In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import torch
from nn import MotifCaller, NaiveCaller
from training_data import data_preproc, load_training_data
from utils import get_savepaths, load_model
from Bio import Align
import numpy as np
from sklearn.preprocessing import normalize
from greedy_decoder import GreedyCTCDecoder
from Levenshtein import ratio
from utils import load_model, get_metrics_for_evaluation, sort_transcript
from sklearn.model_selection import train_test_split
from beam_search_decoder import beam_search_ctc
import torch.nn as nn
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [5]:
n_classes = 9
model_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\training_logs\20250325.123145.630545\model.pth"
labels_int = np.arange(n_classes).tolist()
labels = [f"{i}" for i in labels_int] # Tokens to be fed into greedy decoder
greedy_decoder = GreedyCTCDecoder(labels = labels)
ctc = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
model = load_model(model_path=model_path, device=device, n_classes=n_classes, hidden_size=256)

dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\synthetic\pickled_datasets\no_spacers_long.pkl"

In [26]:
X, y = load_training_data(
       dataset_path, column_x='squiggle', column_y='motif_seq', payload=False, sampling_rate=0.1)

8580
Selected 8580 forward reads


In [27]:
X = data_preproc(X, window_size=1000, step_size=300, normalize_values=True)

100%|██████████| 858/858 [00:01<00:00, 833.95it/s]


In [28]:
X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42)

In [29]:
from evaluation import evaluate_cycle_prediction

In [30]:
aligner = Align.PairwiseAligner()

In [37]:

counter = 0
sum_diff = 0
gaps = 0
identities = 0
mismatches = 0

for x, y in tqdm(zip(X_test, y_test), total=len(X_test)):
    
    input_sequence = x.to(device)
    target_sequence = torch.tensor(y).to(device)

    #input_sequence = normalize([input_sequence], norm='l1')
    print(input_sequence.shape)

    model_output = model(input_sequence)
    model_output = model_output.permute(1, 0, 2)  # Assuming log probs are computed in network
    print(model_output.shape)
    model_output = model_output.reshape(
        model_output.shape[0] * model_output.shape[1], n_classes)
    print(model_output.shape)
    

    n_timesteps = model_output.shape[0]
    input_lengths = torch.tensor([n_timesteps])
    label_lengths = torch.tensor([len(target_sequence)])
    
  
    loss = ctc(
        log_probs=model_output, targets=target_sequence, input_lengths=input_lengths, target_lengths=label_lengths)
    print(loss.item())
    
    #model_output_flattened = model_output.view(
    #    model_output.shape[0] * model_output.shape[1], n_classes)
    
    greedy_transcript = "".join(greedy_decoder(model_output))
    beam_transcript = beam_search_ctc(
        model_output.detach().cpu(), beam_width=30)
    

    actual_transcript = " ".join([str(i) for i in y])
    print(actual_transcript)
    print(greedy_transcript)
    print(ratio(beam_transcript, actual_transcript))
    print(ratio(greedy_transcript, actual_transcript))
    
    aligned = aligner.align(greedy_transcript, actual_transcript)[0].counts()

    gaps += aligned[0] / len(actual_transcript)
    identities += aligned[1] / len(actual_transcript)
    mismatches += aligned[2] / len(actual_transcript)
    #print(aligner.align(greedy_transcript, actual_transcript)[0].counts())

print(gaps/64)
print(identities/64)
print(mismatches/64)




  0%|          | 0/172 [00:00<?, ?it/s]

torch.Size([61, 1, 1000])
torch.Size([15, 61, 9])
torch.Size([915, 9])
21.253889083862305


  new_prob = np.log(np.exp(old_prob) + np.exp(alignment_probs[ind] + prob))


3 4 5 1 1 3 6 1 4 6 7 5 8 4 4 4 4 4 3 3 2 8 1 7 3 7 8 2 3 6 5 5 3 2 4 7 5 7 3 6 4 4 5 6 2 2 2 4 6 8 2 6 8 7 1 8 4 6 7 8 2 7 4 5 8 6 6 5 7 5 1 1
444444444
0.0
0.11842105263157898
torch.Size([67, 1, 1000])
torch.Size([15, 67, 9])
torch.Size([1005, 9])
18.527786254882812
8 5 4 6 2 1 3 3 7 4 1 3 8 5 7 2 1 7 3 2 6 2 1 4 6 4 6 6 3 2 6 1 8 3 4 3 3 5 7 8 3 7 5 5 1 8 5 8 6 7 4 1 1 5 6 8 7 7 8 4 4 3 4 6 8 5 4 7 1 5 3 1 4 8 5 6 6 5 4 2 6 4 2 8 6 4 2 3 8
4444444444
0.0
0.106951871657754
torch.Size([69, 1, 1000])
torch.Size([15, 69, 9])
torch.Size([1035, 9])
18.435863494873047
6 1 3 1 1 8 5 8 6 7 3 7 4 8 2 6 4 6 5 1 4 4 5 2 5 3 2 1 4 1 1 2 6 3 1 6 2 4 4 4 6 3 8 5 2 4 4 5 5 8 6 5 4 2 2 7 6 6 3 8 1 1 1 1 5 4 6 5 6 4 1 4 6 1 4 3 6 3 1 8 4 7 8 6 5 6 3 1 3 3 6 3
4444444
0.0
0.0736842105263158
torch.Size([51, 1, 1000])
torch.Size([15, 51, 9])
torch.Size([765, 9])
26.930673599243164
6 6 8 5 4 8 4 4 5 1 6 5 2 3 5 2 5 8 4 2 6 1 7 6 4 4 8 5 8 6 7 4 4 8 8 3 8 1 7 8 1 5 4 3 7 7 3 3 2
4444444444
0.0
0.168224299

KeyboardInterrupt: 

In [1]:
import pandas as pd


df = pd.read_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\datasets\synthetic\working_datasets\unnormalized\synth_dataset.pkl")