In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
import pickle
from tqdm import tqdm
from dataclasses import dataclass
from collections import deque
from sklearn.neural_network import MLPRegressor
import random

In [2]:
def seqDist(s1, s2):
    return sum([1 if b1 != b2 else 0 for b1, b2 in zip(s1, s2)])

In [3]:
def mutateSeq(S, mut):
    pos, mut_base = mut           
    
    if pos >= len(S):
        raise Exception("Position out of range") 
    
    if pos == 0:
        return mut_base + S[1:]
    elif pos == len(S)-1:
        return S[:len(S)-1] + mut_base
    else:
        return S[:pos] + mut_base + S[pos+1:]

In [4]:
def makeMutations(S, mut):
    Sx1 = S
    for i in range(0, len(mut[0])):
        Sx2 = mutateSeq(Sx1, (mut[0][i], mut[1][1][i]))
        Sx1 = Sx2
    return Sx1

In [5]:
def hot1Encode(S):
    encoded = []
    dct_bases = {'A': [1, 0, 0, 0], 'T': [0, 1, 0, 0], 'G': [0, 0, 1, 0], 'C': [0, 0, 0, 1]}
    for b in S:
        encoded.extend(dct_bases[b])
    return encoded

In [6]:
@dataclass
class BfsNode:
    seq: str
    cpm: int
    depth: int
    muts_from_start: int

@dataclass
class Mutant:
    seq: str
    mut: list
    pred_cpm: float

In [7]:
def mutantToNode(M, N):
    
    M_depth = N.depth+1
    M_muts_from_start = N.muts_from_start+len(M.mut[0])
    
    return BfsNode(seq=M.seq, cpm=M.pred_cpm, depth=M_depth, muts_from_start=M_muts_from_start)

In [8]:
def beamSearch(start_seq_seq, start_seq_cpm, mutations, model, beam_width, mode, max_depth, exploration_cpm_threshold, top_explored):

    queue = deque([BfsNode(seq=start_seq_seq, cpm=start_seq_cpm, depth=0, muts_from_start=0)])
    explored_sequences = []
    
    while len(queue):
        
        # Get next node
        next_node = queue.popleft()
        
        # Make every possible mutation -- get mutant batch
        mutants_batch = [Mutant(seq=makeMutations(next_node.seq, mut), mut=mut, pred_cpm=None) for mut in mutations if ''.join([next_node.seq[x] for x in mut[0]]) == mut[1][0]]

        # Run mutants through the blacklist filter
        mutants_batch = [m for m in mutants_batch if m.seq not in blacklist]
        if len(mutants_batch) == 0:
            continue

        # Encode mutants
        mutants_encoded = np.array([np.array(hot1Encode(m.seq)) for m in mutants_batch])

        # Make predictions about mutants
        if len(mutants_encoded) == 1:
            mutants_encoded = mutants_encoded.reshape(1, -1)

        if mode == 'directed':
            predictions = model.predict(mutants_encoded)
        elif mode == 'random':
            predictions = np.random.rand(len(mutants_encoded))

        # Assign the predicted cpms
        for i in range(0, len(predictions)):
            mutants_batch[i].pred_cpm = predictions[i]

        # Convert mutants from mutant class to bfs node class and sort by cpm
        mutants_batch = [mutantToNode(m, next_node) for m in mutants_batch]
        mutants_batch = sorted(mutants_batch, key=lambda p: p.cpm, reverse=True)

        for i, m in enumerate(mutants_batch):
            if len(explored_sequences) < top_explored:
                    explored_sequences.append(m)
                
            else:
                if m.cpm > explored_sequences[-1].cpm:
                    del explored_sequences[-1]
                    explored_sequences.append(m)
                    explored_sequences = sorted(explored_sequences, key=lambda p: p.cpm, reverse=True)                        

            if (i < beam_width) and (m.depth <= max_depth):
                queue.append(m)

            blacklist.add(m.seq)
        
    return explored_sequences

---
## Run

In [9]:
# Load the entire dataset
df = pd.read_csv('/home/jardic/Documents/projects/jk/datasets/datasets_prepped/strc_km.csv', usecols=['varseq', 'cpm'])

# Load the splits
with open('../splits/splits.pkl', mode='rb') as sf:
    splits = pickle.load(sf)

# Load the possible mutations
with open('../prep/allowed_mutations_permuations.pkl', mode='rb') as mf:
    mutations = pickle.load(mf)

# Load the hyperparameter combinations (hyperparameter grid)
with open('../prep/mlp_hyperparameters.pkl', mode='rb')as hpc_f:
    hpcs_all = pickle.load(hpc_f)

# Get the sampled dfs
# [1, 0.1, 0.01, 0.001] These are the sampling of the training data
df_tst, df_val, df_trn = df.loc[splits['tst']], df.loc[splits['val']], df.loc[splits['trn'][2]]

# I'll be using the best training sequences as starting points
df_trn = df_trn.sort_values('cpm', ascending=False)

# Make helpers
seq_2_cpm = {s : c for s, c in zip(df['varseq'], df['cpm'])}

---

In [10]:
beam_search_params = {
    'topN_start' : 5,
    'beam_width' : 5,
    'max_depth' : 5,
    'top_explored' : 100,
    'mode' : 'directed'
}

In [13]:
df_results = pd.DataFrame(columns=['hpc_index', 'hidden_layer_sizes', 'learning_rate_init', 'batch_size', 'precision', 'recall', 'search_space'])

for hpc_index, hpc in hpcs_all:

    # Load one of the trained models
    model_file = '../hpc_tuning_models/model_trained_0p01_hpc_' + str(hpc_index) + '.pkl'
    with open(model_file, mode='rb') as mf:
        model = pickle.load(mf)

    # Run beam search
    blacklist = set(df_trn['varseq'].tolist())
    res_all = []
    for i in range(0, beam_search_params['topN_start']):
        res_start = beamSearch(df_trn.iloc[i]['varseq'], df_trn.iloc[i]['cpm'],
                               mutations=mutations, 
                               model=model,
                               mode=beam_search_params['mode'],
                               beam_width=beam_search_params['beam_width'],
                               max_depth=beam_search_params['max_depth'],
                               top_explored=beam_search_params['top_explored'],
                               exploration_cpm_threshold=df_trn.iloc[0]['cpm'])
        
        res_all.extend(res_start)
    
    df_res_all = pd.DataFrame([[r.seq, r.cpm, r.depth, r.muts_from_start] for r in res_all], columns=['varseq', 'predicted_cpm', 'depth', 'muts'])
    df_res_all = df_res_all.head(100)
    
    seqs_exp = set(df_res_all['varseq'].tolist())
    seqs_val = set(df_val['varseq'].tolist())
    seqs_tst = set(df_tst['varseq'].tolist())
    
    seqs_exp = seqs_exp - seqs_tst
    
    precision = len(seqs_exp.intersection(seqs_val)) / len(seqs_exp)
    #print('precision:', precision)
    
    recall = len(seqs_exp.intersection(seqs_val)) / len(seqs_val)
    #print('recall:', recall)
    
    search_space_size = len(blacklist) - len(df_trn)
    #print('search space size:', search_space_size)
    
    df_results.loc[len(df_results)] = [hpc_index,
                                       hpc['hidden_layer_sizes'],
                                       hpc['learning_rate_init'],
                                       hpc['batch_size'],
                                       precision,
                                       recall,
                                       search_space_size
                                      ]

In [16]:
df_results.sort_values('precision', ascending=False)

Unnamed: 0,hpc_index,hidden_layer_sizes,learning_rate_init,batch_size,precision,recall,search_space
17,17,"(100, 100, 100, 100)",0.0001,100,0.37931,0.22,361858
35,35,"(100, 100, 100, 100)",0.0001,200,0.372881,0.22,361423
11,11,"(100, 100, 100)",0.0001,100,0.368421,0.21,334170
8,8,"(100, 100)",0.0001,100,0.362069,0.21,337987
26,26,"(100, 100)",0.0001,200,0.355932,0.21,359068
14,14,"(200, 200, 200)",0.0001,100,0.345455,0.19,363146
32,32,"(200, 200, 200)",0.0001,200,0.339286,0.19,367077
29,29,"(100, 100, 100)",0.0001,200,0.338983,0.2,367427
34,34,"(100, 100, 100, 100)",0.001,200,0.322581,0.2,357475
2,2,"(50, 50)",0.0001,100,0.322034,0.19,358721


In [17]:
df_results.to_csv('MLP_hyperparams_optimization.csv', index=False)