In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Create dummy model output
batch_size = 2
max_seq_len = 50  # Time steps
num_classes = 5  # (A, T, C, G, blank)

# Random model output logits
logits = torch.randn(batch_size, max_seq_len, num_classes)  # Shape: (batch, seq_len, num_classes)

# Convert logits to log probabilities for CTC Loss
log_probs = F.log_softmax(logits, dim=-1)  # Ensure values sum to 1

# Transpose for CTC Loss (T, N, C)
log_probs = log_probs.permute(1, 0, 2)  # Shape: (seq_len, batch, num_classes)

# Define targets (flattened)
targets = torch.tensor([0, 1, 2, 3, 0, 1, 2, 3])  # Example (A=0, T=1, C=2, G=3)

# Define sequence lengths
input_lengths = torch.tensor(batch_size* [50])  # Both sequences have full length
target_lengths = torch.tensor(batch_size* [4])  # Targets are length 4 each

# Initialize CTC loss function
ctc_loss_fn = nn.CTCLoss(blank=4)  # Assuming blank token is index 4

# Compute CTC Loss
loss = ctc_loss_fn(log_probs, targets, input_lengths, target_lengths)
print("CTC Loss:", loss.item())


CTC Loss: 16.153362274169922


In [3]:

T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
# Initialize random batch of input vectors, for *size = (T,C)
input = torch.randn(T, C).log_softmax(1).detach().requires_grad_()
input_lengths = torch.tensor(T, dtype=torch.long)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
print(loss)

tensor(6.5846, grad_fn=<MeanBackward0>)


In [5]:
input.shape

torch.Size([50, 20])

In [11]:
target.shape

torch.Size([17])

In [12]:
target_lengths

tensor(17)

In [13]:
input_lengths

tensor(50)

In [6]:
# Target are to be padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
S = 30      # Target sequence length of longest target in batch (padding length)
S_min = 10  # Minimum target length, for demonstration purposes
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)

In [13]:
target

tensor([[ 1, 13, 12, 19,  5, 11, 14, 10, 19, 10,  7, 16, 17, 15,  4,  3,  9, 10,
          8,  4,  3, 17, 11,  3,  2, 14,  7, 14,  6, 17],
        [ 2,  8, 12, 12,  9, 19,  8, 14,  6, 10, 10,  1,  7, 12, 12, 17, 17,  6,
          3,  5, 18,  3, 11, 17, 11,  4,  1,  1, 11,  7],
        [ 5,  1,  9, 16, 12, 11,  4, 14, 19,  3,  1,  1, 12,  7, 19,  9,  3,  8,
          6, 15, 12, 16, 15, 12, 12,  3,  4,  3, 19, 17],
        [16, 16,  5,  1, 15, 10, 17, 17, 11, 10, 10,  4, 15,  8, 10,  8, 16, 18,
         10,  9,  4,  7,  4, 14, 10, 15, 15, 10,  5,  2],
        [ 6,  2, 10,  7, 14,  6,  9,  7,  6, 10, 11,  6,  9, 12, 16, 16, 11,  2,
          8, 12, 18, 17,  2,  9, 16,  9, 11,  6,  9, 17],
        [ 3, 15,  3,  4, 12,  2,  5, 15, 18,  9, 19, 10, 11, 14, 19,  4, 11, 13,
          5, 10, 12, 15, 13,  1, 13, 16,  3, 15, 19, 17],
        [10,  1,  5, 14, 14,  6,  3,  7, 10,  4, 12, 13,  6, 14,  9, 11, 10,  9,
         17, 14,  8,  1, 15,  2, 12, 13,  6, 11, 12, 12],
        [ 1, 14, 14,  4, 14

In [7]:
target_lengths

tensor([22, 29, 27, 19, 12, 29, 25, 10, 20, 29, 24, 14, 10, 15, 29, 21])

In [8]:
input_lengths

tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50])

In [3]:

# Target are to be un-padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)

In [40]:
input_lengths

tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50])

In [10]:
targets

tensor([0, 1, 2, 3, 0, 1, 2, 3])

In [14]:

import pandas as pd
from training_data import load_training_data
from sklearn.model_selection import train_test_split

dataset_path = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\synthetic\pickled_datasets\rc.pkl"
X, y = load_training_data(
       dataset_path, column_x='squiggle', column_y='motif_seq', payload=False, sampling_rate=0.1)
X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42)

4954
Selected 4954 forward reads


In [30]:
from nn import NaiveCaller
from torch.nn.utils.rnn import pad_sequence

model = NaiveCaller(num_classes=17)

In [None]:
def pad_input_seq_to_longest(input_seqs):
    

In [None]:
ctc_loss = nn.CTCLoss()

In [45]:
for ind in range(len(X_train)):

    input_seqs = X_train[ind: ind + 32]

    input_seqs = pad_sequence([torch.tensor(
                i, dtype=torch.float32) for i in X_train[ind: ind + 32]], batch_first=True)
    input_seqs = input_seqs.view(32, 1, input_seqs.shape[-1])
    
    model_output = model(input_seqs)

    targets = pad_sequence([torch.tensor(
                i, dtype=torch.float32) for i in y_train[ind: ind + 32]], batch_first=True)
    target_lengths = torch.tensor([len(i) for i in y_train[ind: ind + 32]])
    
    model_output = model_output.permute(1, 0, 2)
    n_timesteps = model_output.shape[0]
    input_lengths = torch.tensor([n_timesteps for i in range(32)])

    print(target_lengths)
    
    loss = ctc_loss(model_output, targets, input_lengths, target_lengths)
    loss.backward()
    print(loss.item())
    break


tensor([35, 38, 45, 21, 41, 24, 25, 35, 25, 11, 35, 35, 40, 41, 22, 35, 37, 18,
        43, 47, 28, 35, 43, 28, 43, 23, 40, 48, 37, 45, 11, 44])
9.00639533996582
