In [2]:
import pandas as pd
import torch
from nn import MotifCaller, NaiveCaller
from training_data import data_preproc, load_training_data
from utils import get_savepaths
import numpy as np
from sklearn.preprocessing import normalize
from greedy_decoder import GreedyCTCDecoder
from Levenshtein import ratio
from beam_search_decoder import decode

In [3]:
n_classes = 17
model_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\models\synthetic\local_trained.pth"

In [4]:
labels_int = np.arange(n_classes).tolist()
labels = [f"{i}" for i in labels_int] # Tokens to be fed into greedy decoder
greedy_decoder = GreedyCTCDecoder(labels = labels)

In [5]:
def load_model(model_path, device):
    """
    Loading model purely for inference
    Will need to lead optimizer to fine tune
    """
    # Model Definition
    model = NaiveCaller(num_classes=17)
    
    if device == torch.device('cpu'):
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(model_path)

    model.load_state_dict(checkpoint['model_state_dict'])

    model = model.to(device)
    return model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
    

# Load model
# device
# port these useful methods to utils

In [6]:
model = load_model(model_path=model_path, device=device)

In [7]:
dataset = get_savepaths(running_on_hpc=False)[0]

In [8]:
df = pd.read_pickle(dataset)
df.head()

Unnamed: 0,motif_seq,base_seq,squiggle
0,"[5, 11, 11, 7, 11, 11, 8, 11, 12, 2, 12, 12, 5...",TTTATCGTCGTCACATCAGTCGACATCAGTCGGCATGAAGACACTA...,"[504, 506, 491, 502, 504, 515, 458, 472, 458, ..."
1,"[14, 14, 8, 14, 15, 2, 15, 15, 5, 15, 15, 7, 1...",TGACGTCGGATGACGTCGGCAGCGCCACCAACTCCACAAATGACGT...,"[469, 480, 479, 479, 481, 482, 483, 498, 505, ..."
2,"[10, 10, 3, 10, 10, 7, 10, 11, 2, 11, 11, 4, 1...",GGACAGCTAGGGACAGCTACACCCCCGTATTTTGAGCGGGGGACAG...,"[465, 470, 479, 486, 539, 532, 533, 525, 527, ..."
3,"[12, 1, 12, 12, 4, 12, 12, 6, 12, 12, 7, 12, 1...",TCGCCTTCATACCCCACTAACGTAGAGTACTGCCCTTCATACCCCT...,"[485, 487, 496, 495, 488, 488, 490, 488, 524, ..."
4,"[11, 12, 3, 12, 12, 4, 12, 12, 6, 12, 12, 7, 1...",AGTCGCCTTCATACCCACCCCCGTATTTTGAGCGGCCTTCATACCC...,"[544, 556, 539, 547, 536, 534, 538, 493, 488, ..."


In [9]:
dataset_path, model_save_path, file_write_path = get_savepaths(
        running_on_hpc=False)

X, y = load_training_data(
        dataset_path, column_x='squiggle', column_y='motif_seq',
        sampling_rate=0.3)


Index(['motif_seq', 'base_seq', 'squiggle'], dtype='object')


In [10]:
from sklearn.preprocessing import normalize
import torch.nn as nn

ctc = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

In [None]:
import numpy as np
import heapq

def update_alignments(alignments, alignment_probs, top_n):

    if len(alignments) == 0:
        alignments.append(top_n)
        alignment_probs.append(top_n_probs)
        return alignments,
    
    for ind, alignment in enumerate(alignments):
        last_char = alignment[-1]
        for i, prob in zip(top_n, top_n_probs):
            if i == last_char:  # If it's the same as before (whether blank or repeated char - it gets collapsed)
                alignment_probs[ind] += prob

            elif last_char == blank_index:  # If previous is a blank and this is a character, we can get rid of the previous blank
                new_alignment = alignment[-1] + i
                if new_alignment in alignments:
                    alignments[alignments.index(new_alignment)] += prob
                else:
                    alignments.append(alignment[-1] + i)
                    alignment_probs.append(alignment_probs[ind] + prob)

            else:
                alignments.append(alignment + i)
                alignment_probs.append(alignment_probs[ind] + prob)
                
    # go through all the alignments and merge the leftover ones
    # return the most probable one

    

def beam_search_ctc(prob_matrix, beam_width=3, blank=0, n_classes=17):
    
    # Get top n probabilities and their corresponding indices for each time step
    # Create a list of alignments sequentially, collapsing and combining as you go
    indices = np.arange(n_classes)
    alignments = []
    probabilities = []

    for ind, probs in enumerate(prob_matrix):
        
        # Get the top 3
        # previous_alignments adding - collapse at will - if the same as previous, don't add 
        # If new and the previous is blank, remove the blank
        top_n = heapq.nlargest(3, enumerate(probs), key=lambda x: x[1])

        alignments



In [38]:
counter = 0
for x_, y_ in zip(X, y):
    #x = normalize([x_])
    #input_sequence = normalize([x_], norm='l1')
    input_sequence = torch.tensor(
        x_, dtype=torch.float32)
    input_sequence = input_sequence.view(1, 1, len(x_)).to(device)
    model_output = model(input_sequence)
    model_output = model_output.permute(1, 0, 2)
    
    label_lengths = torch.tensor([len(y_)])
    target_sequence = torch.tensor(y_).to(device)

    
    n_timesteps = model_output.shape[0]
    input_lengths = torch.tensor([n_timesteps])

    print(model_output.shape)
    
    model_output_flattened = model_output.view(model_output.shape[0]* model_output.shape[1], n_classes)
    print(model_output_flattened.shape)

    loss = ctc(
        model_output, target_sequence, input_lengths, label_lengths)
    #print(loss.item())
    
    greedy_transcript = " ".join(greedy_decoder(model_output))
    beam_transcript = beam_search_ctc(model_output_flattened.detach().cpu())
    actual_transcript = " ".join([str(i) for i in y_])
    print(greedy_transcript)
    print(beam_transcript)
    print(actual_transcript)
    print(ratio(greedy_transcript, actual_transcript))
    print(ratio(beam_transcript, actual_transcript))
    print(loss)

    break
    
    counter += 1
    if counter == 2000:
        break
    

torch.Size([106, 1, 17])
torch.Size([106, 17])
[(0, tensor(0.)), (10, tensor(-19.0663)), (4, tensor(-19.7087))]
[(0, tensor(-2.2530e-05)), (10, tensor(-10.6991)), (9, tensor(-17.3199))]
[(0, tensor(0.)), (10, tensor(-18.4950)), (9, tensor(-23.0844))]
[(0, tensor(-0.0006)), (10, tensor(-7.3804)), (9, tensor(-13.9946))]
[(0, tensor(-1.4186e-05)), (10, tensor(-12.0249)), (5, tensor(-12.7543))]
[(0, tensor(-0.0005)), (10, tensor(-7.5912)), (9, tensor(-11.7136))]
[(0, tensor(0.)), (10, tensor(-19.5494)), (9, tensor(-20.2829))]
[(0, tensor(-0.0004)), (10, tensor(-7.7915)), (9, tensor(-11.9784))]
[(0, tensor(-3.9577e-05)), (8, tensor(-10.7000)), (7, tensor(-11.8956))]
[(0, tensor(-0.0020)), (10, tensor(-6.2384)), (9, tensor(-11.6482))]
[(0, tensor(0.)), (10, tensor(-20.0583)), (9, tensor(-22.7519))]
[(0, tensor(-0.0003)), (11, tensor(-8.6603)), (10, tensor(-9.1101))]
[(0, tensor(-0.0004)), (1, tensor(-8.5233)), (2, tensor(-8.8442))]
[(0, tensor(-0.0035)), (11, tensor(-5.7383)), (10, tensor(-8

In [61]:
from model_config import ModelConfig

In [62]:
model_config = ModelConfig(
    n_classes=n_classes, hidden_size=256, window_size=1024, window_step=800, train_epochs=50, device=device,
    model_save_path="", write_path="", dataset='synthetic', windows=True, sampling_rate=1.0
)

In [66]:
from training_loop import run_epoch
import torch.optim as optim
from greedy_decoder import GreedyCTCDecoder

In [67]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
labels_int = np.arange(n_classes).tolist()
labels = [f"{i}" for i in labels_int] # Tokens to be fed into greedy decoder
greedy_decoder = GreedyCTCDecoder(labels=labels)

In [69]:
result_dict = run_epoch(
    model=model, model_config=model_config, X=X, y=y, ctc=ctc,
    optimizer=optimizer, decoder=greedy_decoder)

100%|██████████| 2997/2997 [00:17<00:00, 167.28it/s]


In [75]:
result_dict['losses'][:500]

array([4.43492031, 3.12125254, 4.32518435, 3.63519311, 2.99334025,
       4.46822262, 2.96295476, 4.26412487, 4.25731993, 3.65833068,
       3.53381896, 3.62321472, 3.70630789, 3.79388523, 3.80692482,
       4.41694069, 3.25016284, 4.02096653, 4.08079147, 2.94028783,
       3.65108752, 3.69296646, 3.59262991, 3.48503113, 4.36539078,
       3.61073542, 4.01794147, 4.51349545, 3.69478416, 4.07377005,
       3.52363276, 3.65702391, 4.58049345, 3.65981936, 4.16589594,
       4.36304283, 3.97712994, 3.81253338, 4.91858339, 3.19839787,
       3.82459402, 3.97900391, 3.48522186, 4.26855755, 3.50788808,
       3.45172596, 4.58332777, 3.93838859, 3.8704803 , 3.57057595,
       3.34357715, 3.42299008, 3.44375992, 4.59503126, 3.65452838,
       3.50559545, 3.79823709, 3.97137856, 3.20284557, 3.33833027,
       4.01198483, 4.26743793, 4.20309877, 3.68024731, 2.64314175,
       3.15585208, 3.95795512, 4.24880981, 4.16048098, 3.24434352,
       3.70698881, 3.86576986, 3.04915285, 3.12248182, 2.59087

Best sequence: [('01', np.float64(0.11616)), ('010', np.float64(0.07680000000000001))]
