In [1]:
#@title Install softwares
import os
import glob
import pathlib
import sys
import time
import traceback
import pickle
import re
from IPython.utils import io
import subprocess
import tqdm.notebook
import urllib3
import gzip


try:
  from google.colab import files
  IN_COLAB = True
except:
  IN_COLAB = False

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

JAX_UNIREP_REPO='https://github.com/ElArkk/jax-unirep.git'
DOWNLOAD_PFAM_SCRIPT='https://raw.githubusercontent.com/xqding/PEVAE_Paper/master/pfam_msa/script/download_MSA.py'
ENTREZ_URL='ftp://ftp.ncbi.nlm.nih.gov/entrez/entrezdirect//install-edirect.sh'

ROOT_DIR='.'
TMP_DIR = f"{ROOT_DIR}/tmp"

SCRIPTS_DIR=f'{TMP_DIR}/scripts'
FLAG_DIR=f'{TMP_DIR}/flag'
WEIGHTS_DIR=f'{TMP_DIR}/weights'

JOBS_DIR=f'{ROOT_DIR}/jobs'

RES_DIR=f'{ROOT_DIR}/results'
RES_DIR_MSA=f"{RES_DIR}/MSA"
RES_DIR_BLAST=f"{RES_DIR}/MSA/blast"
RES_DIR_FITNESS=f"{RES_DIR}/fitness"
RES_DIR_SEQ_OUT=f"{RES_DIR}/sequence_out"
RES_DIR_PICKLE=f'{RES_DIR}/pickle'
RES_DIR_REPORT=f'{RES_DIR}/report'
RES_DIR_FEATURE=f'{RES_DIR}/feature'
RES_DIR_FIGURE=f'{RES_DIR}/figure'
pathes=['/usr/local/cuda-11.4/bin']
libpathes=['/usr/local/cuda-11.4/lib64']

CONDA_PATH='/opt/anaconda3'

# if not already installed
total = 3
with tqdm.notebook.tqdm(total=total, bar_format=TQDM_BAR_FORMAT) as pbar:
  for dir in [ROOT_DIR,
              TMP_DIR,FLAG_DIR,WEIGHTS_DIR,SCRIPTS_DIR,
              JOBS_DIR,
              RES_DIR,RES_DIR_MSA,RES_DIR_BLAST,RES_DIR_FITNESS,RES_DIR_SEQ_OUT,RES_DIR_PICKLE,RES_DIR_REPORT,RES_DIR_FEATURE,RES_DIR_FIGURE]:
    os.makedirs(dir, exist_ok=True)
  pbar.update(1)

  for path in pathes:
    if f"{pathlib.Path(path).resolve()}" not in os.environ['PATH']:
      os.environ['PATH'] = f"{pathlib.Path(path).resolve()}:{os.environ['PATH']}"
  '''
  for path in libpathes:
    if f"{pathlib.Path(path).resolve()}" not in os.environ['LD_LIBRARY_PATH']:
      os.environ['LD_LIBRARY_PATH'] = f"{pathlib.Path(path).resolve()}:{os.environ['LD_LIBRARY_PATH'] }"
  '''
  pbar.update(1)


  # weights
  '''for p in ["tqdm","jax-unirep","biopython","awscli","optuna","seaborn","python-Levenshtein","feather-format"]:
    os.system(f'pip install {p}')'''
  pbar.update(1)
  

  0%|          | 0/3 [elapsed: 00:00 remaining: ?]

In [2]:
#@title Functions used

from jax_unirep import get_reps, fit
from jax_unirep.utils import load_params
from Bio import SeqIO
import pandas as pd
import glob
import os
import numpy as np

from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split, KFold
from sklearn.linear_model import RidgeCV, LinearRegression, HuberRegressor
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsRegressor

import warnings
warnings.filterwarnings('ignore') 

from sklearn.preprocessing import normalize, StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
import feather


# read FASTA file:
# input: file name
# output: names and sequences in the file as an array of dim-2 arrays [name, sequence].
def read_fasta(name):
    # Yinying edited here bcs pfam provides stockholm formated file. 
    # PFam use stockholm while InterPro use fasta
    
    fasta_seqs = SeqIO.parse(open( name ), name.suffix.replace(".",''))
    data = []
    for fasta in fasta_seqs:
        data.append([fasta.id, str(fasta.seq).strip()])
    return data


# write FASTA file:
# input: file name; df to write
def write_fasta(name, seqs_df):
    out_file = open(f'{RES_DIR_SEQ_OUT}/{name}_out.fasta', "w")
    for i in range(len(seqs_df)):
        out_file.write('>' + seqs_df.name[i] + '\n')
        out_file.write(seqs_df.sequence[i] + '\n')
    out_file.close()


# input: takes in a sequence
# output: True if a sequence contains only standard amino acids, returns False if contains non-standard ones.
def validate(seq, pattern=re.compile(r'^[FIWLVMYCATHGSQRKNEPD]+$')):
    if (pattern.match(seq)):
        return True
    return False

    
# Remove sequences longer than k residues and with non-standard residues
# inputs: seqs_df = dataframe of name, sequence; k = max lenght of residues to keep
# output: cleaned dataframe
def clean(seqs_df, k):
    # remove sequences with length > 1000 AA's
    rows2drop = []
    for i in range(len(seqs_df)):
        if (len(seqs_df.sequence[i]) > k):
            rows2drop.append(i)

    print('Total number of sequences dropped due to length >', k, ':', len(rows2drop))

    seqs_df = seqs_df.drop(rows2drop).reset_index().drop('index', axis=1)

    print('Total number of sequences remaining:', len(seqs_df))
    
    # remove sequences with invalid AA residues
    # valid_alphabet = ['F','I','W','L','V','M','Y','C','A','T','H','G','S','Q','R','K','N','E','P','D']
    invalid_seqs = []

    for i in range(len(seqs_df)):
        if (not validate(seqs_df.sequence[i])):
            invalid_seqs.append(i)

    print('Total number of invalid sequences dropped:', len(invalid_seqs))

    seqs_df = seqs_df.drop(invalid_seqs).reset_index().drop('index', axis=1)

    print('Total number of valid sequences remaining:', len(seqs_df))
    
    
    seqs_df = seqs_df.drop_duplicates(subset='sequence').reset_index().drop('index', axis=1)

    print('Total sequences remaining after duplicate removal', len(seqs_df))

    
    return seqs_df


# calculate the Levenstein distance of mulitple sequences to a target sequence
# also plots a histogram of distances
# inputs: t_seq = target sequence; seqs_df = dataframe of sequences;
# num_bins = bins for histogram; hist_range = range for histogram
# outputs: numpy array of distances
def lev_dist(t_seq, seqs_df, num_bins=20, hist_range=(0,350)):
    distances = []
    for i in range(len(seqs_df)):
        distances.append(distance(t_seq, seqs_df.sequence[i]))
    distances = np.array(distances)

    mean_dist = np.mean(distances)
    median_dist = np.median(distances)
    min_dist = np.min(distances)
    max_dist = np.max(distances)
    
    print("Mean Levenstein distance:", mean_dist)
    print("Median Levenstein distance:", mean_dist)
    print("Min Levenstein distance:", min_dist)
    print("Max Levenstein distance:", max_dist)

    
    # histogram of Levenstein distances from target sequence
    plt.clf()
    plt.hist(distances, bins=num_bins, range=hist_range)
    plt.show()
    
    return distances



"""
File formatting note.
Data should be preprocessed as a sequence of comma-seperated ints with
sequences  /n seperated
"""

# Lookup tables
aa_to_int = {
    'M':1,
    'R':2,
    'H':3,
    'K':4,
    'D':5,
    'E':6,
    'S':7,
    'T':8,
    'N':9,
    'Q':10,
    'C':11,
    'U':12,
    'G':13,
    'P':14,
    'A':15,
    'V':16,
    'I':17,
    'F':18,
    'Y':19,
    'W':20,
    'L':21,
    'O':22, #Pyrrolysine
    'X':23, # Unknown
    'Z':23, # Glutamic acid or GLutamine
    'B':23, # Asparagine or aspartic acid
    'J':23, # Leucine or isoleucine
    'start':24,
    'stop':25,
}

int_to_aa = {value:key for key, value in aa_to_int.items()}

def get_aa_to_int():
    """
    Get the lookup table (for easy import)
    """
    return aa_to_int

def get_int_to_aa():
    """
    Get the lookup table (for easy import)
    """
    return int_to_aa
    
def aa_seq_to_int(s):
    """
    Return the int sequence as a list for a given string of amino acids
    """
    return [24] + [aa_to_int[a] for a in s] + [25]

def int_seq_to_aa(s):
    """
    Return the int sequence as a list for a given string of amino acids
    """
    return "".join([int_to_aa[i] for i in s])

    
def format_seq(seq,stop=False):
    """
    Takes an amino acid sequence, returns a list of integers in the codex of the babbler.
    Here, the default is to strip the stop symbol (stop=False) which would have 
    otherwise been added to the end of the sequence. If you are trying to generate
    a rep, do not include the stop. It is probably best to ignore the stop if you are
    co-tuning the babbler and a top model as well.
    """
    if stop:
        int_seq = aa_seq_to_int(seq.strip())
    else:
        int_seq = aa_seq_to_int(seq.strip())[:-1]
    return int_seq

def is_valid_seq(seq, max_len=2000):
    """
    True if seq is valid for the babbler, False otherwise.
    """
    l = len(seq)
    valid_aas = "MRHKDESTNQCUGPAVIFYWLO"
    if (l < max_len) and set(seq) <= set(valid_aas):
        return True
    else:
        return False

def fasta_to_input(in_path):
    source = SeqIO.parse(in_path + '.fasta.txt','fasta')
    with open(in_path + "_formatted.fasta.txt", "w") as destination:
        for seq in fasta_seqs:
            seq = seq.strip()
            if is_valid_seq(seq) and len(seq) < 275: 
                formatted = ",".join(map(str,format_seq(seq)))
                destination.write(formatted)
                destination.write('\n')

def seqs_to_input(name, in_seqs, stop=False):
    with open('outputs/' + name + "_formatted.fasta.txt", "w") as destination:
        for seq in in_seqs:
            seq = seq.strip()
            if is_valid_seq(seq): 
                formatted = ",".join(map(str,format_seq(seq, stop=stop)))
                destination.write(formatted)
                destination.write('\n')
                
                

def read_labeled_data(fitness_csv,seq_col,fit_col): 
  df=pd.read_csv(pathlib.Path(fitness_csv),usecols=[seq_col,fit_col])
  data = []
  for seq, fitness in zip(df.loc[:, seq_col], df.loc[:, fit_col]):
        data.append([str(seq).strip(), fitness])
  return data
# data=read_labeled_data("./P450_experi_data_encoded.csv",'Seq','11H-Cuol')


# save represented dataframe of features as feather
def save_reps(df, path):
  feather.write_dataframe(df, path + '.feather')
  print(path + '.feather', 'saved!')


# read represented dataframe of features as feather
def read_reps(path):
  return feather.read_dataframe(path + '.feather')


def _one_hot(x, k, dtype=np.float32):
  # return np.array(x[:, None] == np.arange(k), dtype)
  return np.array(x[:, None] == np.arange(k))


def aa_seq_to_int(s):
  """Return the int sequence as a list for a given string of amino acids."""
  # Make sure only valid aa's are passed
  if not set(s).issubset(set(aa_to_int.keys())):
    raise ValueError(
      f"Unsupported character(s) in sequence found:"
      f" {set(s).difference(set(aa_to_int.keys()))}"
    )

  return [aa_to_int[a] for a in s]


def aa_seq_to_onehot(seq):
  return 1*np.equal(np.array(aa_seq_to_int(seq))[:,None], np.arange(21)).flatten()
  

def multi_onehot(seqs):
  return np.stack([aa_seq_to_onehot(s) for s in seqs.tolist()])


def distance_matrix(N):
	#A is the one who works with the distance_matrix = np.zeros((N,N))
	for i in range(N):
		for j in range(N):
			# distance_matrix[i,j]=1- ((abs(i-j)/N)**2)
			distance_matrix[i,j]= 1-(abs(i-j)/N)

	return distance_matrix


def confusion_matrix_loss(Y_test,Y_preds_test):

  N = len(Y_test)
  Y_rank_matrix = np.zeros((N,N))
  Y_preds_rank_matrix = np.zeros((N,N))
  for i in range(N):
    for j in range(N):

      if Y_test[i] > Y_test[j]:
        Y_rank_matrix[i,j] = 1
      elif Y_test[i] <= Y_test[j]:
        Y_rank_matrix[i,j] = 0
      if Y_preds_test[i] > Y_preds_test[j]:
        Y_preds_rank_matrix[i,j] = 1
      elif Y_preds_test[i] <= Y_preds_test[j]:
        Y_preds_rank_matrix[i,j] = 0
  confusion_matrix = ~(Y_preds_rank_matrix == Y_rank_matrix)
  # dist_mat = distance_matrix(N)
  # confusion_matrix = confusion_matrix*dist_mat
  loss = np.sum(confusion_matrix)/confusion_matrix.size

  return loss



In [3]:
#@title  load labeled training data
col_seq='Seq' #@param {type:"string"}
col_fitness='11H-Cuol' #@param {type:"string"}
seqs_df = pd.DataFrame(read_labeled_data(f"{RES_DIR_FITNESS}/P450_experi_data_encoded.csv",col_seq,col_fitness), columns = ['sequence', 'fitness'])

In [4]:
#@title  define hyper params 
PROJECT_NAME='evotuning_P450_from'

gdrive_path=f'{RES_DIR_PICKLE}/'
DIR_PATH = f'{RES_DIR_FEATURE}/{PROJECT_NAME}'
#os.system(f'mkdir -p {pathlib.Path(DIR_PATH).resolve()}')

PARAMS = [ '_global/iter_0','_randinit/iter_0',None, 'one_hot',]

FEATHER_PATH = gdrive_path 

unirep_df = read_reps(FEATHER_PATH + 'unirep')
eunirep_df = read_reps(FEATHER_PATH + 'evotuning_P450_from-_global_iter_0')
rand_eunirep_df=read_reps(FEATHER_PATH + 'evotuning_P450_from-_randinit_iter_0')
onehot_df = read_reps(FEATHER_PATH + 'one_hot')

dfs = [unirep_df, eunirep_df, onehot_df,rand_eunirep_df]
df_names = ['unirep', 'eunirep', 'one hot','rand_eunirep']

In [5]:
top_variant='I46L-L48F-S49L-I61F-L120T-C343Y-T352I-L356P'

fitness_table_df = pd.read_csv(f"{RES_DIR_FITNESS}/P450_experi_data_encoded.csv")

s_wt = fitness_table_df[(fitness_table_df.loc[:,'variants']==top_variant)].Seq.tolist()[0]

DE_record_folder = f'{RES_DIR_SEQ_OUT}/P450_DE_test_100_from_top' # assign project name for the simulation


In [6]:
s_wt

'MWTILLGLATLAIAYYIHWVNKWKDSKFNGVLPPGTMGLPLIGETLQFLRPSDSLDVHPFFQRKVKRYGPIFKTCLAGRPVVVSTDAEFNHYIMLQEGRAVEMWYLDTLSKFFGLDTEWTKALGLIHKYIRSITLNHFGAESLRERFLPRIEESARETLHYWSTQTSVEVKESAAAMVFRTSIVKMFSEDSSKLLTEGLTKKFTGLLGGFLTLPLNLPGTTYHKCIKDMKQIQKKLKDILEERLAKGVKIDEDFLGQAIKDKESQQFISEEFIIQLLFSISFASFESISTTLTLILNFLADHPDVVKELEAEHEAIRKARADPDGPITWEEYKSMNFTLNVIYETLRLGSVIPALPRKTTKEIQIKGYTIPEGWTVMLVTASRHRDPEVYKDPDTFNPWRWKELDSITIQKNFMPFGGGLRHCAGAEYSKVYLCTFLHILFTKYRWRKLKGGKIARAHILRFEDGLYVNFTPKE'

In [7]:

def inspect_variant_name(wt,variant,prefix=None):
    assert len(wt)==len(variant)
    mut_table=[]
    for i in range(len(wt)):
        if wt[i]!=variant[i]: 
            mut_table.append(f'{wt[i]}{i+1}{variant[i]}')
    
    variant_name="_".join(mut_table)
        
    return f'{prefix}_{variant_name}' if prefix != None else variant_name


def parse_DE_results(DE_record_folder):
    trajectory_files=[x for x in os.listdir(DE_record_folder) if x.endswith('_seqs.txt')]
    #print(trajectory_files)
    DE_sequences={}
    pbar=tqdm.notebook.tqdm(total=len(trajectory_files), bar_format=TQDM_BAR_FORMAT)
    for trajectory_file in trajectory_files:
        pbar.update(1)
        fitness_file=trajectory_file.replace('_seqs.txt','_fitness.txt')
        seqs=open(f'{DE_record_folder}/{trajectory_file}','r').readlines()[2:]
        fitnesses=open(f'{DE_record_folder}/{fitness_file}','r').readlines()[1:]
        assert len(seqs)==len(fitnesses)
        #print(len(seqs))
        for seq,fitness in zip(seqs,fitnesses):
            seq=seq.strip()
            #if seq==s_wt:print('seq==s_wt')
            variant_name=inspect_variant_name(s_wt,seq.strip(),prefix='top')
            DE_sequences[variant_name]={'seq':seq,
                                       'fitness':float(fitness.strip())}
    return DE_sequences
    

In [8]:
DE_sequences=parse_DE_results(DE_record_folder)

  0%|          | 0/10002 [elapsed: 00:00 remaining: ?]

In [70]:
top_fitness=DE_sequences['top_']['fitness']
improved_design={x:DE_sequences[x] for x in DE_sequences if DE_sequences[x]['fitness']>top_fitness}


In [71]:
improved_design

{'top_N216H_Y222T': {'seq': 'MWTILLGLATLAIAYYIHWVNKWKDSKFNGVLPPGTMGLPLIGETLQFLRPSDSLDVHPFFQRKVKRYGPIFKTCLAGRPVVVSTDAEFNHYIMLQEGRAVEMWYLDTLSKFFGLDTEWTKALGLIHKYIRSITLNHFGAESLRERFLPRIEESARETLHYWSTQTSVEVKESAAAMVFRTSIVKMFSEDSSKLLTEGLTKKFTGLLGGFLTLPLHLPGTTTHKCIKDMKQIQKKLKDILEERLAKGVKIDEDFLGQAIKDKESQQFISEEFIIQLLFSISFASFESISTTLTLILNFLADHPDVVKELEAEHEAIRKARADPDGPITWEEYKSMNFTLNVIYETLRLGSVIPALPRKTTKEIQIKGYTIPEGWTVMLVTASRHRDPEVYKDPDTFNPWRWKELDSITIQKNFMPFGGGLRHCAGAEYSKVYLCTFLHILFTKYRWRKLKGGKIARAHILRFEDGLYVNFTPKE',
  'fitness': 0.20600366592407227},
 'top_N216H_G219Y_Y222T_K224S_M229Q_K235V': {'seq': 'MWTILLGLATLAIAYYIHWVNKWKDSKFNGVLPPGTMGLPLIGETLQFLRPSDSLDVHPFFQRKVKRYGPIFKTCLAGRPVVVSTDAEFNHYIMLQEGRAVEMWYLDTLSKFFGLDTEWTKALGLIHKYIRSITLNHFGAESLRERFLPRIEESARETLHYWSTQTSVEVKESAAAMVFRTSIVKMFSEDSSKLLTEGLTKKFTGLLGGFLTLPLHLPYTTTHSCIKDQKQIQKVLKDILEERLAKGVKIDEDFLGQAIKDKESQQFISEEFIIQLLFSISFASFESISTTLTLILNFLADHPDVVKELEAEHEAIRKARADPDGPITWEEYKSMNFTLNVIYETLRLGSVIPALPRKTTKEIQIKGYTIPEGWTVMLVTASRHRDPEVYKDPDTFNPWRWKELDS

In [72]:
import pickle
PROJECT_NAME='P450_Low-N_890-11H-Cuol'
pickle.dump(DE_sequences,open(f'{RES_DIR_PICKLE}/{PROJECT_NAME}_all_trajectory.pkl','wb'))

In [75]:

fitness_jump_cutoff=sum([improved_design[x]['fitness'] for x in improved_design])/len(improved_design)-top_fitness

In [76]:
fitness_jump_cutoff

0.03396023819120686

In [9]:
def parse_DE_fitness_jump(DE_record_folder,fitness_jump_cutoff,fitness_baseline):
    trajectory_files=[x for x in os.listdir(DE_record_folder) if x.endswith('_seqs.txt')]
    #print(trajectory_files)
    jumps={}
    pbar=tqdm.notebook.tqdm(total=len(trajectory_files), bar_format=TQDM_BAR_FORMAT)
    for trajectory_file in trajectory_files:
        pbar.update(1)
        fitness_file=trajectory_file.replace('_seqs.txt','_fitness.txt')
        seqs=open(f'{DE_record_folder}/{trajectory_file}','r').readlines()[2:]
        fitnesses=open(f'{DE_record_folder}/{fitness_file}','r').readlines()[1:]
        assert len(seqs)==len(fitnesses)
        #print(len(seqs))
        fitness_jumps={inspect_variant_name(seqs[i-1],seqs[i],prefix='jump'):float(fitnesses[i].strip())-float(fitnesses[i-1].strip()) for i in range(2,len(seqs)) if float(fitnesses[i].strip())-float(fitnesses[i-1].strip())>fitness_jump_cutoff and float(fitnesses[i-1].strip())>fitness_baseline}
        jumps.update(fitness_jumps)
    return jumps
    

In [92]:
jumps=parse_DE_fitness_jump(DE_record_folder,fitness_jump_cutoff,top_fitness)

  0%|          | 0/10002 [elapsed: 00:00 remaining: ?]

In [93]:
jumps

{'jump_H91I_E102D_K103F': 0.04995298385620117,
 'jump_N340M_R347F': 0.048888206481933594,
 'jump_W104E': 0.04612159729003906,
 'jump_L331H_Y332H_L339E_L346N': 0.06214284896850586,
 'jump_Q128G_A140E': 0.06411099433898926,
 'jump_S96E_L106S': 0.06299424171447754,
 'jump_E284G': 0.06506609916687012,
 'jump_L346E_P356T': 0.052550315856933594,
 'jump_C114W_P125R_I126G_H127K': 0.053603410720825195,
 'jump_G139Q_E145S': 0.08072280883789062,
 'jump_V341G_G349F_V351L_A354T': 0.042560577392578125,
 'jump_K249N_Y250N_A256H_C257Y': 0.04041290283203125,
 'jump_K77S_P84F': 0.06647825241088867,
 'jump_R347A_G349H': 0.035164594650268555,
 'jump_H349S': 0.034276723861694336,
 'jump_R347N': 0.05757284164428711,
 'jump_I202T_C205Q': 0.0379948616027832,
 'jump_L276R_L277W': 0.03583526611328125,
 'jump_R357M_A369H': 0.04451584815979004,
 'jump_P99A_I101Y_W104H': 0.061484336853027344,
 'jump_T345P': 0.03979349136352539,
 'jump_W337S_K345M': 0.03856229782104492,
 'jump_G325T_I327Q_W329S_H339T': 0.0361030101

In [94]:
len(jumps)

721

In [96]:
import random
alphabet='ARNDCQEGHILKMFPSTWYV'
alphabet_excluded=alphabet.replace('I','')


In [51]:
def predict_fitness(seq,param,Model):
    x,_,_ = get_reps([seq],params=params)
    y = Model.predict(x)
    return y

In [34]:


def inspect_variant_sequence(wt_seq,mutations,sep='-'):
    
    mutations=mutations.split(sep)
    mutant_seq=list(wt_seq)
    #print(mutant_seq)
    for mutation in mutations:
        orginal_aa=mutation[0]
        mutated_aa=mutation[-1]
        mutate_loc=int(mutation[1:-1])
        #print(mutate_loc)
        if mutant_seq[mutate_loc-1]==orginal_aa:
            mutant_seq[mutate_loc-1]=mutated_aa
        else:
            raise ValueError(f'orginal_aa {mutate_loc-1} {mutant_seq[mutate_loc-1]} != {orginal_aa}')
    return ''.join(mutant_seq)

In [48]:
def init_by_rep(df, alpha_val, N, param_file):

  if param_file == None:
    params = load_params(None)[1]
    DE_model = get_top_model(df, [alpha_val], N) # choose unirep representation, alpha=1e-3, and 96 training mutants


  else: # if we want to use an evotuned representation:
    params = load_params(param_file)[1]
    DE_model = get_top_model(df, [alpha_val], N) # choose eunirep representation, alpha=1e-3, and 96 training mutants

  return params, DE_model


def read_labeled_data(fitness_csv,seq_col,fit_col): 
  df=pd.read_csv(pathlib.Path(fitness_csv),usecols=[seq_col,fit_col])
  data = []
  for seq, fitness in zip(df.loc[:, seq_col], df.loc[:, fit_col]):
        data.append([str(seq).strip(), fitness])
  return data

def get_top_model(df, alpha, train_batch_size):

  rand_state_num = 42

  np.random.seed(rand_state_num)

  rndperm = np.random.permutation(df.shape[0])

  X_train = df.loc[rndperm[:train_batch_size], df.columns[2:]]
  Y_train = df.loc[rndperm[:train_batch_size], "fitness"]

  kfold = KFold(n_splits=10, random_state=rand_state_num, shuffle=True)

  return RidgeCV(alphas=alpha, cv=kfold).fit(X_train, Y_train)


# read represented dataframe of features as feather
def read_reps(path):
  return feather.read_dataframe(path + '.feather')



In [44]:
#@title define input parameters
seqs_df = pd.DataFrame(read_labeled_data(f"{RES_DIR_FITNESS}/P450_experi_data_encoded.csv",col_seq,col_fitness), columns = ['sequence', 'fitness'])

df = read_reps(FEATHER_PATH + 'evotuning_P450_from-_global_iter_0')

param_file=f'{RES_DIR_FEATURE}/evotuning_P450_from_global/iter_0'
print(param_file)


./results/feature/evotuning_P450_from_global/iter_0


In [45]:
#@title define hyperparameters
alpha = 0.0015

BATCH_SIZE = 42

rand_state_num = 42

In [52]:
#@title initialize the directed evolution

np.random.seed(rand_state_num)
rndperm = np.random.permutation(df.shape[0])


TRAIN_BATCH_SIZE = int(BATCH_SIZE*0.8)
HOLDOUT_BATCH_SIZE = int(BATCH_SIZE*0.2)

training_df = df.iloc[rndperm[:TRAIN_BATCH_SIZE],:]
                          
testing_df = df.iloc[rndperm[TRAIN_BATCH_SIZE:TRAIN_BATCH_SIZE+HOLDOUT_BATCH_SIZE],:]

params, Model = init_by_rep(df, alpha, TRAIN_BATCH_SIZE, param_file)

In [53]:
wt_890_seq=fitness_table_df[(fitness_table_df.loc[:,'variants']=='WT')].Seq.tolist()[0]

validate_list=f'{RES_DIR_FITNESS}/validation_set_890_11-H-Cuol.txt'
validate_sequence_name=open(validate_list,'r').read().split('\n')
#print(f'>WT\n{wt_890_seq}')
predicted_fitness={}
for variant in validate_sequence_name:
    variant=variant.strip()
    #print(f'>{variant}')
    variant_sequence=inspect_variant_sequence(wt_890_seq,variant,sep='-')
    #print(variant_sequence)
    variant_fitness=predict_fitness(variant_sequence,params,Model)
    predicted_fitness[variant]={'seq':variant_sequence,
                               'fitness':variant_fitness}

In [60]:
for v in predicted_fitness:
    print(f'{predicted_fitness[v]["fitness"].tolist()[0]}')

-0.015680789947509766
0.12171673774719238
0.021290302276611328
0.04916954040527344
0.048044681549072266
0.04303622245788574
0.027156591415405273
0.05421566963195801
0.014623165130615234
0.11374425888061523
0.044205427169799805
0.07687664031982422
0.10109996795654297
0.04227924346923828
0.04429054260253906
0.056641340255737305
0.08367705345153809
0.026111841201782227
-0.017404556274414062
0.07050752639770508
0.05086469650268555
0.07879352569580078
0.09364438056945801
0.02442145347595215
0.05504965782165527
0.062159061431884766
0.08647966384887695
0.05785179138183594
0.04699397087097168
0.07329249382019043
0.09370565414428711
0.06436419486999512
0.08708047866821289
0.09451150894165039
0.027662038803100586
0.13572478294372559
0.09847354888916016
0.0423130989074707
0.11659717559814453
0.02649402618408203
0.007451057434082031
0.07086515426635742
0.061036109924316406
0.03975796699523926
