In [9]:
import haiku as hk
import jax
import jax.numpy as jnp
import pandas as pd
from nucleotide_transformer.pretrained import get_pretrained_model
import random

In [11]:
model_name = '500M_multi_species_v2'
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    embeddings_layers_to_save=(20,),
    max_positions=5800,
)
forward_fn = hk.transform(forward_fn)

In [None]:
AA_MAPPING = {
  'A': ['GCT', 'GCC', 'GCA', 'GCG'],
  'R': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'],
  'N': ['AAT', 'AAC'],
  'D': ['GAT', 'GAC'],
  'C': ['TGT', 'TGC'],
  'Q': ['CAA', 'CAG'],
  'E': ['GAA', 'GAG'],
  'G': ['GGT', 'GGC', 'GGA', 'GGG'],
  'H': ['CAT', 'CAC'],
  'I': ['ATT', 'ATC', 'ATA'],
  'L': ['TTA', 'TTG', 'CTT', 'CTC', 'CTA', 'CTG'],
  'K': ['AAA', 'AAG'],
  'M': ['ATG'],
  'F': ['TTT', 'TTC'],
  'P': ['CCT', 'CCC', 'CCA', 'CCG'],
  'O': ['TAG'],
  'U': ['TGA'],
  'S': ['TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC'],
  'T': ['ACT', 'ACC', 'ACA', 'ACG'],
  'W': ['TGG'],
  'Y': ['TAT', 'TAC'],
  'V': ['GTT', 'GTC', 'GTA', 'GTG'],
}

def most_frequent_codons(protein: str) -> str:
    # source: https://www.genscript.com/tools/codon-frequency-table
    # Manually took the highest frequency codon for each AA
    CANONICAL_CODON = {
        'F': 'TTT',
        'Y': 'TAT',
        'L': 'CTG',
        'H': 'CAT',
        'Q': 'CAG',
        'I': 'ATT',
        'M': 'ATG',
        'N': 'AAC',
        'K': 'AAA',
        'V': 'GTG',
        'D': 'GAT',
        'E': 'GAA',
        'S': 'AGC',
        'C': 'TGC',
        'P': 'CCG',
        'R': 'CGT',
        'T': 'ACC',
        'A': 'GCG',
        'G': 'GGC',
        'W': 'TGG',
    }
    return "".join(CANONICAL_CODON[x] for x in protein)

def switch_codon(protein: str) -> str:
    basis = most_frequent_codons(protein)
    position = random.choice(range(len(protein))) 
    res = basis[:(position)*3] + random.choice(AA_MAPPING[protein[position]]) + basis[(position+1)*3:]
    return ''.join(res)

def generate_plasmids(protein: str, k: int):
    res = set([most_frequent_codons(protein)]) 
    while True:
        candidate = switch_codon(protein)
        res.add(candidate)
        if len(res) >= k:
            break           
    return res

In [3]:
train_df = pd.read_csv('data/icodon/training.csv.gz', index_col=0, compression='gzip')

In [5]:
train_df = train_df.loc[:10]

In [6]:
sequences = train_df['coding'].values.tolist()

In [7]:
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

In [8]:
%%time
# Initialize random key
random_key = jax.random.PRNGKey(0)

# Infer
outs = forward_fn.apply(parameters, random_key, tokens)

AssertionError: Inputs to the learned positional embeddings layer have a length 5000 greater than the max positions used to instantiate it: 1000

In [None]:
def get_embeddings(tokens, model, random_key, batch_size=32):
    forward_fn, parameters = model
    for i in range(0, len(tokens), batch_size):
        batch = tokens[i:min(i+batch_size, len(tokens))]
        yield forward_fn.apply(parameters, random_key, batch)['embeddings_20']
        
def get_embeddings(tokens, model, random_key, batch_size=32):
    forward_fn, parameters = model
    res = []
    for i in range(0, len(tokens), batch_size):
        batch = tokens[i:min(i+batch_size, len(tokens))]
        res.append(forward_fn.apply(parameters, random_key, batch)['embeddings_20'])

In [None]:
def pool(embeddings, lengths):
    res = np.zeros((embeddings.shape[0], embeddings.shape[2]))
    for i in range(embeddings.shape[0]):
        res[i, :] = jnp.mean(embeddings[i,0:lengths[i],:], axis=0)
    return res

In [None]:
def normalize(X):
    return X - np.mean(X, axis=0) / np.std(X, axis=0)