In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import torch
from nn import MotifCaller, NaiveCaller
from training_data import data_preproc, load_training_data
from utils import get_savepaths
import numpy as np
from sklearn.preprocessing import normalize
from greedy_decoder import GreedyCTCDecoder
from Levenshtein import ratio
from utils import load_model, get_metrics_for_evaluation
from transcript_sorting import sort_transcript_reduced_spacers, sort_transcript
from sklearn.model_selection import train_test_split
from beam_search_decoder import beam_search_ctc
import torch.nn as nn
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple

In [3]:
orientation_df = pd.read_csv(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\motif_search_barcoded.csv")

In [4]:
n_classes = 19
model_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\models\empirical\model (6).pth"
labels_int = np.arange(n_classes).tolist()
labels = [f"{i}" for i in labels_int] # Tokens to be fed into greedy decoder
greedy_decoder = GreedyCTCDecoder(labels = labels)
ctc = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

In [5]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
model = load_model(model_path=model_path, device=device, n_classes=n_classes, hidden_size=512)

test_dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\datasets\empirical\full_empirical_test_dataset_v5_payload_seq.pkl"
#dataset = pd.read_pickle(dataset_path)
dataset = pd.read_pickle(test_dataset_path)

#X, y, payloads = load_training_data(
#        test_dataset_path, column_x='squiggle', column_y='motif_seq', payload=True, sampling_rate=0.2)

In [6]:
orientation_df = orientation_df[['read_id', 'orientation']]
merged_df = pd.merge(dataset, orientation_df, on='read_id')

In [7]:
sampled_df = merged_df.sample(n=100)

In [8]:
def evaluate_prediction(prediction, original):

    found = 0
    err = 0
    for i, j in zip(prediction, original):
        for k in range(len(i)):
            if i[k] in j:
                found += 1
            else:
                err += 1

    return found, err

In [11]:

counter = 0
sum_diff = 0

for ind, row in sampled_df.iterrows():
    
    x = row['squiggle']
    y = row['Spacer_Sequence']
    payload = row['Payload_Sequence']
    orientation = row['orientation']

    if orientation.startswith('-'):
        continue

    input_sequence = torch.tensor(
        x, dtype=torch.float32)
    input_sequence = input_sequence.view(
        1, 1, len(x)).to(device)
    model_output = model(input_sequence)
    model_output = model_output.permute(1, 0, 2)
    
    label_lengths = torch.tensor([len(y)])
    target_sequence = torch.tensor(y).to(device)

    
    n_timesteps = model_output.shape[0]
    input_lengths = torch.tensor([n_timesteps])
    
    model_output_flattened = model_output.view(
        model_output.shape[0] * model_output.shape[1], n_classes)

    loss = ctc(
        model_output, target_sequence, input_lengths, label_lengths)
    print(loss.item())
    print(orientation)
    
    greedy_transcript = " ".join(greedy_decoder(model_output_flattened))
    greedy_transcript_2 = " ".join(greedy_decoder(model_output))

    beam_transcript = beam_search_ctc(
        model_output_flattened.detach().cpu(), beam_width=20)
    actual_transcript = " ".join([str(i) for i in y])
    #print(greedy_transcript)
    #print(beam_transcript)

    decoded_prediction = sort_transcript_reduced_spacers(greedy_transcript)
    decoded_prediction_2 = sort_transcript_reduced_spacers(greedy_transcript_2)
    search_prediction = sort_transcript(actual_transcript)
    original = sort_transcript(" ".join([str(i) for i in payload]))
    print(decoded_prediction)
    print(decoded_prediction_2)
    print(search_prediction)
    print(original)
    
    found_motifs_caller = evaluate_prediction(decoded_prediction, original)
    found_motifs_search = evaluate_prediction(search_prediction, original)
    print(found_motifs_caller)
    print(found_motifs_search)
    sum_diff += found_motifs_caller[0] - found_motifs_search[0]
    
    #greedy_ratio = ratio(greedy_transcript, actual_transcript)
    #beam_ratio = ratio(beam_transcript, actual_transcript)
    #sum_diff += beam_ratio - greedy_ratio
    #print()
    counter+=1
    print(sum_diff)
    print()

    if counter == 5:
        print(sum_diff)
        break

8.335997581481934
+|+|+|+|+
[[5], [], [5], [3], [], [6], [1], [3]]
[[5], [], [5], [3], [], [6], [1, 1], [3]]
[[], [], [], [3], [6], [6], [], [3]]
[[1, 4, 5, 8], [4, 6, 7, 8], [1, 2, 3, 4], [1, 3, 5, 6], [5, 6, 7, 8], [2, 3, 4, 6], [1, 4, 5, 6], [1, 2, 3, 4]]
(5, 1)
(4, 0)
1

6.338716983795166
+|+|+|+|+|+
[[1], [], [5], [1], [], [7], [5], [3]]
[[1], [], [5], [1], [], [7], [5], [3]]
[[], [1], [8], [1], [3], [], [8], []]
[[2, 6, 7, 8], [1, 2, 3, 4], [1, 4, 7, 8], [1, 2, 5, 8], [1, 3, 4, 8], [1, 4, 5, 7], [1, 3, 5, 8], [1, 3, 7, 8]]
(4, 2)
(5, 0)
0

16.298765182495117
+|+
[[8], [], [1], [8], [], [8], [3], [5]]
[[8], [], [1], [8], [], [8], [3], [5]]
[[3], [], [], [6], [], [], [], []]
[[3, 4, 5, 6], [2, 5, 6, 7], [1, 2, 4, 8], [2, 4, 6, 7], [2, 5, 7, 8], [2, 4, 7, 8], [2, 3, 4, 8], [2, 5, 6, 7]]
(4, 2)
(2, 0)
2

55.10380172729492
+
[[8], [6], [6], [6], [3], [], [3], [2]]
[[8], [6, 6], [6, 6], [6], [3], [], [3], [2]]
[[8], [], [], [], [], [], [], []]
[[2, 3, 7, 8], [1, 5, 7, 8], [1, 2, 4, 5],

In [14]:
def get_model_prediction(
        squiggle: list[float], model: NaiveCaller, beam: bool, beam_width: int = 30) -> List[List[int]]:

    input_sequence = torch.tensor(
                squiggle, dtype=torch.float32)
    input_sequence = input_sequence.view(
        1, 1, len(squiggle)).to(device)
    model_output = model(input_sequence)
    model_output = model_output.permute(1, 0, 2)
    
    model_output_flattened = model_output.view(
    model_output.shape[0] * model_output.shape[1], n_classes)

    if beam:
        transcript = beam_search_ctc(
        model_output_flattened.detach().cpu(), beam_width=beam_width)
    else:
        transcript = " ".join(greedy_decoder(model_output_flattened))

    return sort_transcript_reduced_spacers(transcript)


In [15]:
def evaluate_motif_tallies(motif_tallies: List[List[int]], payload_cycles: List[List[int]]) -> Tuple[float, float]:
    
    correct = 0
    errs = 0
    for tallies, cycle in zip(motif_tallies, payload_cycles):
        sorted_tallies = sorted(range(len(tallies)), key=lambda i: tallies[i], reverse=True)
        top_4 = [i+1 for i in sorted_tallies[:4]]


        correct += len(set(top_4).intersection(set(cycle)))
        errs += len(set(top_4) - set(cycle))

    return correct / 32, errs / 32

In [17]:
# Testing convergence
for barcode in merged_df['ONT_Barcode'].unique():
    for cycle in merged_df['HW_Address'].unique():

        motif_tallies_search = [[0 for i in range(8)] for i in range(8)]
        motif_tallies_caller = [[0 for i in range(8)] for i in range(8)]

        selected_df = merged_df.loc[
            (merged_df['ONT_Barcode'] == barcode) &
            (merged_df['HW_Address'] == cycle) &
            (merged_df['orientation'].str.startswith('+'))       
        ]
        payload = selected_df['Payload'].tolist()[0]
        squiggles = selected_df['squiggle'].tolist()
        search_predictions = [
            sort_transcript(i) for i in selected_df['Spacer_Sequence'].tolist()]
        orientations = selected_df['orientation']
        
        for squiggle, search_prediction in tqdm(zip(
            squiggles[:30], search_predictions), total=len(squiggles)):
            
            decoded_prediction = get_model_prediction(squiggle, model, beam=False)
            

            for i in range(len(search_prediction)):
                for j, k in zip(search_prediction[i], decoded_prediction[i]):
                    motif_tallies_search[i][j-1] += 1
                    motif_tallies_caller[i][k-1] += 1

        
        search_acc, search_err = evaluate_motif_tallies(motif_tallies_search, payload)
        caller_acc, caller_err = evaluate_motif_tallies(motif_tallies_caller, payload)

        print(f"Search acc {search_acc} err {search_err}\n"
              f"Caller acc {caller_acc} err {caller_err}")
        
        break
        
    break


  0%|          | 0/261 [00:00<?, ?it/s]

Search acc 0.8125 err 0.1875
Caller acc 0.65625 err 0.34375
