In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import pandas as pd
import torch
from nn import MotifCaller, NaiveCaller
from training_data import data_preproc, load_training_data
from utils import get_savepaths
import numpy as np
from sklearn.preprocessing import normalize
from greedy_decoder import GreedyCTCDecoder
from Levenshtein import ratio
from utils import load_model, get_metrics_for_evaluation, sort_transcript
from sklearn.model_selection import train_test_split
from beam_search_decoder import beam_search_ctc
import torch.nn as nn
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [2]:
n_classes = 17
model_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\models\synthetic\50_epochs.pth"
labels_int = np.arange(n_classes).tolist()
labels = [f"{i}" for i in labels_int] # Tokens to be fed into greedy decoder
greedy_decoder = GreedyCTCDecoder(labels = labels)
ctc = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

In [3]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
model = load_model(model_path=model_path, device=device, n_classes=n_classes)

dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\synthetic\pickled_datasets\25_2_25.pkl"

In [4]:
df = pd.read_pickle(dataset_path)

In [5]:
X, y = load_training_data(
       dataset_path, column_x='squiggle', column_y='motif_seq', payload=False, sampling_rate=1.0)

In [6]:
X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42)

In [8]:
from evaluation import evaluate_cycle_prediction

In [13]:

counter = 0
sum_diff = 0

for x, y in tqdm(zip(X_test, y_test), total=len(X_test)):
    
    input_sequence = normalize([x], norm='l1')
    input_sequence = torch.tensor(
        input_sequence, dtype=torch.float32)
    input_sequence = input_sequence.view(
        1, 1, len(x)).to(device)
    
    model_output = model(input_sequence)
    model_output = model_output.permute(1, 0, 2)
    
    label_lengths = torch.tensor([len(y)])
    target_sequence = torch.tensor(y).to(device)

    n_timesteps = model_output.shape[0]
    input_lengths = torch.tensor([n_timesteps])    
    model_output_flattened = model_output.view(
        model_output.shape[0] * model_output.shape[1], n_classes)

    loss = ctc(
        model_output, target_sequence, input_lengths, label_lengths)
    #print(loss.item())
    
    greedy_transcript = " ".join(greedy_decoder(model_output))
    #beam_transcript = beam_search_ctc(
    #    model_output_flattened.detach().cpu(), beam_width=10)
    #print()
    actual_transcript = " ".join([str(i) for i in y])

    print(greedy_transcript)
    print(actual_transcript)

    decoded_prediction = sort_transcript(greedy_transcript)
    search_prediction = sort_transcript(actual_transcript)
    n_motifs = sum([len(i) for i in search_prediction])
    #print(decoded_prediction)
    #print(search_prediction)
    if n_motifs == 0:
        continue

    found_motifs_caller = evaluate_cycle_prediction(decoded_prediction, search_prediction)
    print(found_motifs_caller)
    print()
    
    #greedy_ratio = ratio(greedy_transcript, actual_transcript)
    #beam_ratio = ratio(beam_transcript, actual_transcript)
    #sum_diff += beam_ratio - greedy_ratio

    counter += 1
    if counter == 30:
        print(sum_diff)
        break

  0%|          | 0/1998 [00:00<?, ?it/s]

13 13 5 13 13 8 13 14 1 14 14 3 14 14 5 14 14 8 14 15 1 15 15 3 15 15 6 15 15 8 15 16 1 16 16 3 16 16 6 16 16 8 16
3 13 13 5 13 13 6 13 14 2 14 14 3 14 14 4 14 14 8 14 15 2 15 15 3 15 15 5 15 15 6 15 16 2 16 16 4 16 16 5 16 16 7 16
(0.3333333333333333, 0.6)

10 1 10 10 3 10 10 6 10 10 8 10 11 1 11 11 4 11 11 5 11 11 8 11 12 1 12 12 2 12 12 5 12 12 8 12 13 1 13 13 3 13 13 5 13 13 8 13 14 1 14 14 3 14 14 5 14 14 8 14 15 1 15 15 3 15 15 5 15 15 8 15 16 1 16 16 3 16 16 6 16 16 8 16
9 10 1 10 10 2 10 10 4 10 10 8 10 11 3 11 11 4 11 11 5 11 11 6 11 12 2 12 12 3 12 12 5 12 12 6 12 13 1 13 13 3 13 13 4 13 13 7 13 14 1 14 14 3 14 14 4 14 14 7 14 15 2 15 15 5 15 15 7 15 15 8 15 16 1 16 16 6 16 16 7 16 16 8 16
(0.5416666666666666, 0.4583333333333333)

14 14 3 14 14 5 14 14 8 14 15 1 15 15 4 15 15 5 15 15 7 15 16 1 16 16 3 16 16 6 16 16 8 16
14 14 3 14 14 4 14 14 7 14 15 1 15 15 2 15 15 3 15 15 8 15 16 3 16 16 5 16 16 6 16 16 8 16
(0.45454545454545453, 0.5454545454545454)

2 15 15 3 15 15 6 15 15 