In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from crnn import CNN_BiGRU_Classifier
import math
from tqdm.notebook import tqdm
from training_data import data_preproc, load_pre_data
from sklearn.model_selection import train_test_split
from greedy_decoder import GreedyCTCDecoder, beam_search_decoder
from utils import get_actual_transcript, get_model_savepath, get_motifs_identified
import torchaudio
import datetime

# Tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()


# Loading the data
dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\datasets\empirical\full_test_dataset_v4_spacers.pkl"

sampled_dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\datasets\empirical\sampled_test_dataset_v4_spacers.pkl"

model_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\models\empirical\model_10_11_24.pth"

# Can add it this way - the function there is literal ass
squiggle_column = ''
label_column = ''

X,y = data_preproc(dataset_path=sampled_dataset_path, chop_reads=1)
print(f"Number of samples: {len(X)}")

100%|██████████| 6894/6894 [00:12<00:00, 553.65it/s]


Number of samples: 6894


In [2]:

# Loading the model
import pandas as pd
import numpy as np

dataset_df = pd.read_pickle(dataset_path)

output_classes = 19

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
print(f"Running on {device}")

labels_int = np.arange(output_classes).tolist()
labels = [f"{i}" for i in labels_int] # Tokens to be fed into greedy decoder
greedy_decoder = GreedyCTCDecoder(labels=labels)

# Model Parameters
input_size = 1  # Number of input channels
hidden_size = 256
num_layers = 4
output_size = output_classes  # Number of output classes
dropout_rate = 0.2

saved_model = True

# Model Definition
model = CNN_BiGRU_Classifier(input_size, hidden_size, num_layers, output_size, dropout_rate)

if saved_model:
    model_path = model_path
    if device == torch.device('cpu'):
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])

model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001)
ctc_loss = nn.CTCLoss(zero_infinity=True)

X_test, y_test = X, y

torch.autograd.set_detect_anomaly(True)

n_classes = output_classes
step_sequence = 100
window_overlap = 50
length_per_sample = 150
model_output_split_size = 1

epochs = 2


Running on cuda


In [47]:
def create_payload_spacer_sequence(payload):

    payload_sequence = []

    cycle_number = 11


    for i in payload:
        for j in i:
            payload_sequence.append(cycle_number)
            payload_sequence.append(j)
            payload_sequence.append(cycle_number)
        cycle_number+=1

    return payload_sequence

def sort_transcript(transcript):

    cycles = [[] for i in range(8)]

    split_transcript = transcript.split()
    split_transcript = [int(i) for i in split_transcript  if i != '']

    for i in range(len(split_transcript)):
        
        found_motif = split_transcript[i]
        
        # If we have a payload motif
        if found_motif < 9:

            # finding the spacers - only for payload cycles
            if i > 0:

                # Checking for Back Spacer
                if split_transcript[i-1] > 10:
                    cycle_number = split_transcript[i-1] - 11
                    print(split_transcript[i-1])
                    print(cycle_number)
                    cycles[cycle_number].append(split_transcript[i])

                # Checking for Forward Spacer
                elif i < len(split_transcript) - 1:
                    if split_transcript[i+1] > 10:
                        cycle_number = split_transcript[i+1] - 11
                        print(split_transcript[i+1])
                        print(cycle_number)
                        cycles[cycle_number].append(split_transcript[i])

            else:
                if i < len(split_transcript) - 1:
                    # Checking for Forward Spacer
                    if split_transcript[i+1] > 10:
                        cycle_number = split_transcript[i+1] - 11
                        print(split_transcript[i+1])
                        print(cycle_number)
                        cycles[cycle_number].append(split_transcript[i])   

    return cycles



def sort_transcript_caller(transcript):

    # Looking for spacers instead

    cycles = [[] for i in range(8)]

    split_transcript = transcript.split()
    split_transcript = [int(i) for i in split_transcript  if i != '']

    for i in range(len(split_transcript)-1):
        current_cycle = 11

        found_motif = split_transcript[i]
        # Looking for spacers
        if found_motif >= current_cycle:
            current_cycle = split_transcript[i] - 11
            if split_transcript[i+1] < 9:
                cycles[current_cycle].append(split_transcript[i+1])
    return cycles

In [21]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [74]:


payloads = dataset_df['Payload'].tolist()

# Test Loop
model.eval()
test_loss = 0.0
distances_arr = []
motifs_identifed_arr = []
greedy_transcripts = []
actual_transcripts = []
with torch.no_grad():
    for i in range(len(X_test[1:10])):

        if i % 500 == 0:
            print(f"Processing {i}/{len(X_test)}")

        test_sequence, target_sequence = torch.tensor(X_test[i]).to(device), torch.tensor(y_test[i]).to(device)



        model_output_timestep = model(test_sequence) # Getting model output

        input_lengths = torch.tensor(X_test[i].shape[0])
        target_lengths = torch.tensor(len(target_sequence))

        loss = ctc_loss(model_output_timestep, target_sequence, input_lengths, target_lengths)

        print(loss.item())
        test_loss += loss.item()

        greedy_result = greedy_decoder(model_output_timestep)
        beam_result = beam_search_decoder(model_output_timestep)
        greedy_transcript = " ".join(greedy_result)
        actual_transcript = get_actual_transcript(target_sequence)
        greedy_transcripts.append(greedy_transcript)
        actual_transcripts.append(actual_transcript)
        payload_spacer = create_payload_spacer_sequence(payloads[i])

        print(greedy_result)
        print(actual_transcript)
        print(payload_spacer)
        print()


        cycles_motif_caller = sort_transcript(greedy_result)
        cycles_motif_read = sort_transcript(actual_transcript)
        

        print(cycles_motif_caller)
        print(cycles_motif_read)
        print(payloads[i])
        print()

        correct_motifs_predicted_motif_read = 0
        correct_motifs_predicted_motif_caller = 0

        incorrect_motifs_predicted_motif_read = 0
        incorrect_motifs_predicted_motif_caller = 0

        for i,j,k in zip(cycles_motif_read, cycles_motif_caller, payloads[i]):

            correct_motifs_predicted_motif_read += len(set(i).intersection(set(k)))
            correct_motifs_predicted_motif_caller += len(set(j).intersection(set(k)))
            incorrect_motifs_predicted_motif_read += len(set(i).difference(set(k)))
            incorrect_motifs_predicted_motif_caller += len(set(j).difference(set(k)))

        print(f"Correct motifs caller {correct_motifs_predicted_motif_caller}")
        print(f"Incorrect motifs caller {incorrect_motifs_predicted_motif_caller}")
        print(f"Correct motifs read {correct_motifs_predicted_motif_read}")
        print(f"Incorrect motifs read {incorrect_motifs_predicted_motif_read}")
        print()


        motif_err = torchaudio.functional.edit_distance(actual_transcript, greedy_transcript) / len(actual_transcript)
        distances_arr.append(motif_err)

        motifs_identifed = get_motifs_identified(actual_transcript, greedy_transcript)
        motifs_identifed_arr.append(motifs_identifed)


"""
test_loss /= len(X_test)
test_accuracy = np.mean(distances_arr)
motifs_identifed = np.mean(motifs_identifed_arr)
print(f"Test Loss: {test_loss:.4f}, Test Edit Distance: {test_accuracy:.4f}, Motifs Identified: {motifs_identifed:.4f}")
"""

Processing 0/6894
5.602296352386475
[([], 0.0), ([1], 0.0)]


  for s in range(S):


9 1 9 8 10 11 8 11 3 12 13 13 14 1 14 6 16 17 17 18
 10 8 10 14 1 14 16 3 16 18 7 18
[11, 3, 11, 11, 6, 11, 11, 7, 11, 11, 8, 11, 12, 3, 12, 12, 4, 12, 12, 5, 12, 12, 8, 12, 13, 1, 13, 13, 2, 13, 13, 4, 13, 13, 6, 13, 14, 1, 14, 14, 2, 14, 14, 5, 14, 14, 6, 14, 15, 4, 15, 15, 5, 15, 15, 6, 15, 15, 8, 15, 16, 1, 16, 16, 2, 16, 16, 5, 16, 16, 7, 16, 17, 1, 17, 17, 3, 17, 17, 4, 17, 17, 7, 17, 18, 2, 18, 18, 4, 18, 18, 5, 18, 18, 7, 18]

11
0
11
0
14
3
14
3
14
3
16
5
18
7
[[8, 3], [], [], [1, 6], [], [], [], []]
[[], [], [], [1], [], [3], [], [7]]
[[3, 6, 7, 8], [3, 4, 5, 8], [1, 2, 4, 6], [1, 2, 5, 6], [4, 5, 6, 8], [1, 2, 5, 7], [1, 3, 4, 7], [2, 4, 5, 7]]

Correct motifs caller 4
Incorrect motifs caller 0
Correct motifs read 2
Incorrect motifs read 1

5.112450122833252
[([11, 5], 0.0), ([11, 5, 1], 0.0)]
9 10 6 10 11 3 11 12 7 12 13 14 5 14 5 15 6 16 17 7 6
 11 3 11 13 5 13 14 5 14 18 6 18
[11, 3, 11, 11, 6, 11, 11, 7, 11, 11, 8, 11, 12, 3, 12, 12, 4, 12, 12, 5, 12, 12, 8, 12, 13, 1, 1

'\ntest_loss /= len(X_test)\ntest_accuracy = np.mean(distances_arr)\nmotifs_identifed = np.mean(motifs_identifed_arr)\nprint(f"Test Loss: {test_loss:.4f}, Test Edit Distance: {test_accuracy:.4f}, Motifs Identified: {motifs_identifed:.4f}")\n'

In [11]:
model_output_timestep[5]

tensor([ -0.1066, -16.0322, -15.7049, -17.1003, -16.0035, -15.6052, -20.9624,
        -16.2155, -15.0726,  -3.8783,  -7.3401, -15.4738, -16.0508, -17.0568,
        -14.7530, -17.4408, -16.7398, -17.2372, -16.6317], device='cuda:0')