In [49]:
import scanpy as sc
from cellflow.model import CellFlow
import requests
import pandas as pd
import anndata
import random
import torch
from esm import pretrained

In [50]:
class ESMConverter:
  def __init__(self, model:str):
    self.model, self.alphabet = pretrained.load_model_and_alphabet(model)
    self.batch_converter = self.alphabet.get_batch_converter()

  def convert(self, sequences):
    batch_labels, batch_strs, batch_tokens = self.batch_converter(sequences)
    with torch.no_grad():
      token_embeddings = self.model(batch_tokens, repr_layers=[33])
      embeddings = token_embeddings['representations'][33]
      average_embeddings = embeddings.mean(dim=1)
    return average_embeddings

In [51]:
from UniProtMapper import ProtMapper
def get_protein_sequence_by_gene(gene_name):
    mapper = ProtMapper()
    result, failed = mapper.get(
        ids=gene_name, from_db="Gene_Name", to_db="UniProtKB"
    )
    result = result[(result['Organism'] == "Homo sapiens (Human)")&(result['Reviewed'] == "reviewed")]
    protein = result.iloc[0]["Entry"]
    # print(protein)
    # Define the UniProt API endpoint
    sequence_url = f"https://www.uniprot.org/uniprot/{protein}.fasta"
    sequence_response = requests.get(sequence_url)
        
    if sequence_response.status_code == 200:
        # Extract and return the protein sequence
        sequence = ''.join(sequence_response.text.splitlines()[1:])
        return sequence
    else:
        return "NONE"

# Example usage
# gene_name = "PTPRC"  # Replace with your gene name
# protein_sequence = get_protein_sequence_by_gene(gene_name)
# print(f"Protein Sequence for {gene_name}:\n{protein_sequence}")

In [52]:
# filePath = "data/vcc_data/adata_Training.h5ad"
filePath = "../vcc_sample.h5ad"

In [53]:
adata = sc.read_h5ad(filePath)

In [54]:
adata

AnnData object with n_obs × n_vars = 88509 × 18080
    obs: 'target_gene', 'guide_id', 'batch', 'control'
    var: 'gene_id'

In [55]:
adata.obs['control'] = [(lambda x: True if x == "non-targeting" else False)(x) for x in adata.obs['target_gene']]

In [None]:
# Split train_test data
x = adata[adata.obs['control'] == True]
y = adata[adata.obs['control'] == False]
# For runability test, sample little data
x_t = sc.pp.sample(x, n = 1000, copy = True)
y_t = sc.pp.sample(y, n = 5000, copy = True)
x_train = x_t[:500, :]
y_train = y_t[:2500, :]
x_eval = x_t[500:, :]
y_eval = y_t[2500:, :]

train = anndata.concat([x_train, y_train])
eval = anndata.concat([x_eval, y_eval])
x_eval.obs['target_gene'] = random.sample(list(y_eval.obs['target_gene']), x_eval.n_obs)

  x_eval.obs['target_gene'] = random.sample(list(y_eval.obs['target_gene']), x_eval.n_obs)


In [56]:
# Parameters for preparing data
sample_rep = "X"
control_key = "control"
perturbation_covariates = {"gene": ("target_gene",)}
split_covariates = ["batch"]
perturbation_covariate_reps = {"gene": "gene_embedding"}
sample_covariates = None
sample_covariate_reps = None

In [57]:
# If embedding is ready, load it
import pickle
import torch
embedding = pickle.load(open("subsample_gene_embedding.pkl", "rb"))

In [None]:
# If not, prepare gene embeddings
# Sort out target genes
genes = adata.obs[adata.obs['control'] == False]['target_gene'].to_list()
genes = list(set(genes))

embedding = pd.DataFrame(columns=["gene", "protein", "embedding"])
embedding["gene"] = genes
embedding.index = genes
embedding["protein"] = embedding['gene'].apply(get_protein_sequence_by_gene)

In [None]:
converter = ESMConverter("esm2_t33_650M_UR50D")
sequences = list(zip(embedding['gene'], embedding['protein']))
em = []
for s in sequences:
  em.append(converter.convert([s]))
embedding['embedding'] = em
# Save embedding to pickle
pd.to_pickle(embedding,"subsample_gene_embedding.pkl")

In [58]:
embedding['embedding'] = embedding['embedding'].apply(torch.flatten)
adata.uns['gene_embedding'] = {}
for g in embedding['gene']:
    adata.uns['gene_embedding'][g] = embedding.loc[g]['embedding']

In [None]:
train.uns = adata.uns
eval.uns = adata.uns

In [69]:
cf = CellFlow(train)

In [70]:
cf.prepare_data(
    sample_rep = sample_rep,
    control_key = control_key,
    perturbation_covariates = perturbation_covariates,
    perturbation_covariate_reps = perturbation_covariate_reps,
    split_covariates = split_covariates,
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 43/43 [00:00<00:00, 590.64it/s]
100%|██████████| 46/46 [00:00<00:00, 1772.85it/s]
100%|██████████| 35/35 [00:00<00:00, 1690.18it/s]
100%|██████████| 37/37 [00:00<00:00, 1685.72it/s]
100%|██████████| 45/45 [00:00<00:00, 1949.65it/s]
100%|██████████| 42/42 [00:00<00:00, 1644.20it/s]
100%|██████████| 39/39 [00:00<00:00, 1654.56it/s]
100%|██████████| 41/41 [00:00<00:00, 1726.07it/s]
100%|██████████| 42/42 [00:00<00:00, 1754.68it/s]
100%|██████████| 49/49 [00:00<00:00, 1452.93it/s]
100%|██████████| 41/41 [00:00<00:00, 1689.24it/s]
100%|██████████| 34/34 [00:00<00:00, 1649.85it/s]
100%|██████████| 44/44 [00:00<00:00, 1722.86it/s]
100%|██████████| 49/49 [00:00<00

In [71]:
cf.prepare_model()

In [72]:
cf.prepare_validation_data(y_eval, name = "test")

  adata.obs[self._control_key] = adata.obs[self._control_key].astype("boolean")


ValueError: No control cells found in adata.

In [16]:
cf.train(num_iterations=10, batch_size = 512)

100%|██████████| 10/10 [01:23<00:00,  8.31s/it]


KeyboardInterrupt: 

In [None]:
# cf.predict(test, covariate_data=, sample_rep = "X", key_added_prefix = "pred_")