In [17]:
import scanpy as sc
from cellflow.model import CellFlow
import requests
import pandas as pd
import anndata
import random
import torch
from esm import pretrained

In [18]:
class ESMConverter:
  def __init__(self, model:str):
    self.model, self.alphabet = pretrained.load_model_and_alphabet(model)
    self.batch_converter = self.alphabet.get_batch_converter()

  def convert(self, sequences):
    batch_labels, batch_strs, batch_tokens = self.batch_converter(sequences)
    with torch.no_grad():
      token_embeddings = self.model(batch_tokens, repr_layers=[33])
      embeddings = token_embeddings['representations'][33]
      average_embeddings = embeddings.mean(dim=1)
    return average_embeddings

In [19]:
from UniProtMapper import ProtMapper
def get_protein_sequence_by_gene(gene_name):
    mapper = ProtMapper()
    result, failed = mapper.get(
        ids=gene_name, from_db="Gene_Name", to_db="UniProtKB"
    )
    result = result[(result['Organism'] == "Homo sapiens (Human)")&(result['Reviewed'] == "reviewed")]
    protein = result.iloc[0]["Entry"]
    # print(protein)
    # Define the UniProt API endpoint
    sequence_url = f"https://www.uniprot.org/uniprot/{protein}.fasta"
    sequence_response = requests.get(sequence_url)
        
    if sequence_response.status_code == 200:
        # Extract and return the protein sequence
        sequence = ''.join(sequence_response.text.splitlines()[1:])
        return sequence
    else:
        return "NONE"

# Example usage
# gene_name = "PTPRC"  # Replace with your gene name
# protein_sequence = get_protein_sequence_by_gene(gene_name)
# print(f"Protein Sequence for {gene_name}:\n{protein_sequence}")

In [20]:
# filePath = "data/vcc_data/adata_Training.h5ad"
filePath = "../vcc_sample.h5ad"

In [21]:
adata = sc.read_h5ad(filePath)

In [None]:
adata.write_h5ad("../vcc_sample.h5ad")

{'RNF20': tensor([ 0.0104, -0.0884, -0.0148,  ..., -0.1244,  0.1133,  0.1652]),
 'TMSB10': tensor([-0.0109, -0.0171, -0.0571,  ..., -0.0265, -0.1325,  0.0159]),
 'ANTXR1': tensor([-0.0367, -0.0803, -0.1321,  ..., -0.2080,  0.0250,  0.1001]),
 'HIRA': tensor([-0.0164, -0.0753, -0.0067,  ..., -0.1768,  0.0007,  0.1631]),
 'ATP6V0C': tensor([ 0.0886, -0.1398, -0.0269,  ..., -0.3051, -0.0488,  0.2683]),
 'SUPT4H1': tensor([ 0.0669, -0.0283,  0.0173,  ..., -0.2693, -0.0681,  0.1898]),
 'DHCR24': tensor([ 0.0492, -0.0253,  0.0156,  ..., -0.1282, -0.0437,  0.0311]),
 'OXCT1': tensor([[ 0.0332, -0.0064, -0.0084,  ..., -0.0846, -0.0035,  0.0739]]),
 'SMARCA4': tensor([-0.0085, -0.0615, -0.0069,  ..., -0.0540,  0.0795,  0.0830]),
 'CASP2': tensor([ 0.0352, -0.0567, -0.0319,  ..., -0.1854,  0.0906,  0.1119]),
 'ACAT2': tensor([ 0.0030, -0.0700,  0.0587,  ..., -0.0953,  0.0767,  0.0887]),
 'ETV4': tensor([ 0.0121, -0.0638, -0.0065,  ..., -0.0799,  0.0840,  0.0173]),
 'PHF10': tensor([ 0.0099, -0.0

In [23]:
adata.obs['control'] = [(lambda x: True if x == "non-targeting" else False)(x) for x in adata.obs['target_gene']]

In [24]:
# Split train_test data
x = adata[adata.obs['control'] == True]
y = adata[adata.obs['control'] == False]
# For runability test, sample little data
x_t = sc.pp.sample(x, n = 1000, copy = True)
y_t = sc.pp.sample(y, n = 5000, copy = True)
x_train = x_t[:500, :]
y_train = y_t[:2500, :]
x_eval = x_t[500:, :]
y_eval = y_t[2500:, :]

train = anndata.concat([x_train, y_train])
eval = anndata.concat([x_eval, y_eval])
x_eval.obs['target_gene'] = random.sample(list(y_eval.obs['target_gene']), x_eval.n_obs)

  x_eval.obs['target_gene'] = random.sample(list(y_eval.obs['target_gene']), x_eval.n_obs)


In [25]:
# Parameters for preparing data
sample_rep = "X"
control_key = "control"
perturbation_covariates = {"gene": ("target_gene",)}
split_covariates = ["batch"]
perturbation_covariate_reps = {"gene": "gene_embedding"}
sample_covariates = None
sample_covariate_reps = None

In [26]:
# If embedding is ready, load it
import pickle
import torch
embedding = pickle.load(open("subsample_gene_embedding.pkl", "rb"))

In [None]:
# If not, prepare gene embeddings
# Sort out target genes
genes = adata.obs[adata.obs['control'] == False]['target_gene'].to_list()
genes = list(set(genes))

embedding = pd.DataFrame(columns=["gene", "protein", "embedding"])
embedding["gene"] = genes
embedding.index = genes
embedding["protein"] = embedding['gene'].apply(get_protein_sequence_by_gene)

In [None]:
converter = ESMConverter("esm2_t33_650M_UR50D")
sequences = list(zip(embedding['gene'], embedding['protein']))
em = []
for s in sequences:
  em.append(converter.convert([s]))
embedding['embedding'] = em
embedding['embedding'] = embedding['embedding'].apply(torch.flatten)
# Save embedding to pickle
pd.to_pickle(embedding,"subsample_gene_embedding.pkl")

In [None]:
adata.uns['gene_embedding'] = {}
for g in embedding['gene']:
    adata.uns['gene_embedding'][g] = embedding.loc[g]['embedding']

In [28]:
train.uns = adata.uns
eval.uns = adata.uns

In [29]:
cf = CellFlow(train)

In [30]:
cf.prepare_data(
    sample_rep = sample_rep,
    control_key = control_key,
    perturbation_covariates = perturbation_covariates,
    perturbation_covariate_reps = perturbation_covariate_reps,
    split_covariates = split_covariates,
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 44/44 [00:00<00:00, 1458.74it/s]
100%|██████████| 38/38 [00:00<00:00, 1822.09it/s]
100%|██████████| 38/38 [00:00<00:00, 1789.50it/s]
100%|██████████| 46/46 [00:00<00:00, 1704.97it/s]
100%|██████████| 45/45 [00:00<00:00, 1624.08it/s]
100%|██████████| 50/50 [00:00<00:00, 1635.70it/s]
100%|██████████| 42/42 [00:00<00:00, 1576.93it/s]
100%|██████████| 35/35 [00:00<00:00, 1720.67it/s]
100%|██████████| 38/38 [00:00<00:00, 1729.23it/s]
100%|██████████| 44/44 [00:00<00:00, 1732.14it/s]
100%|██████████| 36/36 [00:00<00:00, 1609.17it/s]
100%|██████████| 35/35 [00:00<00:00, 1636.50it/s]
100%|██████████| 35/35 [00:00<00:00, 1700.75it/s]
100%|██████████| 45/45 [00:00<0

In [33]:
cf.prepare_model(condition_embedding_dim=128, time_freqs=64, time_encoder_dims=(128,128,128), hidden_dims=(128,128,128), decoder_dims=(256,256,256))

In [34]:
cf.prepare_validation_data(eval, name = "test")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 38/38 [00:00<00:00, 1546.65it/s]
100%|██████████| 33/33 [00:00<00:00, 1679.66it/s]
100%|██████████| 44/44 [00:00<00:00, 1752.94it/s]
100%|██████████| 45/45 [00:00<00:00, 1741.02it/s]
100%|██████████| 43/43 [00:00<00:00, 1777.60it/s]
100%|██████████| 46/46 [00:00<00:00, 1679.63it/s]
100%|██████████| 30/30 [00:00<00:00, 1542.40it/s]
100%|██████████| 48/48 [00:00<00:00, 1761.54it/s]
100%|██████████| 41/41 [00:00<00:00, 1593.73it/s]
100%|██████████| 47/47 [00:00<00:00, 1729.97it/s]
100%|██████████| 37/37 [00:00<00:00, 1796.57it/s]
100%|██████████| 40/40 [00:00<00:00, 1580.22it/s]
100%|██████████| 42/42 [00:00<00:00, 1856.69it/s]
100%|██████████| 45/45 [00:00<0

In [35]:
cf.train(num_iterations=1, batch_size = 512)

100%|██████████| 1/1 [00:04<00:00,  4.75s/it]


In [45]:
cf.save("./models/","cf-128-256", overwrite=False)

In [86]:
res = cf.predict(x_eval, covariate_data=x_eval.obs, sample_rep = "X")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _covariate_data["cell_index"] = _covariate_data.index
100%|██████████| 13/13 [00:00<00:00, 1827.89it/s]


In [91]:
pickle.dump(res, open("result.pkl", "wb"))

In [63]:
x_eval = x_eval[x_eval.obs['batch'] == "Flex_1_10"]

In [51]:
cf.adata.uns['gene_embedding']['OXCT1'] = res

In [40]:
s = get_protein_sequence_by_gene("OXCT1")


Fetched: 500 / 688


In [42]:
converter = ESMConverter("esm2_t33_650M_UR50D")
res = converter.convert([("OXCT1", s)])

In [83]:
x_eval.obsm["pred_('ACVR1B', 'Flex_1_10')"].shape

(15, 18080)

In [None]:
x_eval.obsm

AnnData object with n_obs × n_vars = 15 × 18080
    obs: 'target_gene', 'guide_id', 'batch', 'control'
    var: 'gene_id'
    uns: 'gene_embedding', 'OXCT1'
    obsm: "pred_('ACVR1B', 'Flex_1_10')", "pred_('AKT2', 'Flex_1_10')", "pred_('BIRC2', 'Flex_1_10')", "pred_('CLDN6', 'Flex_1_10')", "pred_('HMGN1', 'Flex_1_10')", "pred_('IGF2R', 'Flex_1_10')", "pred_('INSIG1', 'Flex_1_10')", "pred_('KAT2A', 'Flex_1_10')", "pred_('KDR', 'Flex_1_10')", "pred_('OXCT1', 'Flex_1_10')", "pred_('POLB', 'Flex_1_10')", "pred_('TADA1', 'Flex_1_10')", "pred_('WFS1', 'Flex_1_10')"