In [1]:
%load_ext autoreload
%autoreload 2

In [37]:
import pandas as pd
import torch
from nn import MotifCaller, NaiveCaller
from training_data import data_preproc, load_training_data
from utils import get_savepaths
import numpy as np
from sklearn.preprocessing import normalize
from greedy_decoder import GreedyCTCDecoder
from Levenshtein import ratio
from utils import load_model, get_metrics_for_evaluation
from transcript_sorting import sort_transcript_reduced_spacers, sort_transcript
from sklearn.model_selection import train_test_split
from beam_search_decoder import beam_search_ctc
import torch.nn as nn
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from typing import List

In [2]:
n_classes = 19
model_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\models\empirical\cleaned_5_512.pth"
labels_int = np.arange(n_classes).tolist()
labels = [f"{i}" for i in labels_int] # Tokens to be fed into greedy decoder
greedy_decoder = GreedyCTCDecoder(labels = labels)
ctc = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

In [3]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
model = load_model(model_path=model_path, device=device, n_classes=n_classes, hidden_size=512)

test_dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\datasets\empirical\full_empirical_test_dataset_v5_payload_seq.pkl"
#dataset = pd.read_pickle(dataset_path)
#dataset = pd.read_pickle(test_dataset_path)

#X, y, payloads = load_training_data(
#        test_dataset_path, column_x='squiggle', column_y='motif_seq', payload=True, sampling_rate=0.2)

In [4]:
X, y, payloads = load_training_data(
        test_dataset_path, column_x='squiggle', column_y='Spacer_Sequence', payload=True, sampling_rate=0.2, orientation=False)

In [5]:
X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42)

In [6]:
X_train, X_test, payloads_train, payloads_test = train_test_split(
        X, payloads, test_size=0.2, random_state=42)

In [7]:
def evaluate_prediction(prediction, original):

    found = 0
    err = 0
    for i, j in zip(prediction, original):
        for k in range(len(i)):
            if i[k] in j:
                found += 1
            else:
                err += 1

    return found, err

In [40]:

counter = 0
sum_diff = 0

for x, y, payload in tqdm(zip(X_test, y_test, payloads_test), total=len(X_test)):
    input_sequence = torch.tensor(
        x, dtype=torch.float32)
    input_sequence = input_sequence.view(
        1, 1, len(x)).to(device)
    model_output = model(input_sequence)
    model_output = model_output.permute(1, 0, 2)
    
    label_lengths = torch.tensor([len(y)])
    target_sequence = torch.tensor(y).to(device)

    
    n_timesteps = model_output.shape[0]
    input_lengths = torch.tensor([n_timesteps])
    
    model_output_flattened = model_output.view(
        model_output.shape[0] * model_output.shape[1], n_classes)

    loss = ctc(
        model_output, target_sequence, input_lengths, label_lengths)
    print(loss.item())
    
    greedy_transcript = " ".join(greedy_decoder(model_output))
    beam_transcript = beam_search_ctc(
        model_output_flattened.detach().cpu(), beam_width=20)
    actual_transcript = " ".join([str(i) for i in y])
    #print(greedy_transcript)
    #print(beam_transcript)

    decoded_prediction = sort_transcript_reduced_spacers(beam_transcript)
    search_prediction = sort_transcript(actual_transcript)
    original = sort_transcript(" ".join([str(i) for i in payload]))
    print(decoded_prediction)
    print(search_prediction)
    print(original)
    
    found_motifs_caller = evaluate_prediction(decoded_prediction, original)
    found_motifs_search = evaluate_prediction(search_prediction, original)
    print(found_motifs_caller)
    print(found_motifs_search)
    sum_diff += found_motifs_caller[0] - found_motifs_search[0]
    
    #greedy_ratio = ratio(greedy_transcript, actual_transcript)
    #beam_ratio = ratio(beam_transcript, actual_transcript)
    #sum_diff += beam_ratio - greedy_ratio
    #print()
    counter+=1
    print(sum_diff)
    print()

    if counter == 200:
        print(sum_diff)
        break

  0%|          | 0/2762 [00:00<?, ?it/s]

4.061060905456543
[[1], [1, 1, 4, 4], [4], [], [], [], [], [7]]
[[], [], [3, 6], [3], [6], [], [], []]
[[1, 5, 6, 8], [1, 2, 3, 7], [3, 4, 5, 6], [1, 3, 6, 8], [1, 3, 6, 7], [2, 3, 7, 8], [2, 3, 7, 8], [2, 6, 7, 8]]
(5, 2)
(4, 0)
1

30.538583755493164
[[8], [1, 1, 4, 4], [7], [], [], [], [], [7]]
[[], [4], [], [], [], [], [], []]
[[2, 3, 4, 5], [1, 2, 3, 4], [2, 4, 6, 7], [1, 3, 4, 5], [3, 5, 7, 8], [2, 3, 5, 6], [2, 3, 4, 7], [1, 3, 4, 8]]
(5, 2)
(1, 0)
5

26.7930850982666
[[1], [1, 8, 1, 4], [4], [], [], [], [], [7]]
[[], [], [5], [], [], [], [], []]
[[2, 4, 6, 7], [1, 2, 4, 7], [1, 2, 3, 5], [2, 3, 4, 5], [1, 3, 4, 5], [1, 3, 5, 7], [1, 2, 3, 6], [2, 4, 5, 8]]
(3, 4)
(1, 0)
7

11.062883377075195
[[1], [1, 1, 4, 4], [4], [], [], [], [], [7]]
[[], [], [5], [], [], [1], [], []]
[[2, 3, 6, 8], [3, 4, 5, 6], [2, 5, 6, 7], [1, 3, 5, 6], [1, 2, 4, 7], [1, 4, 5, 8], [3, 4, 7, 8], [1, 4, 7, 8]]
(3, 4)
(2, 0)
8

13.580032348632812
[[8], [1, 1, 4, 4], [4], [], [], [], [], [7]]
[[], [], [], [3]

KeyboardInterrupt: 

In [10]:
test_df = pd.read_pickle(test_dataset_path)

In [15]:
test_df

Unnamed: 0,ONT_Barcode,HW_Address,Payload,Library_Motifs,read_id,squiggle,Motifs,Spacer_Sequence,Payload_Sequence
0,5,barcode_external01_internal01,"[[3, 6, 7, 8], [3, 4, 5, 8], [1, 2, 4, 6], [1,...","[ltm8_2x1, ltm8_3x8, ltm8_4x4, ltm8_5x1, ltm8_...",0038e7e2-ab7a-4e8e-a9b5-d39eefa8b0f2,"[0.5791722622390428, 0.5589061164692956, 0.557...","[1, 8, 4, 1, 8, 8, 4]","[10, 1, 10, 11, 8, 11, 12, 4, 12, 13, 1, 13, 1...","[11, 3, 11, 11, 6, 11, 11, 7, 11, 11, 8, 11, 1..."
1,5,barcode_external01_internal01,"[[3, 6, 7, 8], [3, 4, 5, 8], [1, 2, 4, 6], [1,...","[ltm8_1x1, ltm8_3x6, ltm8_5x4, ltm8_6x1]",0073a9a2-8ee0-4332-9722-72837b77b29c,"[0.5748992796972286, 0.5625686729337077, 0.565...","[1, 6, 4, 1]","[9, 1, 9, 11, 6, 11, 13, 4, 13, 14, 1, 14]","[11, 3, 11, 11, 6, 11, 11, 7, 11, 11, 8, 11, 1..."
2,5,barcode_external01_internal01,"[[3, 6, 7, 8], [3, 4, 5, 8], [1, 2, 4, 6], [1,...","[ltm8_4x8, ltm8_5x4]",00e4308c-baf1-49b1-848a-9c7ff35971e5,"[0.9999999999999999, 0.6249999999999999, 0.613...","[8, 4]","[12, 8, 12, 13, 4, 13]","[11, 3, 11, 11, 6, 11, 11, 7, 11, 11, 8, 11, 1..."
3,5,barcode_external01_internal01,"[[3, 6, 7, 8], [3, 4, 5, 8], [1, 2, 4, 6], [1,...","[ltm8_1x1, ltm8_2x1, ltm8_3x7, ltm8_4x8, ltm8_...",017a284c-952d-422e-ba1a-c4d88e06a3ac,"[0.694006309148265, 0.7634069400630915, 0.7555...","[1, 1, 7, 8, 4, 4, 7]","[9, 1, 9, 10, 1, 10, 11, 7, 11, 12, 8, 12, 13,...","[11, 3, 11, 11, 6, 11, 11, 7, 11, 11, 8, 11, 1..."
4,5,barcode_external01_internal01,"[[3, 6, 7, 8], [3, 4, 5, 8], [1, 2, 4, 6], [1,...","[ltm8_6x8, ltm8_7x1]",0182d415-b83b-4eeb-9bcd-4f217aba5e8c,"[0.5516483516483517, 0.6747252747252748, 0.698...","[8, 1]","[14, 8, 14, 15, 1, 15]","[11, 3, 11, 11, 6, 11, 11, 7, 11, 11, 8, 11, 1..."
...,...,...,...,...,...,...,...,...,...
70897,53,barcode_external08_internal08,"[[1, 3, 4, 6], [1, 2, 3, 6], [1, 2, 4, 8], [2,...","[ltm8_3x3, ltm8_4x6, ltm8_5x1, ltm8_7x7, ltm8_...",fd1e552d-6a0e-45b9-9cd4-c9b7051ea826,"[0.5481626175070199, 0.5524356000488341, 0.555...","[3, 6, 1, 7, 8, 2]","[11, 3, 11, 12, 6, 12, 13, 1, 13, 15, 7, 15, 1...","[11, 1, 11, 11, 3, 11, 11, 4, 11, 11, 6, 11, 1..."
70898,53,barcode_external08_internal08,"[[1, 3, 4, 6], [1, 2, 3, 6], [1, 2, 4, 8], [2,...","[ltm8_2x8, ltm8_3x6, ltm8_6x5]",fd6271bc-36f0-4940-8f17-9d164e6081bf,"[0.5573190086680503, 0.5589061164692956, 0.557...","[8, 6, 5]","[10, 8, 10, 11, 6, 11, 14, 5, 14]","[11, 1, 11, 11, 3, 11, 11, 4, 11, 11, 6, 11, 1..."
70899,53,barcode_external08_internal08,"[[1, 3, 4, 6], [1, 2, 3, 6], [1, 2, 4, 8], [2,...","[ltm8_2x8, ltm8_6x2, ltm8_7x1, ltm8_9x5]",fdf1bb91-241d-406b-bd48-65a6a92874c7,"[0.4012875536480687, 0.40772532188841204, 0.54...","[8, 2, 1, 5]","[10, 8, 10, 14, 2, 14, 15, 1, 15, 17, 5, 17]","[11, 1, 11, 11, 3, 11, 11, 4, 11, 11, 6, 11, 1..."
70900,53,barcode_external08_internal08,"[[1, 3, 4, 6], [1, 2, 3, 6], [1, 2, 4, 8], [2,...","[ltm8_4x3, ltm8_5x4]",fdfdc956-36c7-4fed-86b8-358966b0d71c,"[0.16560820713238886, 0.1524181729360039, 0.16...","[3, 4]","[12, 3, 12, 13, 4, 13]","[11, 1, 11, 11, 3, 11, 11, 4, 11, 11, 6, 11, 1..."


In [42]:
def get_model_prediction(
        squiggle: list[float], model: NaiveCaller, beam: bool, beam_width: int = 30) -> List[List[int]]:

    input_sequence = torch.tensor(
                squiggle, dtype=torch.float32)
    input_sequence = input_sequence.view(
        1, 1, len(squiggle)).to(device)
    model_output = model(input_sequence)
    model_output = model_output.permute(1, 0, 2)

    if beam:
        model_output_flattened = model_output.view(
        model_output.shape[0] * model_output.shape[1], n_classes)
        transcript = beam_search_ctc(
        model_output_flattened.detach().cpu(), beam_width=beam_width)
    else:
        transcript = " ".join(greedy_decoder(model_output))
    
    return sort_transcript_reduced_spacers(transcript)


In [None]:
# Testing convergence
for barcode in test_df['ONT_Barcode'].unique():
    for cycle in test_df['HW_Address'].unique():

        motif_tallies_search = [[0 for i in range(8)] for i in range(8)]
        motif_tallies_caller = [[0 for i in range(8)] for i in range(8)]

        selected_df = test_df.loc[
            (test_df['ONT_Barcode'] == barcode) &
            (test_df['HW_Address'] == cycle)       
        ]
        payload = selected_df['Payload'][0]
        squiggles = selected_df['squiggle'].tolist()
        search_predictions = [
            sort_transcript(i) for i in selected_df['Spacer_Sequence'].tolist()]
        print(payload)

        for squiggle, search_prediction in zip(
            squiggles, search_predictions):
            decoded_prediction = get_model_prediction(squiggle, model, beam=True)

            

        break
    break


[[3, 6, 7, 8], [3, 4, 5, 8], [1, 2, 4, 6], [1, 2, 5, 6], [4, 5, 6, 8], [1, 2, 5, 7], [1, 3, 4, 7], [2, 4, 5, 7]]
[[8], [4], [1], [8], [8], [], [4], []]
[[8], [1, 1, 4, 4], [4], [], [], [], [], [7]]

[[6], [], [4], [1], [], [], [], []]
[[2], [1, 1, 4, 4], [4], [], [], [], [], [7]]

[[], [8], [4], [], [], [], [], []]
[[7], [1, 1, 4, 4], [4], [], [], [], [], [7]]

[[7], [8], [4], [], [4], [7], [], []]
[[2], [1, 1, 4, 4], [4], [], [], [], [], [7]]

[[], [], [], [8], [1], [], [], []]
[[6], [], [8, 4, 7, 7], [], [], [], [], [7]]

[[], [], [], [], [], [], [1], []]
[[6], [], [1, 4, 4, 7], [], [], [], [], [7]]

[[], [3], [], [], [5], [7], [], []]
[[2], [1, 1, 4, 4], [7], [], [], [], [], [7]]

[[6], [], [6], [5], [8], [], [4], [5]]
[[6], [1, 1, 4], [7], [], [], [], [], [7]]

[[], [], [], [1], [], [], [4], []]
[[8], [1, 1, 4, 4], [4], [], [], [], [7], []]

[[], [], [4], [], [], [], [], []]
[[8], [1, 1, 4, 4], [4], [], [], [], [], [7]]

[[8], [], [6], [5], [], [], [], []]
[[2], [1, 1, 4, 4], [4], 

KeyboardInterrupt: 