In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
import pickle
from tqdm import tqdm
from dataclasses import dataclass
from collections import deque
from sklearn.neural_network import MLPRegressor
import random

In [2]:
scheme = {'blue':'#2f788e', 'red':'#d15b4f', 'green':'#45b563', 'grey':'#8a8888'}

plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['font.size'] = 12
plt.rcParams['font.family'] = 'Arial'

In [3]:
def seqDist(s1, s2):
    return sum([1 if b1 != b2 else 0 for b1, b2 in zip(s1, s2)])

In [4]:
def mutateSeq(S, mut):
    pos, mut_base = mut           
    
    if pos >= len(S):
        raise Exception("Position out of range") 
    
    if pos == 0:
        return mut_base + S[1:]
    elif pos == len(S)-1:
        return S[:len(S)-1] + mut_base
    else:
        return S[:pos] + mut_base + S[pos+1:]

In [5]:
def makeMutations(S, mut):
    Sx1 = S
    for i in range(0, len(mut[0])):
        Sx2 = mutateSeq(Sx1, (mut[0][i], mut[1][1][i]))
        Sx1 = Sx2
    return Sx1

In [6]:
def hot1Encode(S):
    encoded = []
    dct_bases = {'A': [1, 0, 0, 0], 'T': [0, 1, 0, 0], 'G': [0, 0, 1, 0], 'C': [0, 0, 0, 1]}
    for b in S:
        encoded.extend(dct_bases[b])
    return encoded

In [7]:
@dataclass
class BfsNode:
    seq: str
    cpm: int
    depth: int
    muts_from_start: int

@dataclass
class Mutant:
    seq: str
    mut: list
    pred_cpm: float

In [8]:
def mutantToNode(M, N):
    
    M_depth = N.depth+1
    M_muts_from_start = N.muts_from_start+len(M.mut[0])
    
    return BfsNode(seq=M.seq, cpm=M.pred_cpm, depth=M_depth, muts_from_start=M_muts_from_start)

In [9]:
def beamSearch(start_seq_seq, start_seq_cpm, mutations, model, beam_width, mode, max_depth, exploration_cpm_threshold, top_explored):

    queue = deque([BfsNode(seq=start_seq_seq, cpm=start_seq_cpm, depth=0, muts_from_start=0)])
    explored_sequences = []
    
    while len(queue):
        
        # Get next node
        next_node = queue.popleft()
        
        # Make every possible mutation -- get mutant batch
        mutants_batch = [Mutant(seq=makeMutations(next_node.seq, mut), mut=mut, pred_cpm=None) for mut in mutations if ''.join([next_node.seq[x] for x in mut[0]]) == mut[1][0]]

        # Run mutants through the blacklist filter
        mutants_batch = [m for m in mutants_batch if m.seq not in blacklist]
        if len(mutants_batch) == 0:
            continue

        # Encode mutants
        mutants_encoded = np.array([np.array(hot1Encode(m.seq)) for m in mutants_batch])

        # Make predictions about mutants
        if len(mutants_encoded) == 1:
            mutants_encoded = mutants_encoded.reshape(1, -1)

        if mode == 'directed':
            predictions = model.predict(mutants_encoded)
        elif mode == 'random':
            predictions = np.random.rand(len(mutants_encoded))

        # Assign the predicted cpms
        for i in range(0, len(predictions)):
            mutants_batch[i].pred_cpm = predictions[i]

        # Convert mutants from mutant class to bfs node class and sort by cpm
        mutants_batch = [mutantToNode(m, next_node) for m in mutants_batch]
        mutants_batch = sorted(mutants_batch, key=lambda p: p.cpm, reverse=True)

        for i, m in enumerate(mutants_batch):
            if len(explored_sequences) < top_explored:
                    explored_sequences.append(m)
                
            else:
                if m.cpm > explored_sequences[-1].cpm:
                    del explored_sequences[-1]
                    explored_sequences.append(m)
                    explored_sequences = sorted(explored_sequences, key=lambda p: p.cpm, reverse=True)                        

            if (i < beam_width) and (m.depth <= max_depth):
                queue.append(m)

            blacklist.add(m.seq)
        
    return explored_sequences

---
## Run

In [21]:
df_res_metrics = pd.DataFrame(columns=['sampling', 'precision', 'recall', 'search_space'])
df_res_seqs = []

In [22]:
# Load the entire dataset
df = pd.read_csv('/home/jardic/Documents/projects/jk/datasets/datasets_prepped/strc_km.csv', usecols=['varseq', 'cpm'])

# Make helpers
seq_2_cpm = {s : c for s, c in zip(df['varseq'], df['cpm'])}

In [23]:
beam_search_params = {
    'topN_start' : 1,
    'beam_width' : 3,
    'max_depth' : 5,
    'top_explored' : 100,
    'mode' : 'directed'
}

In [24]:
# Load the splits
with open('../testing_splits/splits.pkl', mode='rb') as sf:
    splits = pickle.load(sf)

In [25]:
# Load the possible mutations
with open('../prep/allowed_mutations_permuations.pkl', mode='rb') as mf:
    mutations = pickle.load(mf)

In [26]:
for sampling_i, sampling_value in enumerate([1, 0.1, 0.01, 0.001]):
    
    # Get the sampled df
    # [1, 0.1, 0.01, 0.001] These are the sampling of the training data
    df_tst, df_val, df_trn = df.loc[splits['tst']], df.loc[splits['val']], df.loc[splits['trn'][sampling_i]]
    
    # Load the trained model
    with open('../testing_models/model_trained_final_trn' + str(sampling_i) + '.pkl', mode='rb') as mf:
        model = pickle.load(mf)

    # I'll be using the best training sequences as starting points so sorting the training dataframe
    df_trn = df_trn.sort_values('cpm', ascending=False)
    
    # Run beamsearch
    blacklist = set(df_trn['varseq'].tolist())
    res_all = []
    for i in range(0, beam_search_params['topN_start']):
        res_start = beamSearch(df_trn.iloc[i]['varseq'], df_trn.iloc[i]['cpm'],
                               mutations=mutations, 
                               model=model,
                               mode=beam_search_params['mode'],
                               beam_width=beam_search_params['beam_width'],
                               max_depth=beam_search_params['max_depth'],
                               top_explored=beam_search_params['top_explored'],
                               exploration_cpm_threshold=df_trn.iloc[0]['cpm'])
        
        res_all.extend(res_start)
    
    df_res_all = pd.DataFrame([[r.seq, r.cpm, r.depth, r.muts_from_start] for r in res_all], columns=['varseq', 'predicted_cpm', 'depth', 'muts'])
    df_res_all = df_res_all.head(100)

    df_res_seqs.append(df_res_all)
    
    seqs_exp = set(df_res_all['varseq'].tolist())
    seqs_tst = set(df_tst['varseq'].tolist())
    
    precision = len(seqs_exp.intersection(seqs_tst)) / len(seqs_exp)
    #print('precision:', precision)
    
    recall = len(seqs_exp.intersection(seqs_tst)) / len(seqs_tst)
    #print('recall:', recall)
    
    search_space_size = len(blacklist) - len(df_trn)
    #print('search space size:', search_space_size)
    
    df_res_metrics.loc[len(df_res_metrics)] = [sampling_value, precision, recall, search_space_size]
    df_res_seqs.append(df_res_all)

In [27]:
df_res_metrics

Unnamed: 0,sampling,precision,recall,search_space
0,1.0,0.77,0.77,3381.0
1,0.1,0.69,0.69,8537.0
2,0.01,0.41,0.41,9063.0
3,0.001,0.02,0.02,8874.0


In [29]:
with open('results_directed.pkl', mode='wb') as rf:
    pickle.dump([df_res_metrics, df_res_seqs], rf)