In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model

data_path = "/allen/programs/celltypes/workgroups/mousecelltypes/Rohan/dat/proc/nautilex/"
df = pd.read_csv(data_path + "prot_nuc_seqs_human.csv")
df = df.sort_values(by="nuc_seq_length", ascending=True).reset_index(drop=True)
display(df.head(2))

Unnamed: 0,gene_symbol,ensg_id,enst_id,nuc_seq_length,aa_seq_length,nuc_aa_seq_ratio,chromosome,start,end,strand,nuc_seq,aa_seq
0,TRDD1,ENSG00000223997.1,ENST00000415118,8,2.0,4.0,GRCh38:14,22438547.0,22438554.0,1.0,GAAATAGT,EI
1,TRBD1,ENSG00000282431.1,ENST00000632684,12,4.0,3.0,GRCh38:7,142786213.0,142786224.0,1.0,GGGACAGGGGGC,GTGG


In [2]:
# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name="500M_human_ref",
    embeddings_layers_to_save=(20,),
    max_positions=1000,
)
forward_fn = hk.transform(forward_fn)

Downloading model's hyperparameters json file...
Downloaded model's hyperparameters.
Downloading model's weights...
Downloaded model's weights...


In [3]:
batch_size = 20
ref_dict = {}
file_name = data_path+"nuc_seqs_emb_human.csv"

for step, _ in tqdm(enumerate(range(0, len(df), batch_size))):
    batch_idx = np.arange(step*batch_size, min(step*batch_size + batch_size, len(df)))
    ref_dict[step] = batch_idx
    sequences = df["nuc_seq"].values[batch_idx]
    for idx, seq in enumerate(sequences):
        if len(seq) > 1000:
            sequences[idx] = seq[:5952]
    tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
    tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

    # Initialize random key.
    random_key = jax.random.PRNGKey(0)

    # Infer
    outs = forward_fn.apply(parameters, random_key, tokens)

    # Get embeddings at layer 20 - as proposed in the repo readme.
    x = jnp.array(outs["embeddings_20"])


    cols = [f"emb_{i}" for i in range(x.shape[-1])]
    # x[:,0,:] corresponds to the CLS token for each sequence
    df_emb = pd.DataFrame(x[:,0,:], columns=cols)
    df_emb = pd.concat([df.iloc[batch_idx][["gene_symbol", "ensg_id", "enst_id"]].reset_index(drop=True), df_emb], axis=1)

    if os.path.exists(file_name):
        # append to existing file
        df_emb.to_csv(file_name, mode='a', header=False, index=False)
    else:
        df_emb.to_csv(file_name, index=False)

969it [15:57,  1.01it/s]


In [4]:
import anndata as ad
import numpy as np
import pandas as pd


data_path = "/allen/programs/celltypes/workgroups/mousecelltypes/Rohan/dat/proc/nautilex/"
df_emb = pd.read_csv(data_path + "nuc_seqs_emb_human.csv")
df_seq = pd.read_csv(data_path + "prot_nuc_seqs_human.csv")

display(df_emb.head(2))
display(df_seq.head(2))

Unnamed: 0,gene_symbol,ensg_id,enst_id,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,...,emb_1270,emb_1271,emb_1272,emb_1273,emb_1274,emb_1275,emb_1276,emb_1277,emb_1278,emb_1279
0,TRDD1,ENSG00000223997.1,ENST00000415118,0.720733,5.7028,4.866824,5.400179,-7.698461,4.84692,-7.968596,...,10.754392,-2.053164,-7.541198,-0.157856,3.116768,-2.352816,-14.016531,-18.613827,-6.036346,9.650827
1,TRBD1,ENSG00000282431.1,ENST00000632684,-0.371508,5.976748,3.231199,10.971562,-8.412083,3.675622,-6.364289,...,18.006914,2.895096,-1.755377,0.985638,4.407529,-4.338012,-15.876388,-19.709509,-8.248082,5.13817


Unnamed: 0,gene_symbol,ensg_id,enst_id,nuc_seq_length,aa_seq_length,nuc_aa_seq_ratio,chromosome,start,end,strand,nuc_seq,aa_seq
0,TRDJ1,ENSG00000211825.1,ENST00000390473,51,16.0,3.1875,GRCh38:14,22450089.0,22450139.0,1.0,ACACCGATAAACTCATCTTTGGAAAAGGAACCCGTGTGACTGTGGA...,TDKLIFGKGTRVTVEP
1,TRAJ54,ENSG00000211836.1,ENST00000390484,60,20.0,3.0,GRCh38:14,22482287.0,22482346.0,1.0,TAATTCAGGGAGCCCAGAAGCTGGTATTTGGCCAAGGAACCAGGCT...,XIQGAQKLVFGQGTRLTINP


In [5]:
df_merged = df_seq.merge(df_emb, on="ensg_id", how="left")
drop_cols = [c for c in df_merged.columns if not (c == "ensg_id" or c.startswith("emb_"))]
df_merged = df_merged.drop(columns=drop_cols)
display(df_merged.head(2))

Unnamed: 0,ensg_id,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,...,emb_1270,emb_1271,emb_1272,emb_1273,emb_1274,emb_1275,emb_1276,emb_1277,emb_1278,emb_1279
0,ENSG00000211825.1,-0.793367,10.310186,0.873271,9.414678,-12.495082,4.481245,-11.396527,9.682812,0.5453,...,21.416643,4.687433,-5.007195,2.696289,4.393206,-5.634792,-19.484884,-17.335997,-5.436633,1.359537
1,ENSG00000211836.1,0.148229,9.606176,0.533709,3.475356,-8.501716,8.174894,-4.582773,11.462854,3.492025,...,15.05146,0.713071,-2.778783,2.205617,4.067134,-4.380317,-20.43778,-10.456451,-3.425991,5.654714


In [6]:
df_merged = df_merged.set_index("ensg_id")
display(df_merged.head(2))
df_seq = df_seq.set_index("ensg_id")
display(df_seq.head(2))

Unnamed: 0_level_0,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,...,emb_1270,emb_1271,emb_1272,emb_1273,emb_1274,emb_1275,emb_1276,emb_1277,emb_1278,emb_1279
ensg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ENSG00000211825.1,-0.793367,10.310186,0.873271,9.414678,-12.495082,4.481245,-11.396527,9.682812,0.5453,41.201874,...,21.416643,4.687433,-5.007195,2.696289,4.393206,-5.634792,-19.484884,-17.335997,-5.436633,1.359537
ENSG00000211836.1,0.148229,9.606176,0.533709,3.475356,-8.501716,8.174894,-4.582773,11.462854,3.492025,39.57487,...,15.05146,0.713071,-2.778783,2.205617,4.067134,-4.380317,-20.43778,-10.456451,-3.425991,5.654714


Unnamed: 0_level_0,gene_symbol,enst_id,nuc_seq_length,aa_seq_length,nuc_aa_seq_ratio,chromosome,start,end,strand,nuc_seq,aa_seq
ensg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
ENSG00000211825.1,TRDJ1,ENST00000390473,51,16.0,3.1875,GRCh38:14,22450089.0,22450139.0,1.0,ACACCGATAAACTCATCTTTGGAAAAGGAACCCGTGTGACTGTGGA...,TDKLIFGKGTRVTVEP
ENSG00000211836.1,TRAJ54,ENST00000390484,60,20.0,3.0,GRCh38:14,22482287.0,22482346.0,1.0,TAATTCAGGGAGCCCAGAAGCTGGTATTTGGCCAAGGAACCAGGCT...,XIQGAQKLVFGQGTRLTINP


In [7]:
#create an anndata object 
adata = ad.AnnData(X=df_merged, obs=df_seq)

In [8]:
# checks:
display(adata.X[:2,:5])
display(adata.var.head(2))
display(adata.obs.head(2))

array([[ -0.793367  ,  10.310186  ,   0.8732713 ,   9.414678  ,
        -12.495082  ],
       [  0.14822918,   9.606176  ,   0.533709  ,   3.4753563 ,
         -8.501716  ]])

emb_0
emb_1


Unnamed: 0_level_0,gene_symbol,enst_id,nuc_seq_length,aa_seq_length,nuc_aa_seq_ratio,chromosome,start,end,strand,nuc_seq,aa_seq
ensg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
ENSG00000211825.1,TRDJ1,ENST00000390473,51,16.0,3.1875,GRCh38:14,22450089.0,22450139.0,1.0,ACACCGATAAACTCATCTTTGGAAAAGGAACCCGTGTGACTGTGGA...,TDKLIFGKGTRVTVEP
ENSG00000211836.1,TRAJ54,ENST00000390484,60,20.0,3.0,GRCh38:14,22482287.0,22482346.0,1.0,TAATTCAGGGAGCCCAGAAGCTGGTATTTGGCCAAGGAACCAGGCT...,XIQGAQKLVFGQGTRLTINP


In [9]:
adata.write_h5ad(data_path + "nuc_seqs_emb_human.h5ad")