In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model

data_path = "/allen/programs/celltypes/workgroups/mousecelltypes/Rohan/dat/proc/nautilex/"
df = pd.read_csv(data_path + "prot_nuc_seqs_mouse.csv")
df = df.sort_values(by="nuc_seq_length", ascending=True).reset_index(drop=True)
display(df.head(2))

Unnamed: 0,gene_symbol,ensg_id,enst_id,nuc_seq_length,aa_seq_length,nuc_aa_seq_ratio,chromosome,start,end,strand,nuc_seq,aa_seq
0,Gm16968,ENSMUSG00000076632.2,ENSMUST00000103441,23,7.0,3.285714,GRCm39:12,113491652.0,113491674.0,-1.0,TATATAACTAAAGTGGTAGCTCA,YITKVVA
1,Ighd1-1,ENSMUSG00000076630.2,ENSMUST00000103439,23,7.0,3.285714,GRCm39:12,113445790.0,113445812.0,-1.0,TTTATTACTACGGTAGTAGCTAC,FITTVVA


In [2]:
# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name="500M_multi_species_v2",
    embeddings_layers_to_save=(20,),
    max_positions=1000,
)
forward_fn = hk.transform(forward_fn)

Downloading model's hyperparameters json file...
Downloaded model's hyperparameters.
Downloading model's weights...
Downloaded model's weights...


In [3]:
batch_size = 20
ref_dict = {}
file_name = data_path+"nuc_seqs_emb_mouse.csv"

for step, _ in tqdm(enumerate(range(0, len(df), batch_size))):
    batch_idx = np.arange(step*batch_size, min(step*batch_size + batch_size, len(df)))
    ref_dict[step] = batch_idx
    sequences = df["nuc_seq"].values[batch_idx]
    for idx, seq in enumerate(sequences):
        if len(seq) > 1000:
            sequences[idx] = seq[:5952]
    tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
    tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

    # Initialize random key.
    random_key = jax.random.PRNGKey(0)

    # Infer
    outs = forward_fn.apply(parameters, random_key, tokens)

    # Get embeddings at layer 20 - as proposed in the repo readme.
    x = jnp.array(outs["embeddings_20"])


    cols = [f"emb_{i}" for i in range(x.shape[-1])]
    # x[:,0,:] corresponds to the CLS token for each sequence
    df_emb = pd.DataFrame(x[:,0,:], columns=cols)
    df_emb = pd.concat([df.iloc[batch_idx][["gene_symbol", "ensg_id", "enst_id"]].reset_index(drop=True), df_emb], axis=1)

    if os.path.exists(file_name):
        # append to existing file
        df_emb.to_csv(file_name, mode='a', header=False, index=False)
    else:
        df_emb.to_csv(file_name, index=False)

1046it [16:42,  1.04it/s]


In [4]:
import anndata as ad
import numpy as np
import pandas as pd


data_path = "/allen/programs/celltypes/workgroups/mousecelltypes/Rohan/dat/proc/nautilex/"
df_emb = pd.read_csv(data_path + "nuc_seqs_emb_mouse.csv")
df_seq = pd.read_csv(data_path + "prot_nuc_seqs_mouse.csv")

display(df_emb.head(2))
display(df_seq.head(2))

Unnamed: 0,gene_symbol,ensg_id,enst_id,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,...,emb_1014,emb_1015,emb_1016,emb_1017,emb_1018,emb_1019,emb_1020,emb_1021,emb_1022,emb_1023
0,Gm16968,ENSMUSG00000076632.2,ENSMUST00000103441,-1.410019,3.665271,-1.840973,-1.458262,-0.664769,0.3217,-1.95364,...,-1.567432,2.522302,-0.24541,2.003354,-0.093926,2.50656,-3.417,0.67888,-1.512638,4.36806
1,Ighd1-1,ENSMUSG00000076630.2,ENSMUST00000103439,-0.359152,1.68417,-1.071354,2.210377,2.27511,3.122249,0.021253,...,3.170171,-1.524124,-0.414902,5.149659,-1.012639,0.191297,-4.027051,3.148464,-0.578641,1.203099


Unnamed: 0,gene_symbol,ensg_id,enst_id,nuc_seq_length,aa_seq_length,nuc_aa_seq_ratio,chromosome,start,end,strand,nuc_seq,aa_seq
0,Gm20730,ENSMUSG00000076500.3,ENSMUST00000103301,359,119.0,3.016807,GRCm39:6,43058452.0,43059031.0,-1.0,ATGAGGTGCCTAGCTGAGTTCCTGAGGCTACTTGTGCTCTGGATCC...,MRCLAEFLRLLVLWIPATGDIVMTQAAPSVPANPGESVSISCRSSK...
1,Gm54608,ENSMUSG00000090395.2,ENSMUST00000166255,278,92.0,3.021739,GRCm39:12,113618587.0,113618864.0,1.0,CATGGCTGTGTACTCAGACCTCAGACTGTTTATTTTCAGGTAAAGT...,HGCVLRPQTVYFQVKCVFVIISGDGESALHCVYIVGATSTTKNYCH...


In [5]:
df_merged = df_seq.merge(df_emb, on="ensg_id", how="left")
drop_cols = [c for c in df_merged.columns if not (c == "ensg_id" or c.startswith("emb_"))]
df_merged = df_merged.drop(columns=drop_cols)
display(df_merged.head(2))

Unnamed: 0,ensg_id,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,...,emb_1014,emb_1015,emb_1016,emb_1017,emb_1018,emb_1019,emb_1020,emb_1021,emb_1022,emb_1023
0,ENSMUSG00000076500.3,-3.227321,-1.504367,-2.361634,3.114111,1.495391,1.380513,0.211368,2.959904,3.53552,...,-1.492077,-0.865977,-2.748574,-1.280103,-1.712528,3.291653,-0.045543,1.394469,-0.161717,0.432455
1,ENSMUSG00000090395.2,-1.758324,3.623576,-3.298756,1.120855,-0.533384,1.455197,-1.628992,4.152858,3.325852,...,1.555718,-0.32884,-1.204153,0.162675,-4.117949,1.733881,0.627468,4.001059,-0.683634,-0.09787


In [6]:
df_merged = df_merged.set_index("ensg_id")
display(df_merged.head(2))
df_seq = df_seq.set_index("ensg_id")
display(df_seq.head(2))

Unnamed: 0_level_0,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,...,emb_1014,emb_1015,emb_1016,emb_1017,emb_1018,emb_1019,emb_1020,emb_1021,emb_1022,emb_1023
ensg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ENSMUSG00000076500.3,-3.227321,-1.504367,-2.361634,3.114111,1.495391,1.380513,0.211368,2.959904,3.53552,0.400706,...,-1.492077,-0.865977,-2.748574,-1.280103,-1.712528,3.291653,-0.045543,1.394469,-0.161717,0.432455
ENSMUSG00000090395.2,-1.758324,3.623576,-3.298756,1.120855,-0.533384,1.455197,-1.628992,4.152858,3.325852,0.65521,...,1.555718,-0.32884,-1.204153,0.162675,-4.117949,1.733881,0.627468,4.001059,-0.683634,-0.09787


Unnamed: 0_level_0,gene_symbol,enst_id,nuc_seq_length,aa_seq_length,nuc_aa_seq_ratio,chromosome,start,end,strand,nuc_seq,aa_seq
ensg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
ENSMUSG00000076500.3,Gm20730,ENSMUST00000103301,359,119.0,3.016807,GRCm39:6,43058452.0,43059031.0,-1.0,ATGAGGTGCCTAGCTGAGTTCCTGAGGCTACTTGTGCTCTGGATCC...,MRCLAEFLRLLVLWIPATGDIVMTQAAPSVPANPGESVSISCRSSK...
ENSMUSG00000090395.2,Gm54608,ENSMUST00000166255,278,92.0,3.021739,GRCm39:12,113618587.0,113618864.0,1.0,CATGGCTGTGTACTCAGACCTCAGACTGTTTATTTTCAGGTAAAGT...,HGCVLRPQTVYFQVKCVFVIISGDGESALHCVYIVGATSTTKNYCH...


In [7]:
#create an anndata object 
adata = ad.AnnData(X=df_merged, obs=df_seq)

# checks:
display(adata.X[:2,:5])
display(adata.var.head(2))
display(adata.obs.head(2))

adata.write_h5ad(data_path + "nuc_seqs_emb_mouse.h5ad")

array([[-3.2273207, -1.504367 , -2.3616338,  3.1141107,  1.4953909],
       [-1.7583236,  3.623576 , -3.2987561,  1.1208553, -0.5333837]])

emb_0
emb_1


Unnamed: 0_level_0,gene_symbol,enst_id,nuc_seq_length,aa_seq_length,nuc_aa_seq_ratio,chromosome,start,end,strand,nuc_seq,aa_seq
ensg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
ENSMUSG00000076500.3,Gm20730,ENSMUST00000103301,359,119.0,3.016807,GRCm39:6,43058452.0,43059031.0,-1.0,ATGAGGTGCCTAGCTGAGTTCCTGAGGCTACTTGTGCTCTGGATCC...,MRCLAEFLRLLVLWIPATGDIVMTQAAPSVPANPGESVSISCRSSK...
ENSMUSG00000090395.2,Gm54608,ENSMUST00000166255,278,92.0,3.021739,GRCm39:12,113618587.0,113618864.0,1.0,CATGGCTGTGTACTCAGACCTCAGACTGTTTATTTTCAGGTAAAGT...,HGCVLRPQTVYFQVKCVFVIISGDGESALHCVYIVGATSTTKNYCH...
