In [39]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [42]:
import pandas as pd
import torch
from nn import MotifCaller, NaiveCaller
from training_data import data_preproc, load_training_data
from utils import get_savepaths
import numpy as np
from sklearn.preprocessing import normalize
from greedy_decoder import GreedyCTCDecoder
from Levenshtein import ratio
from utils import load_model, get_metrics_for_evaluation, sort_transcript
from sklearn.model_selection import train_test_split
from beam_search_decoder import beam_search_ctc
import torch.nn as nn
from tqdm.notebook import tqdm

In [17]:
n_classes = 19
model_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\models\empirical\3_3_25.pth"
labels_int = np.arange(n_classes).tolist()
labels = [f"{i}" for i in labels_int] # Tokens to be fed into greedy decoder
greedy_decoder = GreedyCTCDecoder(labels = labels)
ctc = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

In [36]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
model = load_model(model_path=model_path, device=device, n_classes=n_classes)

dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\pickled_datasets\28_2_25.pkl"

X, y, payloads = load_training_data(
        dataset_path, column_x='squiggle', column_y='motif_seq', payload=True, sampling_rate=0.2)

Index(['ONT_Barcode', 'HW_Address', 'Payload', 'Library_Motifs', 'read_id',
       'squiggle', 'Motifs', 'motif_seq', 'Payload_Sequence'],
      dtype='object')


In [37]:
X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42)

In [38]:
X_train, X_test, payloads_train, payloads_test = train_test_split(
        X, payloads, test_size=0.2, random_state=42)

In [47]:

counter = 0
sum_diff = 0

for x, y, payload in tqdm(zip(X_test, y_test, payloads_test), total=len(X_test)):
    
    input_sequence = torch.tensor(
        x, dtype=torch.float32)
    input_sequence = input_sequence.view(
        1, 1, len(x)).to(device)
    model_output = model(input_sequence)
    model_output = model_output.permute(1, 0, 2)
    
    label_lengths = torch.tensor([len(y)])
    target_sequence = torch.tensor(y).to(device)

    
    n_timesteps = model_output.shape[0]
    input_lengths = torch.tensor([n_timesteps])
    
    model_output_flattened = model_output.view(
        model_output.shape[0] * model_output.shape[1], n_classes)

    loss = ctc(
        model_output, target_sequence, input_lengths, label_lengths)
    print(loss.item())
    
    greedy_transcript = " ".join(greedy_decoder(model_output))
    beam_transcript = beam_search_ctc(
        model_output_flattened.detach().cpu(), beam_width=10)
    actual_transcript = " ".join([str(i) for i in y])

    print(sort_transcript(beam_transcript))
    print(beam_transcript)
    print(sort_transcript(actual_transcript))
    print(sort_transcript(" ".join([str(i) for i in payload])))
    
    greedy_ratio = ratio(greedy_transcript, actual_transcript)
    beam_ratio = ratio(beam_transcript, actual_transcript)
    
    print(get_metrics_for_evaluation(greedy_transcript, actual_transcript, payload))
    sum_diff += beam_ratio - greedy_ratio
    print()

    counter += 1
    if counter == 10:
        break

  0%|          | 0/936 [00:00<?, ?it/s]

1.748037338256836
[[4], [], [1], [1], [], [], [], [1]]
11 4 11 12 12 13 1 13 14 1 14 15 15 17 18 1 18
[[], [], [6], [1], [8], [1], [8], [6]]
[[3, 4, 7, 8], [1, 2, 3, 5], [3, 4, 6, 8], [1, 3, 4, 5], [1, 3, 4, 8], [1, 3, 4, 7], [2, 3, 4, 8], [1, 4, 6, 8]]
(0.16666666666666666, 0.75, 0.16666666666666666, 0.3333333333333333)

1.3448082208633423
[[4], [], [], [1], [1, 2], [], [], []]
11 4 11 13 13 14 1 14 15 1 15 2 18
[[8], [6], [7], [2], [8], [5], [5], [2]]
[[1, 3, 4, 8], [2, 4, 6, 8], [2, 3, 5, 7], [1, 2, 6, 7], [3, 6, 7, 8], [1, 2, 5, 7], [4, 5, 6, 7], [2, 4, 5, 6]]
(0.0, 1.0, 0.0, 0.5)

1.774727702140808
[[1], [], [], [5], [1], [], [], []]
11 1 11 13 13 14 5 14 15 1 15 18
[[1], [], [], [4], [7], [8], [2], []]
[[1, 3, 4, 6], [1, 4, 7, 8], [1, 3, 4, 8], [1, 2, 4, 7], [3, 5, 7, 8], [3, 6, 7, 8], [2, 5, 7, 8], [1, 3, 4, 7]]
(0.2, 0.6666666666666666, 0.2, 0.4)

1.2886781692504883
[[4], [], [1], [1], [], [], [], [5]]
11 4 11 12 12 13 1 13 14 1 14 15 15 17 18 5 18
[[3], [], [5], [5], [], [6], 