In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import torch
from nn import MotifCaller, NaiveCaller
from training_data import data_preproc, load_training_data
from utils import get_savepaths
import numpy as np
from sklearn.preprocessing import normalize
from greedy_decoder import GreedyCTCDecoder
from Levenshtein import ratio
from utils import load_model, get_metrics_for_evaluation, sort_transcript
from sklearn.model_selection import train_test_split
from beam_search_decoder import beam_search_ctc
import torch.nn as nn
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [3]:
n_classes = 17
model_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\models\synthetic\0.5_right.pth"
labels_int = np.arange(n_classes).tolist()
labels = [f"{i}" for i in labels_int] # Tokens to be fed into greedy decoder
greedy_decoder = GreedyCTCDecoder(labels = labels)
ctc = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

In [4]:
model = NaiveCaller(num_classes=n_classes)

In [5]:
model

NaiveCaller(
  (cnn): Sequential(
    (0): Conv1d(1, 4, kernel_size=(5,), stride=(2,))
    (1): ReLU()
    (2): MaxPool1d(kernel_size=5, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv1d(4, 16, kernel_size=(5,), stride=(3,))
    (4): ReLU()
    (5): MaxPool1d(kernel_size=5, stride=3, padding=0, dilation=1, ceil_mode=False)
    (6): Conv1d(16, 64, kernel_size=(5,), stride=(1,))
    (7): ReLU()
    (8): Conv1d(64, 128, kernel_size=(5,), stride=(1,))
    (9): ReLU()
  )
  (lstm): LSTM(128, 256, num_layers=3, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=512, out_features=17, bias=True)
)

In [6]:
checkpoint = torch.load(model_path)
print(checkpoint['model_state_dict'].keys())

odict_keys(['cnn.0.weight', 'cnn.0.bias', 'cnn.3.weight', 'cnn.3.bias', 'cnn.6.weight', 'cnn.6.bias', 'cnn.8.weight', 'cnn.8.bias', 'lstm.weight_ih_l0', 'lstm.weight_hh_l0', 'lstm.bias_ih_l0', 'lstm.bias_hh_l0', 'lstm.weight_ih_l0_reverse', 'lstm.weight_hh_l0_reverse', 'lstm.bias_ih_l0_reverse', 'lstm.bias_hh_l0_reverse', 'lstm.weight_ih_l1', 'lstm.weight_hh_l1', 'lstm.bias_ih_l1', 'lstm.bias_hh_l1', 'lstm.weight_ih_l1_reverse', 'lstm.weight_hh_l1_reverse', 'lstm.bias_ih_l1_reverse', 'lstm.bias_hh_l1_reverse', 'lstm.weight_ih_l2', 'lstm.weight_hh_l2', 'lstm.bias_ih_l2', 'lstm.bias_hh_l2', 'lstm.weight_ih_l2_reverse', 'lstm.weight_hh_l2_reverse', 'lstm.bias_ih_l2_reverse', 'lstm.bias_hh_l2_reverse', 'fc.weight', 'fc.bias'])


In [7]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
model = load_model(model_path=model_path, device=device, n_classes=n_classes)

dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\synthetic\pickled_datasets\big_synth_5_3_25.pkl"

In [8]:
model

NaiveCaller(
  (cnn): Sequential(
    (0): Conv1d(1, 4, kernel_size=(5,), stride=(2,))
    (1): ReLU()
    (2): MaxPool1d(kernel_size=5, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv1d(4, 16, kernel_size=(5,), stride=(3,))
    (4): ReLU()
    (5): MaxPool1d(kernel_size=5, stride=3, padding=0, dilation=1, ceil_mode=False)
    (6): Conv1d(16, 64, kernel_size=(5,), stride=(1,))
    (7): ReLU()
    (8): Conv1d(64, 128, kernel_size=(5,), stride=(1,))
    (9): ReLU()
  )
  (lstm): LSTM(128, 256, num_layers=3, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=512, out_features=17, bias=True)
)

In [9]:
df = pd.read_pickle(dataset_path)

In [10]:
df.head()

Unnamed: 0,motif_seq,base_seq,squiggle
0,"[6, 12, 13, 2, 13, 13, 3, 13, 13, 5, 13, 13, 7...",ACTGGTGTAGTCACTGATTGACACATCAACATAAAAAGCTGTTACT...,"[525, 525, 508, 521, 520, 521, 530, 519, 465, ..."
1,"[10, 10, 7, 10, 10, 8, 10, 11, 1, 11, 11, 2, 1...",CACTGATTGAAGTCGATCGAAGTCGATCGGGGGTCGCTATAGAGTA...,"[512, 506, 510, 512, 510, 512, 514, 508, 517, ..."
2,"[5, 13, 13, 7, 13, 14, 1, 14, 14, 4, 14, 14, 7...",GCTAGAAGTTCTGTAGGTCCCGCATAAAAAGCTATAAAAAGCTGGG...,"[513, 515, 513, 506, 469, 480, 483, 485, 481, ..."
3,"[8, 9, 10, 1, 10, 10, 2, 10, 10, 3, 10, 10, 7,...",AGTGGACTCGCGGCCTTAGCTAAGTCGATCGAGGCATTCGCCAGGA...,"[537, 534, 535, 535, 538, 545, 544, 464, 492, ..."
4,"[15, 5, 15, 15, 8, 15, 16, 1, 16, 16, 4, 16, 1...",GACTAGTACAGAAGTTCTGTAGGTCCCGCAGACTAGTACAGACTAG...,"[493, 503, 501, 510, 502, 490, 505, 507, 500, ..."


In [12]:
X, y = load_training_data(
       dataset_path, column_x='squiggle', column_y='motif_seq', payload=False, sampling_rate=0.1)

In [13]:
X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42)

In [14]:
from evaluation import evaluate_cycle_prediction

In [25]:

counter = 0
sum_diff = 0

for x, y in tqdm(zip(X_test, y_test), total=len(X_test)):
    
    input_sequence = normalize([x], norm='l1')
    input_sequence = torch.tensor(
        input_sequence, dtype=torch.float32)
    input_sequence = input_sequence.view(
        1, 1, len(x)).to(device)
    
    model_output = model(input_sequence)
    model_output = model_output.permute(1, 0, 2)
    
    label_lengths = torch.tensor([len(y)])
    target_sequence = torch.tensor(y).to(device)

    n_timesteps = model_output.shape[0]
    print(n_timesteps/len(target_sequence))
    input_lengths = torch.tensor([n_timesteps])    
    model_output_flattened = model_output.view(
        model_output.shape[0] * model_output.shape[1], n_classes)

    loss = ctc(
        model_output, target_sequence, input_lengths, label_lengths)
    

    if loss.item() > 1:
        continue

    
    greedy_transcript = " ".join(greedy_decoder(model_output))
    beam_transcript = beam_search_ctc(
        model_output_flattened.detach().cpu(), beam_width=10, return_alignments=True)
    
    actual_transcript = " ".join([str(i) for i in y])

    print(sort_transcript(actual_transcript))
    print()

    t = [sort_transcript(i) for i in beam_transcript]

    
    for i in t:
        print(i)

    break


    #print()
    actual_transcript = " ".join([str(i) for i in y])

    #print(greedy_transcript)
    #print(actual_transcript)

    decoded_prediction = sort_transcript(beam_transcript)
    search_prediction = sort_transcript(actual_transcript)
    n_motifs = sum([len(i) for i in search_prediction])
    #print(decoded_prediction)
    #print(search_prediction)
    if n_motifs == 0:
        continue

    found_motifs_caller = evaluate_cycle_prediction(decoded_prediction, search_prediction)
    print(found_motifs_caller)
    print()
    
    #greedy_ratio = ratio(greedy_transcript, actual_transcript)
    #beam_ratio = ratio(beam_transcript, actual_transcript)
    #sum_diff += beam_ratio - greedy_ratio

    counter += 1
    if counter == 50:
        print(sum_diff)
        break

  0%|          | 0/2000 [00:00<?, ?it/s]

3.0232558139534884
[[], [], [5, 7], [1, 2, 3, 6], [4, 5, 7, 8], [2, 5, 7, 8], [], []]

[[], [], [5, 8], [1, 4, 5, 8], [1, 4, 6, 8], [1, 4, 7, 8], [], []]
[[], [], [5, 8], [1, 2, 5, 8], [1, 4, 6, 8], [1, 4, 7, 8], [], []]
[[], [], [6, 8], [1, 4, 5, 8], [1, 4, 6, 8], [1, 4, 7, 8], [], []]
[[], [], [5, 8], [2, 4, 5, 8], [1, 4, 6, 8], [1, 4, 7, 8], [], []]
[[], [], [5, 8], [1, 5, 5, 8], [1, 4, 6, 8], [1, 4, 7, 8], [], []]
[[], [], [5, 8], [1, 4, 5, 8], [1, 4, 6, 8], [1, 4, 5, 8], [], []]
[[], [], [5, 8], [1, 4, 5, 8], [1, 4, 7, 8], [1, 4, 7, 8], [], []]
[[], [], [5, 8], [1, 4, 6, 8], [1, 4, 6, 8], [1, 4, 7, 8], [], []]
[[], [], [5, 8], [1, 4, 5, 8], [1, 4, 6, 8], [1, 4, 6, 8], [], []]
[[], [], [6, 8], [1, 2, 5, 8], [1, 4, 6, 8], [1, 4, 7, 8], [], []]
