# probing nucleotide transformer to predict universal CHMs from CpG-rich regions   
(using CPU; loop at 100 seq step)  
(to classify sequences of universal CHMs, universal-complementary CHMs, and nonCHM CpG-rich regions)

** !!! jax has version conflict to our GPU cuda !!! data can not be loaded into GPU using `jax.device_put()`  
solved by re-install jax:  
`pip install -U jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html`


In [1]:
import nucleotide_transformer
from nucleotide_transformer.pretrained import get_pretrained_model

In [2]:
import pandas as pd
import numpy as np

In [3]:
import haiku as hk
import jax
import jax.numpy as jnp

In [4]:
jax.config.update('jax_platform_name', 'cpu')

In [5]:
import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
jax.devices()

[CpuDevice(id=0)]

# Download the weights

In [6]:
#@title Select a model
#@markdown ---
model_name = '2B5_multi_species'#@param['500M_human_ref', '500M_1000G', '2B5_1000G', '2B5_multi_species']
#@markdown ---

In [7]:
1000 / 6

166.66666666666666

In [8]:
# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=(range(10, 21)),
    attention_maps_to_save=((1, 4), (7, 18), (13, 20)),
    max_positions=171,
)
forward_fn = hk.transform(forward_fn)


# prepare fasta for universal-complementary CHMs

In [9]:
%%bash
cd /mnt/Storage/home/wangyiman/NLP_model/universalCHM_prediction_with_NLP_model/use_nucleotide_transformer/dataset_prepare
# for feature in Universal.ComplementCHM.30CpG1kb;do
#     bedtools getfasta -fi /mnt/Storage/home/wangyiman/annotations/mm10/mm10.fa -bed ${feature}.bed -fo ${feature}.fasta
# done

# define input data and tokenize it  

In [10]:
%%bash
cd /mnt/Storage/home/wangyiman/NLP_model/universalCHM_prediction_with_NLP_model/use_nucleotide_transformer/dataset_prepare
# ln -s /mnt/Storage/home/wangyiman/NLP_model/universalCHM_prediction_with_NLP_model/use_DNABERT/dataset_prepare/*.fasta .


In [11]:
label_dict = {
    'Universal.CHM.30CpG1kb': 1, 
    'Universal.ComplementCHM.30CpG1kb': 2, 
    'NonCHMsCpGrich': 0,
}

In [14]:
os.chdir("/mnt/Storage/home/wangyiman/NLP_model/universalCHM_prediction_with_NLP_model/use_nucleotide_transformer/dataset_prepare")

sequence_df = pd.DataFrame()
sequence_ls = []
if_N_ls = []
label_ls = []
for feature in ['Universal.CHM.30CpG1kb', 'Universal.ComplementCHM.30CpG1kb', 'NonCHMsCpGrich'] :
    with open(feature+'.fasta','r') as f:
        for line in f:
            if line.startswith('>'):
                continue
            else:
                sequence_ls.append(line.strip().upper())
                label_ls.append(label_dict[feature])
                if 'N' in line :
                    if_N_ls.append(True)
                else :
                    if_N_ls.append(False)
                    
sequence_df['sequence'] = sequence_ls
sequence_df['if_N'] = if_N_ls
sequence_df['label'] = label_ls
sequence_df['length'] = [len(x) for x in sequence_df['sequence']]
sequence_df

Unnamed: 0,sequence,if_N,label,length
0,CGGCCAGGAAGAACACAACAGACCAGAATCTTCTGCGGCAAAACTT...,False,1,1000
1,ATCTACAACTCCAGGGTGGACAATAAGACCTTGTAGGCTGTAAGAG...,False,1,1000
2,TGTTGACAATCCATAACTCCAGGGTGGACTACTAAGCCCTGCAAGG...,False,1,1000
3,ACTAGGGAGAGCGGCTTTTACAACCGTTTGCCAGTCGGCAGGAGTT...,False,1,1000
4,CCAGAAACTTAGGATACCCAAGATATAAGATATAATTTGCTAAAAA...,False,1,1000
...,...,...,...,...
27858,CACCGGCCGAGCTCGCGGGCTGGGCTTTCCCCGTCCAGCCTGGCTG...,False,0,1000
27859,GTACTTGGACTTCGGAAAGTCCCAGTCCCAGAGTTCTCAGCTCTTT...,False,0,1000
27860,TTGTTTAGATGTATACAAAGTGACACTTACTAGGAATTGATTGCTG...,False,0,1000
27861,TGTAAGTAGATTAAGTCCTGAGCCTCTGCCCCATTTCTGTCTGGAA...,False,0,1000


In [15]:
sequence_df[sequence_df['if_N']]

Unnamed: 0,sequence,if_N,label,length
6954,AATCTGGACCCCCCCGAAATCTCTCAAACACTGGACCACCAAACAG...,True,0,1000
13887,GACTTTGAATTTAGACTGCTTTATGCTTCTGATTTCAGCTCCAACC...,True,0,1000
19078,TCTTAAGGTAGAGTGTTATCTTGCTTTTTGTTGCAACAACAGTTGT...,True,0,1000
21110,TCGCACCTTCCATCGGGGCTGCAGGCCGGGCCTGCCGGGGCCGAGC...,True,0,1000
21112,GGCATTCTGGGCCCGGAAGTGCGGCGCACGCGGCTGGGCGCGCCAT...,True,0,1000
21114,NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...,True,0,1000
27862,NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...,True,0,1000


In [17]:
def tokenize_seq_in_part(sequence_df, start_line, num_line_used = 100) :
    seq_df_part = sequence_df.iloc[start_line:(start_line + num_line_used), :]
    
    # Get data and tokenize it
    sequences = seq_df_part['sequence']
    tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
    tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
    tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
    print(tokens.shape)
    return tokens

# do the inference & retrieve embeddings

The tokeninzing,inference and embeddings retrievement has been done for universal CHM and nonCHM CpG-rich regions.  
Only need to do this for 'Universal.ComplementCHM.30CpG1kb'  

In [18]:
%%time

import pickle

# Initialize random key
random_key = jax.random.PRNGKey(0)
sequence_df_woN = sequence_df.loc[~sequence_df['if_N']]
sequence_df_woN.to_pickle('/mnt/Storage/home/wangyiman/NLP_model/universalCHM_prediction_with_NLP_model/use_nucleotide_transformer/embedding_results/total_CPU/dataset_y_wiUniversalComplement.pkl')
sequence_df_woN_universalComplement = sequence_df_woN.loc[sequence_df_woN['label'] == 2]

for i in range(0, 2392, 100) :
    # Infer
    num_line_used = 100 if i < 2300 else 92
    tokens = tokenize_seq_in_part(sequence_df_woN_universalComplement, start_line=i, num_line_used = num_line_used)
    outs = forward_fn.apply(parameters, random_key, tokens)

    # retrieve embeddings
    for embed in range(10,21) :
        outs_key = f'embeddings_{embed}'
        embeddings = outs[outs_key][:, 1:, :]  # removing CLS token
        padding_mask = jnp.expand_dims(tokens[:, 1:] != tokenizer.pad_token_id, axis=-1)
        masked_embeddings = embeddings * padding_mask  # multiply by 0 pad tokens embeddings
        sequences_lengths = jnp.sum(padding_mask, axis=1)
        mean_embeddings = jnp.sum(masked_embeddings, axis=1) / sequences_lengths        
        ### write embedding results into files :
        mean_embeddings_np = jax.device_get(mean_embeddings).copy()
        np.save(
            f'/mnt/Storage/home/wangyiman/NLP_model/universalCHM_prediction_with_NLP_model/use_nucleotide_transformer/embedding_results/total_CPU/seqStartLine{i}_mean_embeddings_{embed}_universalComplement.npy', 
            mean_embeddings_np)

        

(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(100, 171)
(92, 171)
CPU times: user 13h 49min 22s, sys: 3h 43min 18s, total: 17h 32min 41s
Wall time: 48min 8s


In [16]:
130 * 5 / 60

10.833333333333334