In [1]:
import sys
import torch
import tape
import sklearn
import pickle
import numpy as np
import pandas as pd
import scipy
import random
from sklearn.linear_model import LassoLars, LassoLarsCV, Ridge, RidgeCV, BayesianRidge
from sklearn.neighbors import KNeighborsRegressor
from sklearn.gaussian_process import GaussianProcessRegressor
from torch import nn
from tape import TAPETokenizer
from tape import UniRepForLM
import copy

sys.path.append('../common')
import data_io_utils
import paths
import utils
import constants

import A003_common
import acquisition_policies
import models

In [2]:
torch.__version__

'1.9.0+cu111'

# load language model

In [6]:
model = UniRepForLM.from_pretrained("babbler-1900")
model

UniRepForLM(
  (unirep): UniRepModel(
    (embed_matrix): Embedding(26, 10)
    (encoder): mLSTM(
      (mlstm_cell): mLSTMCell(
        (wmx): Linear(in_features=10, out_features=1900, bias=False)
        (wmh): Linear(in_features=1900, out_features=1900, bias=False)
        (wx): Linear(in_features=10, out_features=7600, bias=False)
        (wh): Linear(in_features=1900, out_features=7600, bias=True)
      )
    )
  )
  (feedforward): Linear(in_features=1900, out_features=25, bias=True)
)

In [8]:
model = UniRepForLM.from_pretrained("babbler-1900")
model.feedforward = nn.Linear(1900,26)

In [None]:
checkpoint = torch.load("fp16_64_trial_fine_turning.pt")
model.load_state_dict(checkpoint['model_state_dict'])
#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#epoch = checkpoint['epoch']
#loss = checkpoint['loss']

In [6]:
model.eval()

UniRepForLM(
  (unirep): UniRepModel(
    (embed_matrix): Embedding(26, 10)
    (encoder): mLSTM(
      (mlstm_cell): mLSTMCell(
        (wmx): Linear(in_features=10, out_features=1900, bias=False)
        (wmh): Linear(in_features=1900, out_features=1900, bias=False)
        (wx): Linear(in_features=10, out_features=7600, bias=False)
        (wh): Linear(in_features=1900, out_features=7600, bias=True)
      )
    )
  )
  (feedforward): Linear(in_features=1900, out_features=26, bias=True)
)

In [7]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x
model.feedforward = Identity()

In [7]:
model

UniRepForLM(
  (unirep): UniRepModel(
    (embed_matrix): Embedding(26, 10)
    (encoder): mLSTM(
      (mlstm_cell): mLSTMCell(
        (wmx): Linear(in_features=10, out_features=1900, bias=False)
        (wmh): Linear(in_features=1900, out_features=1900, bias=False)
        (wx): Linear(in_features=10, out_features=7600, bias=False)
        (wh): Linear(in_features=1900, out_features=7600, bias=True)
      )
    )
  )
  (feedforward): Identity()
)

# training data for top model

In [None]:
emb_qfunc = pickle.load( open( "training.p", "rb" ) )
training_embedding = emb_qfunc["embedding"]
training_qfunc = emb_qfunc["qfunc"]

# Validation data for top model

In [42]:
emb_qfunc = pickle.load( open( "validation.p", "rb" ) )
validation_embedding = emb_qfunc["embedding"]
validation_qfunc = emb_qfunc["qfunc"]

# embedding

# top model

In [77]:
TOP_MODEL_DO_SPARSE_REFIT = True

In [10]:
LassoLars_model = A003_common.cv_train_lasso_lars_with_sparse_refit(
                training_embedding, training_qfunc, do_sparse_refit=False)

In [67]:
Ridge_model = A003_common.cv_train_ridge_with_sparse_refit(
                training_embedding, training_qfunc, do_sparse_refit=False)

In [68]:
BayesianRidge_model = A003_common.train_blr(training_embedding, training_qfunc)

In [78]:
EnsembledRidge = A003_common.train_ensembled_ridge(
                training_embedding, 
                training_qfunc, 
                do_sparse_refit=TOP_MODEL_DO_SPARSE_REFIT, 
                n_members=100, 
                subspace_proportion=0.5, 
                pval_cutoff=0.01, 
                normalize=True
            )

In [69]:
KNN_model = A003_common.cv_train_knn(
                training_embedding, training_qfunc, do_sparse_refit=False)

In [49]:
tokenizer = TAPETokenizer(vocab='unirep')

In [31]:
validation_results = EnsembledRidge.predict(validation_embedding)

In [None]:
import seaborn as sns
_, _, r_value, _ , _ = scipy.stats.linregress(validation_embedding, validation_qfunc)
pred_vs_actual_df = pd.DataFrame(np.ones(len(validation_embedding)))
pred_vs_actual_df["actual"] = validation_qfunc
pred_vs_actual_df["predicted"] = validation_embedding
pred_vs_actual_df.drop(columns=0, inplace=True)
pred_vs_actual_df.head()
#--------------------------------------------------#
sns.set_theme(style="darkgrid")
y_interval=max(np.concatenate((validation_qfunc, validation_qfunc),axis=0))-min(np.concatenate((validation_qfunc, validation_qfunc),axis=0))
x_y_range=(min(np.concatenate((validation_qfunc, validation_qfunc),axis=0))-0.1*y_interval, max(np.concatenate((validation_qfunc, validation_qfunc),axis=0))+0.1*y_interval)
g = sns.jointplot(x="actual", y="predicted", data=pred_vs_actual_df,
                kind="reg", truncate=False,
                xlim=x_y_range, ylim=x_y_range,
                color="blue",height=7)

g.fig.suptitle("Predictions vs. Actual Values, R = " + str(round(r_value,3)) , fontsize=18, fontweight='bold')
g.fig.tight_layout()
g.fig.subplots_adjust(top=0.95)
g.ax_joint.text(0.4,0.6,"", fontsize=12)
g.ax_marg_x.set_axis_off()
g.ax_marg_y.set_axis_off()
g.ax_joint.set_xlabel('Actual Values',fontsize=18 ,fontweight='bold')
g.ax_joint.set_ylabel('Predictions',fontsize=18 ,fontweight='bold')

# MCMC

In [23]:
SIM_ANNEAL_K = 1
SIM_ANNEAL_INIT_SEQ_MUT_RADIUS = 3
n_chains = 350
T_max = 0.01*np.ones(3500)
GFP_LIB_REGION = [29, 110]
seed = 9
temp_decay_rate = 1.0
sa_n_iter = 30000
nmut_threshold = 15

In [24]:
np.random.seed(seed)
random.seed(seed)

In [17]:
def acceptance_prob(f_proposal, f_current, k, T):
    ap = np.exp((f_proposal - f_current)/(k*T))
    ap[ap > 1] = 1
    return ap

def make_n_random_edits(seq, nedits, alphabet=constants.AA_ALPHABET_STANDARD_ORDER,
        min_pos=None, max_pos=None): ## Test
    """
    min_pos is inclusive. max_pos is exclusive
    """
    
    lseq = list(seq)
    lalphabet = list(alphabet)
    
    if min_pos is None:
        min_pos = 0
    
    if max_pos is None:
        max_pos = len(seq)
    
    # Create non-redundant list of positions to mutate.
    l = list(range(min_pos, max_pos))
    nedits = min(len(l), nedits)
    random.shuffle(l)
    pos_to_mutate = l[:nedits]    
    
    for i in range(nedits):
        pos = pos_to_mutate[i]     
        aa_to_choose_from = list(set(lalphabet) - set([seq[pos]]))
                        
        lseq[pos] = aa_to_choose_from[np.random.randint(len(aa_to_choose_from))]
        
    return "".join(lseq)

def propose_seqs(seqs, mu_muts_per_seq, min_pos=None, max_pos=None):
    
    mseqs = []
    for i,s in enumerate(seqs):
        n_edits = np.random.poisson(mu_muts_per_seq[i]-1) + 1
        mseqs.append(make_n_random_edits(s, n_edits, min_pos=min_pos, max_pos=max_pos)) 
        
    return mseqs


def anneal(
        init_seqs, 
        k, 
        T_max, 
        mu_muts_per_seq,
        get_fitness_fn,
        n_iter=1000, 
        decay_rate=0.99,
        min_mut_pos=None,
        max_mut_pos=None):
    
    print('Initializing')
    state_seqs = copy.deepcopy(init_seqs)
    state_fitness, state_fitness_std, state_fitness_mem_pred = get_fitness_fn(state_seqs)
    
    seq_history = [copy.deepcopy(state_seqs)]
    fitness_history = [copy.deepcopy(state_fitness)]
    fitness_std_history = [copy.deepcopy(state_fitness_std)]
    fitness_mem_pred_history = [copy.deepcopy(state_fitness_mem_pred)]
    for i in range(n_iter):
        print('Iteration:', i)
        
        print('\tProposing sequences.')
        proposal_seqs = propose_seqs(state_seqs, mu_muts_per_seq, 
                min_pos=min_mut_pos, max_pos=max_mut_pos)
        
        print('\tCalculating predicted fitness.')
        proposal_fitness, proposal_fitness_std, proposal_fitness_mem_pred = get_fitness_fn(proposal_seqs)
        
        
        print('\tMaking acceptance/rejection decisions.')
        aprob = acceptance_prob(proposal_fitness, state_fitness, k, T_max*(decay_rate**i))
        
        # Make sequence acceptance/rejection decisions
        for j, ap in enumerate(aprob):
            if np.random.rand() < ap:
                # accept
                state_seqs[j] = copy.deepcopy(proposal_seqs[j])
                state_fitness[j] = copy.deepcopy(proposal_fitness[j])
                state_fitness_std[j] = copy.deepcopy(proposal_fitness_std[j])
                state_fitness_mem_pred[j] = copy.deepcopy(proposal_fitness_mem_pred[j])
            # else do nothing (reject)
            
        seq_history.append(copy.deepcopy(state_seqs))
        fitness_history.append(copy.deepcopy(state_fitness))
        fitness_std_history.append(copy.deepcopy(state_fitness_std))
        fitness_mem_pred_history.append(copy.deepcopy(state_fitness_mem_pred))
        
    return {
        'seq_history': seq_history,
        'fitness_history': fitness_history,
        'fitness_std_history': fitness_std_history,
        'fitness_mem_pred_history': fitness_mem_pred_history,
        'init_seqs': init_seqs,
        'T_max': T_max,
        'mu_muts_per_seq': mu_muts_per_seq,
        'k': k,
        'n_iter': n_iter,
        'decay_rate': decay_rate,
        'min_mut_pos': min_mut_pos,
        'max_mut_pos': max_mut_pos,
    }


In [18]:
init_seqs = propose_seqs(
        [constants.AVGFP_AA_SEQ]*n_chains, 
        [SIM_ANNEAL_INIT_SEQ_MUT_RADIUS]*n_chains, 
        min_pos=GFP_LIB_REGION[0], 
        max_pos=GFP_LIB_REGION[1])
mu_muts_per_seq = 1.5*np.random.rand(n_chains) + 1
print('mu_muts_per_seq:', mu_muts_per_seq)

mu_muts_per_seq: [1.79368457 1.77927441 1.21903641 ... 1.88690493 1.45501428 2.23954184]


In [29]:
def get_fitness(seqs):
    tokens = []
    for seq in seqs:
        token_ids = torch.tensor([tokenizer.encode(seq)])
        tokens.append(token_ids)
    inputs = torch.stack(tokens,dim = 0).view(-1,240)
    inputs = inputs.cuda()
    output = model(inputs)
    embedding = torch.mean(output[0][:,1:-1,:],1)
    reps = embedding.cpu().detach().numpy()

    yhat, yhat_std, yhat_mem = EnsembledRidge.predict(reps, 
            return_std=True, return_member_predictions=True)
                
    nmut = utils.levenshtein_distance_matrix(
            [constants.AVGFP_AA_SEQ], list(seqs)).reshape(-1)
        
    mask = nmut > nmut_threshold
    yhat[mask] = -np.inf 
    yhat_std[mask] = 0 
    yhat_mem[mask,:] = -np.inf 
        
    return yhat, yhat_std, yhat_mem 

In [None]:
sa_results = anneal(init_seqs,k=SIM_ANNEAL_K,T_max=T_max,mu_muts_per_seq=mu_muts_per_seq,get_fitness_fn=get_fitness,n_iter=sa_n_iter,decay_rate=temp_decay_rate,min_mut_pos=GFP_LIB_REGION[0],max_mut_pos=GFP_LIB_REGION[1])

In [None]:
pickle.dump(sa_results,'result.p')