In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import torch
from nn import MotifCaller, NaiveCaller
from training_data import data_preproc, load_training_data
from utils import get_savepaths
import numpy as np
from sklearn.preprocessing import normalize
from greedy_decoder import GreedyCTCDecoder
from Levenshtein import ratio
from utils import load_model, get_metrics_for_evaluation
from transcript_sorting import sort_transcript_reduced_spacers, sort_transcript
from sklearn.model_selection import train_test_split
from beam_search_decoder import beam_search_ctc, torch_ctc
import torch.nn as nn
from tqdm.notebook import tqdm
import numpy as np
#import matplotlib.pyplot as plt
from typing import List
from sklearn.preprocessing import normalize
from torch.nn.utils.rnn import pad_sequence
from beam_search_decoder import beam_search_ctc, torch_ctc
import matplotlib.pyplot as plt

### Loading models and data

In [None]:
n_classes = 19
model_path_forward = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\models\empirical\final_models\latest_models\edit_forward.pth"
model_path_mixed = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\models\empirical\final_models\latest_models\edit_mixed.pth"
model_path_reverse = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\models\empirical\final_models\latest_models\edit_reverse.pth"

labels_int = np.arange(n_classes).tolist()
labels = [f"{i}" for i in labels_int] # Tokens to be fed into greedy decoder
greedy_decoder = GreedyCTCDecoder(labels = labels)
ctc = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
model_forward = load_model(model_path=model_path_forward, device=device, n_classes=n_classes, hidden_size=256)
model_reverse = load_model(model_path=model_path_reverse, device=device, n_classes=n_classes, hidden_size=256)
model_mixed = load_model(model_path=model_path_mixed, device=device, n_classes=n_classes, hidden_size=256)

test_dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\full_datasets\diluted_squiggles.pkl"

#test_dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\sequencing_runs\01-04run\misc_datasets\small_squiggle_df.pkl"
#test_dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\sequencing_runs\01-04run\finetuning_datasets\edit_train.pkl"

dataset = pd.read_pickle(test_dataset_path)

In [None]:
# For the new datasets
X, y = load_training_data(
        test_dataset_path, column_x='squiggle', column_y='payload', payload=False, orientation=False)

In [None]:
"""
X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42)

X_train, X_test, payloads_train, payloads_test = train_test_split(
        X, payloads, test_size=0.2, random_state=42)
"""

In [None]:
def get_error_by_cycle_and_motif(preds, actual):
    cycles = np.zeros(8)
    motifs = np.zeros(8)

    
    for pred, act in zip(preds, actual):
        counter = 0
        for i, j in zip(pred, act):
            for k in i:
                if k not in j:
                    cycles[counter] += 1
                    motifs[k-1] += 1
            counter += 1

    return cycles / len(preds) * 100, motifs / len(preds) * 100

In [None]:
def evaluate_prediction(prediction, original):

    found = 0
    err = 0
    for i, j in zip(prediction, original):
        for k in range(len(i)):
            if i[k] in j:
                found += 1
            else:
                err += 1

    return found, err

In [None]:
def detect_reverse_oriented_read(transcript_reverse_model):

    # Remove payloads, make set, and check that each spacer motif is increasing
    no_payload_transcript = [
        int(i) for i in transcript_reverse_model if int(i) > 8]

    first_spacer = no_payload_transcript[0]
    flag = False
    for j in no_payload_transcript[1:]:
        if not j == first_spacer:
            if j < first_spacer:
                flag = True
            else:
                flag = False
            break
        
    return flag

    

In [None]:
def sort_and_evaluate_predictions(
        model_outputs_forward, labels, payload, model_outputs_reverse=None, model_outputs_mixed=None, orientation=None, beam=False, beam_width=10, prob_threshold=0.0):
    
    original = payload
    
    if beam:
        transcript_forward = torch_ctc(
            n_classes=19, model_output=model_outputs_forward.unsqueeze(0),
            beam_width=beam_width)
    else:
        result_forward, quality_forward = greedy_decoder.forward_with_quality(
            model_outputs_forward, prob_threshold=prob_threshold)
        transcript_forward = " ".join(result_forward)
    sorted_forward = sort_transcript(transcript_forward)
    
        
    if beam:
        transcript_reverse = torch_ctc(
            n_classes=19, model_output=model_outputs_reverse.unsqueeze(0),
            beam_width=beam_width)
    else:
        result_reverse, quality_reverse = greedy_decoder.forward_with_quality(
            model_outputs_reverse, prob_threshold=prob_threshold)
        transcript_reverse = " ".join(result_reverse)
    sorted_reverse = sort_transcript(transcript_reverse)

    actual_transcript = " ".join([str(i) for i in labels])
    sorted_actual = sort_transcript(actual_transcript)

    #motifs_found_search, motif_errs_search = evaluate_prediction(
    #    sorted_actual, original)

    motifs_found_search, motif_errs_search = [], []

    result_mixed = greedy_decoder.forward(model_outputs_mixed)
    
    motifs_found_caller_forward, motif_errs_caller_forward = evaluate_prediction(
        sorted_forward, original)
    
    if model_outputs_reverse is not None:
        motifs_found_caller_reverse, motif_errs_caller_reverse = evaluate_prediction(
            sorted_reverse, original)
        """
        if orientation is not None:
            if orientation == 1:
                return motifs_found_caller_forward, motif_errs_caller_forward, motifs_found_search, motif_errs_search, quality_forward, sorted_forward, 1
            else:
                return motifs_found_caller_reverse, motif_errs_caller_reverse, motifs_found_search, motif_errs_search, quality_reverse, sorted_reverse, 0
        """
        if detect_reverse_oriented_read(result_mixed):
            return motifs_found_caller_reverse, motif_errs_caller_reverse, motifs_found_search, motif_errs_search, quality_reverse, sorted_reverse, 0
        else:
            return motifs_found_caller_forward, motif_errs_caller_forward, motifs_found_search, motif_errs_search, quality_forward, sorted_forward, 1
        
    return motifs_found_caller_forward, motif_errs_caller_forward, motifs_found_search, motif_errs_search


### Batched evaluation loop

In [None]:
X_test = dataset['squiggle'].to_list()
#y_test = dataset['motif_seq'].to_list()
#edit_seq = dataset['edit_search_seq'].to_list()  # edit_search_seq
#read_ids_ = dataset['read_id'].to_list()
payloads_test = dataset['payload'].to_list()  # payload_seq
#orientations = dataset['orientation_x'].to_list()  # strand
#orientations = [1 if i.startswith('+') else 0 for i in orientations]

In [None]:
payloads_test = dataset['payload'].apply(lambda x: list(eval(x)))

In [None]:
torch.set_default_device(device)

In [None]:

n_test_samples = len(X_test)
batch_size = 8
prob_threshold = 0.85


results_dict = {
    "mf_caller": [],
    "me_caller": [],
    "mf_search": [],
    "me_search": [],
    "mf_edit": [],
    "me_edit": [],
    "orientation": [],
    "quality": []
}
orientations_as_you_go = []

with torch.no_grad():
    for ind in tqdm(range(0, n_test_samples, batch_size)):

        if n_test_samples - ind < batch_size:
            continue
        
        input_seqs = [
            normalize([X_test[k]], norm='max').flatten() for k in range(ind, ind + batch_size)]
        
        input_seqs = pad_sequence([torch.tensor(
                    i, dtype=torch.float32) for i in input_seqs], batch_first=True)
        
        input_seqs = input_seqs.view(input_seqs.shape[0], 1, input_seqs.shape[1])
        
        
        model_output_forward = model_forward(input_seqs).detach().cpu()
        model_output_reverse = model_reverse(input_seqs).detach().cpu()
        model_output_mixed = model_mixed(input_seqs).detach().cpu()

        for k in range(batch_size):
            #orientation = orientations[ind + k]

            
            motifs_found_caller, motif_errs_caller, motifs_found_search, motif_errs_search, quality, transcript, orientation = sort_and_evaluate_predictions(
            model_outputs_forward=model_output_forward[k], labels=[],
            payload=payloads_test[ind + k], model_outputs_reverse=model_output_reverse[k],
            model_outputs_mixed=model_output_mixed[k], orientation=False, beam=False,
            beam_width=15, prob_threshold=prob_threshold)

            #motifs_found_edit, motif_errs_edit = evaluate_prediction(
            #    edit_seq[ind + k][2:], payloads_test[ind + k])
            
            results_dict['mf_caller'].append(motifs_found_caller)
            results_dict['me_caller'].append(motif_errs_caller)
            #results_dict['mf_search'].append(motifs_found_search)
            #results_dict['me_search'].append(motif_errs_search)
            #results_dict['mf_edit'].append(motifs_found_edit)
            #results_dict['me_edit'].append(motif_errs_edit)
            results_dict['orientation'].append(orientation)
            results_dict['quality'].append(quality)
        
        torch.cuda.empty_cache()
        

In [None]:
columns = ['mf_caller', 'me_caller', 'orientation', 'quality']
results_df = pd.DataFrame({i: results_dict[i] for i in columns})

In [None]:
filtered_df = results_df

In [None]:
quality_threshold = 11

filtered_df = results_df.loc[results_df['quality'] > quality_threshold]

mean_values = filtered_df.mean(numeric_only=True)

print("Mean metrics combined:")
print(mean_values.to_frame(name='Mean').T)
print()

mean_values = filtered_df.loc[filtered_df['orientation'] == 1].mean(numeric_only=True)
print("Mean metrics forward")
print(mean_values.to_frame(name='Mean').T)
print()

mean_values = filtered_df.loc[filtered_df['orientation'] == 0].mean(numeric_only=True)
print("Mean metrics reverse")
print(mean_values.to_frame(name='Mean').T)
print()

print(f"{len(filtered_df) / len(results_df) * 100}% of read pool")

In [None]:
ef = pd.read_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\sequencing_runs\01-04run\finetuning_datasets\edit_forward_01_04.pkl")

In [None]:
6.68/8

In [None]:
reverse_reads = results_df.loc[results_df['orientation'] == 0]

### Collecting inference outputs

In [None]:

n_test_samples = len(X_test)
batch_size = 8
prob_threshold = 0
sorted_caller = []
qualities = []
full_qs = []
caller_orientations = []
greedy_transcripts = []
read_ids = []

with torch.no_grad():
    for ind in tqdm(range(0, n_test_samples, batch_size)):

        if n_test_samples - ind < batch_size:
            continue
        
        input_seqs = [
            normalize([X_test[k]], norm='max').flatten() for k in range(ind, ind + batch_size)]
        
        input_seqs = pad_sequence([torch.tensor(
                    i, dtype=torch.float32) for i in input_seqs], batch_first=True)
        
        input_seqs = input_seqs.view(input_seqs.shape[0], 1, input_seqs.shape[1])
        
        model_output_mixed = model_mixed(input_seqs).detach().cpu()

        
        model_output_forward = model_forward(input_seqs).detach().cpu()
        model_output_mixed = model_mixed(input_seqs).detach().cpu()
        model_output_reverse = model_reverse(input_seqs).detach().cpu()
        

        for k in range(batch_size):

            greedy_result_mixed = greedy_decoder(
                model_output_mixed[k]
            )

            if detect_reverse_oriented_read(greedy_result_mixed):
                caller_orientations.append(0)
                greedy_result_reverse, quality, full_q = greedy_decoder.forward_with_quality(model_output_reverse[k], prob_threshold=prob_threshold, full_qualities=True)
                greedy_transcript = " ".join(greedy_result_reverse)
            else:
                caller_orientations.append(1)
                greedy_result_forward, quality, full_q = greedy_decoder.forward_with_quality(
                model_output_forward[k], prob_threshold=prob_threshold, full_qualities=True)
                greedy_transcript = " ".join(greedy_result_forward)

            sorted_greedy = sort_transcript(greedy_transcript)

            sorted_caller.append(sorted_greedy)
            qualities.append(quality)
            full_qs.append(full_q)
            greedy_transcripts.append(greedy_transcript)
            read_ids.append(ind + k)
        
        torch.cuda.empty_cache()
        
        

In [None]:
input_seqs = [
            normalize([X_test[k]], norm='max').flatten() for k in range(ind, ind + batch_size)]

In [None]:
dataset.drop(dataset.tail(4).index,
        inplace = True)

In [None]:
len(sorted_caller)

In [None]:
sorted_caller

In [None]:
dataset['caller_orientations'] = caller_orientations
dataset['caller_seq'] = sorted_caller
dataset['qualities'] = qualities
dataset['full_q'] = full_qs
dataset['greedy_transcripts'] = greedy_transcripts

In [None]:
for ind, row in dataset.iterrows():

    print(evaluate_prediction(row['caller_seq'], row['payload_seq']))

In [None]:
dataset['full_q'] = dataset['full_q'].apply(lambda x: [float(i) for i in x])

In [None]:
dataset.drop(columns=['squiggle'], inplace=True)

In [None]:
dataset.to_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\full_datasets\master_test_labelled.pkl")

In [None]:
dataset

### Spacers classification

In [None]:
# Spacers

filtered_df = dataset.loc[dataset['orientation'].str.startswith('+')]

y_test = dataset['edit_spacer_seq']

n_training_samples = len(X_test)
batch_size = 8
diff = 0
n_test_samples = 100
motifs_found_caller_arr = []
motif_errs_caller_arr = []
motifs_found_search_arr = []
motif_errs_search_arr = []
model = model_forward.to(device)

matched_spacers_arr = []
unique_spacers_arr = []


with torch.no_grad():
    for ind in tqdm(range(0, n_training_samples, batch_size)):

        if n_training_samples - ind < batch_size:
            continue
        
        input_seqs = [
            normalize([X_test[k]], norm='max').flatten() for k in range(ind, ind + batch_size)]
        
        target_seqs = y_test[ind: ind + batch_size]

        input_seqs = pad_sequence([torch.tensor(
                    i, dtype=torch.float32) for i in input_seqs], batch_first=True)
        target_seqs = pad_sequence([torch.tensor(
                    i, dtype=torch.float32) for i in target_seqs], batch_first=True)
        
        input_seqs = input_seqs.view(input_seqs.shape[0], 1, input_seqs.shape[1])
        #input_seqs = input_seqs.to(device)
        
        pad_length_input = input_seqs.shape[2]
        n_samples = input_seqs.shape[0]

        pad_length_target = target_seqs.shape[1]

        model_output = model(input_seqs)
        model_output = model_output.permute(1, 0, 2)  # Assuming log probs are computed in network
        
        
        n_timesteps = model_output.shape[0]
        input_lengths = torch.tensor([n_timesteps for i in range(n_samples)])
        label_lengths = torch.tensor([len(y_test[ind + i]) for i in range(n_samples)])
        
        
        
        model_output = model_output.permute(1, 0, 2).detach().cpu()
        #print(model_output.shape)

        for k in range(batch_size):
            original = payloads_test[ind + k]
            greedy_result = greedy_decoder(model_output[k])
            spacers = [int(i) for i in greedy_result if int(i) > 8]
            actual_spacers =([i for i in y_test[ind + k] if i > 8])
            matched_spacers = set(spacers).intersection(actual_spacers)
            unique_spacers = set(spacers)  - set(actual_spacers)
            matched_spacers_arr.append(len(matched_spacers))
            unique_spacers_arr.append(len(unique_spacers))
            
            

            
        torch.cuda.empty_cache()
        """
        if ind >= n_test_samples:
            print(diff / (n_test_samples))
            break
        """

In [None]:
np.mean(matched_spacers_arr)

In [None]:
np.mean(unique_spacers_arr)

In [None]:
6.82/10

In [None]:
np.mean(matched_spacers_arr) / (np.mean(unique_spacers_arr) + np.mean(matched_spacers_arr))

### Box Plot

In [None]:
import numpy as np

In [None]:
def get_box_stats(data):
    data = np.array(data)
    median = np.median(data)
    upper_quartile = np.percentile(data, 75)
    lower_quartile = np.percentile(data, 25)

    iqr = upper_quartile - lower_quartile
    upper_whisker = data[data<=upper_quartile+1.5*iqr].max()
    lower_whisker = data[data>=lower_quartile-1.5*iqr].min()
    return median, upper_quartile, lower_quartile, upper_whisker, lower_whisker

In [None]:
results_df = pd.DataFrame(results_dict)

In [None]:
f

In [None]:
mf_caller = results_df['mf_edit']

In [None]:
mf_caller

In [None]:
get_box_stats(mf_caller)

In [None]:
columns = ['mf_caller', 'mf_edit', 'mf_search']
box_stat_arr = []

for i in columns:
    if i == 'mf_caller':
        filtered_df = results_df.loc[results_df['quality'] > 12]
    else:
        filtered_df = results_df
    detected_arr = filtered_df[i]
    detected_arr *= 100/8
    box_stat = get_box_stats(detected_arr)
    box_stat_arr.append(box_stat)

In [None]:
box_stat_arr

In [None]:
filtered_df = results_df.loc[results_df['quality'] > 0]

In [None]:
motifs_found_search_arr_

In [None]:
frac = 1/ (8) * 100

In [None]:
motifs_found_caller_arr_ = [i * frac for i in motifs_found_caller_arr]

In [None]:
np.mean(motifs_found_caller_arr_)

In [None]:
motifs_found_search_arr_ = [i * frac for i in motifs_found_search_arr]

In [None]:
t = pd.DataFrame([motifs_found_caller_arr_, motifs_found_search_arr_])

In [None]:
t.to_csv('motifs_found_empirical.csv', header=False)

## Convergence

In [None]:
test_df = dataset

In [None]:
def update_tallies(tallies, prediction):

    for ind, i in enumerate(prediction):
        for j in i:
            tallies[ind][j-1] += 1

    return tallies

def evaluate_motif_tallies(motif_tallies: List[List[int]], payload_cycles: List[List[int]]):
    correct = 0
    errs = 0
    for tallies, cycle in zip(motif_tallies, payload_cycles):
        sorted_tallies = np.argsort(tallies)[::-1]
        top_4 = [i+1 for i in sorted_tallies[:4]]
        correct += len(set(top_4).intersection(set(cycle)))
        errs += len(set(top_4) - set(cycle))

    return correct / 32

In [None]:
def get_model_prediction_batched(
        squiggles: List[List[float]], batch_size:int, beam: bool, beam_width: int = 30, prob_threshold:float = 0.0) -> List[List[int]]:

    greedy_transcripts = []
    qualities = []
    input_seqs = [
        normalize([squiggles[k]], norm='max').flatten() for k in range(
            len(squiggles))]
    
    input_seqs = pad_sequence([torch.tensor(
                i, dtype=torch.float32) for i in input_seqs], batch_first=True)
    
    input_seqs = input_seqs.view(
        input_seqs.shape[0], 1, input_seqs.shape[1])

    model_output_mixed = model_mixed(input_seqs).detach().cpu()

    for k in range(batch_size):
        greedy_result_mixed = greedy_decoder.forward(model_output_mixed[k])

        if detect_reverse_oriented_read(greedy_result_mixed):
            model_output_reverse = model_reverse(input_seqs).detach().cpu()
            greedy_result, quality = greedy_decoder.forward_with_quality(
                model_output_reverse[k], prob_threshold=prob_threshold)
        else:
            model_output_forward = model_forward(input_seqs).detach().cpu()
            greedy_result, quality = greedy_decoder.forward_with_quality(
                model_output_forward[k], prob_threshold=prob_threshold
            )
    
        greedy_transcript = " ".join(greedy_result)
        sorted_greedy = sort_transcript(greedy_transcript)
        
        greedy_transcripts.append(sorted_greedy)
        qualities.append(quality)

    return greedy_transcripts, qualities

### Barcoded runs

In [None]:
t2_barcodes = np.arange(2, 81, 4)
t3_barcodes = np.arange(3, 81, 4)
t4_barcodes = np.arange(4, 81, 4)

In [None]:
test_df = dataset

In [None]:
test_df = test_df.loc[test_df['ONT_Barcode'].isin(t4_barcodes)]

In [None]:
X = test_df['squiggle'].tolist()

In [None]:
read_ids_ = test_df['read_id'].tolist()

In [None]:
caller_predictions_arr = []
search_predictions_arr = []
edit_predictions_arr = []
payloads_arr = []
batch_size = 8
quality_threshold = 0
prob_threshold = 0.85

# Testing convergence - add quality filtering..
for barcode in tqdm(test_df['ONT_Barcode'].unique()):
    u_cycle = 0
    for cycle in tqdm(test_df['HW_Address'].unique()):
        selected_df = test_df.loc[
            (test_df['ONT_Barcode'] == barcode) &
            (test_df['HW_Address'] == cycle)
        ]
        read_ids = selected_df['read_id'].tolist()
        if len(read_ids) == 0:
            continue
        payload = selected_df['payload'].tolist()[0]
        squiggles = [X[read_ids_.index(k)] for k in read_ids]
        #search_predictions = selected_df['motif_seq'].tolist()
        #search_predictions_sorted = [
        #    sort_transcript(i) for i in selected_df['motif_seq'].tolist()]
        decoded_predictions = []

        prediction_indices = []
        
        for ind in range(0, len(squiggles), batch_size):
            if len(squiggles) - ind < batch_size:
                model_prediction, qualities = get_model_prediction_batched(
                squiggles[ind:],
                batch_size=len(squiggles) - ind,
                beam=False, prob_threshold=prob_threshold)
            else:
                model_prediction, qualities = get_model_prediction_batched(
                    squiggles[ind: ind + batch_size],
                    batch_size=batch_size,
                    beam=False, prob_threshold=prob_threshold)
                
            selected_indices = [
                i for i in range(len(qualities)) if qualities[i] > quality_threshold]
            
            prediction_indices.extend([i + ind * batch_size for i in selected_indices])
            
            #for i in model_prediction:
            #    print(evaluate_prediction(i, payload))
            decoded_predictions.extend([model_prediction[i] for i in selected_indices])
            
            if len(prediction_indices) > 80:
                break

        
        #search_predictions_arr.append(search_predictions_sorted[:len(decoded_predictions)])
        caller_predictions_arr.append(decoded_predictions)
        #edit_predictions_arr.append([i[2:] for i in selected_df['edit_search_seq'][:len(decoded_predictions)]])
        
        payloads_arr.append(payload)
        u_cycle += 1
        


### Non-barcoded runs

In [None]:
X = dataset['squiggle'].to_numpy()

In [None]:
## For the non barcoded runs

caller_predictions_arr = []
#search_predictions_arr = []
payloads_arr = []
batch_size = 8
quality_threshold = 12
prob_threshold = 0.85


for cycle in tqdm(test_df['HW_Address'].unique()):
    selected_df = test_df.loc[
        (test_df['HW_Address'] == cycle) #&
        #(test_df['strand'].str.startswith('+')) 
    ]
    read_ids = selected_df['read_id'].tolist()
    payload = selected_df['payload'].tolist()[0]
    squiggles = [X[read_ids_.index(k)] for k in read_ids]
    #search_predictions = selected_df['motif_seq'].tolist()
    #search_predictions_sorted = [
    #    sort_transcript(i) for i in selected_df['motif_seq'].tolist()]
    decoded_predictions = []
    
    for ind in tqdm(range(0, len(squiggles), batch_size)):
        if len(squiggles) - ind < batch_size:
            model_prediction, qualities = get_model_prediction_batched(
            squiggles[ind:],
            search_predictions[ind:],
            batch_size=len(squiggles) - ind,
            beam=False, prob_threshold=prob_threshold)
        else:
            model_prediction, qualities = get_model_prediction_batched(
                squiggles[ind: ind + batch_size],
                batch_size=batch_size,
                beam=False, prob_threshold=prob_threshold)
        
        selected_indices = [
                i for i in range(len(qualities)) if qualities[i] > quality_threshold]
            
        #prediction_indices.extend([i + ind * batch_size for i in selected_indices])
            
        #for i in model_prediction:
        #    print(evaluate_prediction(i, payload))
        decoded_predictions.extend([model_prediction[i] for i in selected_indices])

        if len(decoded_predictions) > 80:
            break
    
    #search_predictions_arr.append(search_predictions_sorted[:len(decoded_predictions)])
    caller_predictions_arr.append(decoded_predictions)
    
    payloads_arr.append(payload)
    


### Running consensus decoding

In [None]:
import heapq

In [None]:
def get_top_4_motifs(tallies: dict):
    prediction = []
    for i in range(len(tallies)):
        elems = heapq.nlargest(4, tallies[i].values())
        inds = heapq.nlargest(4, tallies[i].keys(), key = lambda x: tallies[i][x])
        prediction.append([k for j, k in zip(elems, inds) if j > 0 ])

    return prediction

In [None]:
search_accs_arr = []
caller_accs_arr = []
edit_accs_arr = []


### Comment out when non-barcoded
search_predictions_arr = caller_predictions_arr
edit_predictions_arr = caller_predictions_arr

for search_prediction, caller_prediction, edit_prediction, payload in zip(
    search_predictions_arr, caller_predictions_arr, edit_predictions_arr, payloads_arr):

    n_reads = 0
    motif_tallies_caller = [{i: 0 for i in range(1, 9)} for i in range(8)]
    motif_tallies_search = [{i: 0 for i in range(1, 9)} for i in range(8)]
    motif_tallies_edit = [{i: 0 for i in range(1, 9)} for i in range(8)]
    search_accs, caller_accs, edit_accs = [], [], []
    for i, j, k in zip(search_prediction, caller_prediction, edit_prediction):  # for each prediction
        cycle_num = 0
        for search_cycle, caller_cycle, edit_cycle in zip(i, j, k):  # loop over cycles and update tallies
            for motif in search_cycle:
                motif_tallies_search[cycle_num][motif] += 1
            for motif in caller_cycle:
                motif_tallies_caller[cycle_num][motif] += 1
            for motif in edit_cycle:
                motif_tallies_edit[cycle_num][motif] += 1
            cycle_num += 1

        #  Get caller and search predictions
        caller_prediction_ = get_top_4_motifs(motif_tallies_caller)
        search_prediction_ = get_top_4_motifs(motif_tallies_search)
        edit_prediction_ = get_top_4_motifs(motif_tallies_edit)
        
        n_reads += 1

        #print(payload)
        #payload = eval(payload)  # comment out for other datasets
        #  Evaluate predictions
        search_acc = evaluate_prediction(search_prediction_, payload)[0]/32
        caller_acc  = evaluate_prediction(caller_prediction_, payload)[0]/32
        edit_acc = evaluate_prediction(edit_prediction_, payload)[0]/32
        

        search_accs.append(search_acc)
        caller_accs.append(caller_acc)
        edit_accs.append(edit_acc)
        
    
    search_accs_arr.append(search_accs)
    caller_accs_arr.append(caller_accs)
    edit_accs_arr.append(edit_accs)
    


In [None]:
np.mean([i[-1] for i in caller_accs_arr if len(i) > 0])

In [None]:
avg_search = []
avg_caller = []
avg_edit = []

for i in range(7):

    search_sum = 0
    search_samples = 0
    for k in search_accs_arr:
        if i < len(k):
            search_sum += k[i]
            search_samples += 1
    
    caller_sum = 0
    caller_samples = 0
    for k in caller_accs_arr:
        if i < len(k):
            caller_sum += k[i]
            caller_samples += 1

    edit_sum = 0
    edit_samples = 0
    for k in edit_accs_arr:
        if i < len(k):
            edit_sum += k[i]
            edit_samples += 1

    avg_search.append(search_sum / search_samples)
    avg_caller.append(caller_sum / caller_samples)
    avg_edit.append(edit_sum / edit_samples)


In [None]:
len(avg_search)

In [None]:
test_df['ONT_Barcode'].unique()

In [None]:
sum = 0

for i in test_df['ONT_Barcode'].unique():
    sum += test_df.loc[test_df['ONT_Barcode'] == i]['HW_Address'].value_counts().sum()/ 64

sum/4

In [None]:
avg_caller

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(avg_search, label='zero-error search')
plt.plot(avg_caller, label='motif caller')
plt.plot(avg_edit, label='edit search')
plt.grid()
plt.yticks(np.arange(0, 1.01, 0.05))
plt.xticks(np.arange(0, 50, 2))
plt.xlim(0, 50)
plt.xlabel("Number of reads")
plt.ylabel("Recovery percentage")
plt.title("Convergence of motif-inferring methods")
plt.legend()

## Edit-distance search

In [None]:
import pandas as pd

df = pd.read_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\full_datasets\master_test.pkl")

In [None]:

with open(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\edit_distance_motif_search\res-loose-chain.txt", 'r') as f:
    lines = f.readlines()

In [None]:
read_ids = []
positions = []
orientations = []
motifs_found = []

for i in lines:
    split_i = i.split()
    read_ids.append(split_i[0][3:])
    orientations.append(split_i[1])
    positions.append(split_i[2][4:].split('-'))
    motifs_found.append(split_i[4][8:].split('->'))
    

In [None]:
motifs_found_ = [[int(j[1]) for j in i if not j == 'fake' and j.startswith('m')] for i in motifs_found ]

In [None]:
edit_distance_df = pd.DataFrame({'read_id': read_ids, 'positions': positions, 'orientations': orientations, 'edit_motif_seq': motifs_found_})

In [None]:
edit_distance_df.to_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\edit_distance_motif_search.pkl")

In [None]:
merged_df = pd.merge(df, edit_distance_df, on='read_id')

In [None]:
df.head()

In [None]:
from transcript_sorting import sort_transcript

In [None]:
for i, row in merged_df.iterrows():
    payload_seq = row['payload_seq']
    search_seq = row['motif_seq']
    edit_seq = row['edit_motif_seq']
    orientation = row['orientation']
    
    if len(edit_seq) <= 8:
        print(evaluate_prediction([[i] for i in edit_seq], payload_seq))
        print(evaluate_prediction(sort_transcript(search_seq), payload_seq))
        print()


