In [1]:
from custom_torch_dataset import SwipeDataset
import os
from torch.utils.data import random_split

dataset_path = os.path.join(os.getcwd(), "dataset")

data = SwipeDataset(data_dir=dataset_path,
                    batch=True,
                    batch_first=True)

train_set, val_set, test_set = random_split(data, [0.8, 0.1, 0.1])

In [26]:
data[:][0][0].shape

torch.Size([1, 26, 6])

In [None]:
input_lengths = [x.shape[1] for x in data[:][0][:]]
max(input_lengths)

6

In [32]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def collate_fn(batch):
    # Separate inputs and targets
    inputs = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    
    # Pad inputs dynamically
    padded_inputs = pad_sequence(inputs, batch_first=True, padding_value=(0, 0, -1, 0, 0, 0))
    
    # Pad targets dynamically (if needed)
    padded_targets = pad_sequence(targets, batch_first=True, padding_value="PAD")
    
    return padded_inputs, padded_targets

test_loader = DataLoader(test_set, batch_size=32, collate_fn=collate_fn)

In [33]:
test_in, test_word, test_tensor = next(iter(test_loader))

TypeError: pad_sequence(): argument 'padding_value' (position 3) must be float, not tuple

In [30]:
import torch
from model import Seq2Seq

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
t_model = Seq2Seq(hidden_size=32,
                  num_layers=2,
                  bidirectional=True,
                  input_size=6,
                  max_letters=20,
                  force_ratio=0.7,
                  output_size=27).to(device)

In [62]:
t_input, t_word, t_word_tensor = train_set[0]
t_input = t_input.to(device)
t_word_tensor = t_word_tensor.to(device)

In [63]:
t_output = t_model(t_input, t_word_tensor)

In [64]:
log_probs = t_output.permute(1, 0, 2)
input_lengths = torch.LongTensor([t_output.shape[1]])
target_lengths = torch.LongTensor([len(t_word)])

In [65]:
loss_fn = torch.nn.CTCLoss(blank=0, zero_infinity=True)
loss = loss_fn(log_probs, t_word_tensor, input_lengths, target_lengths)
loss

tensor(3.2831, device='cuda:0', grad_fn=<MeanBackward0>)

In [61]:
loss.backward()

In [2]:
import time
import random
import numpy as np
import torch.nn as nn

def train_model(model, train_set, val_set, optimiser, criterion=nn.CTCLoss(), batch_size=32, num_epochs=10):

    train_losses = []
    val_losses = []

    since = time.time()

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print("-" * 10)

        for phase in ["train", "val"]:
            # set the mode of the model based on the phase and change the data used
            if phase == "train":
                model.train()
                batches = list(range(len(train_set)))
                data = train_set
            else:
                model.eval()
                batches = list(range(len(val_set)))
                data = val_set

            # batch like this because the words are of different lenghts and batching isnt used
            random.shuffle(batches)
            batches = np.array_split(batches, len(batches) // batch_size)

            for batch in batches:   # iterate over each batch of dataset
                batch_loss = 0
                current_loss = 0

                optimiser.zero_grad()
                # enable gradients only if in training mode
                with torch.set_grad_enabled(phase == "train"):
                    for i in batch:     # for every datapoint in the batch
                        input, word, word_tensor = data[i]

                        input = input.to(device)
                        word_tensor = word_tensor.to(device)
                        word_length = len(word)

                        output = t_model(input, word_tensor)
                        # rearrange the output for CTC loss
                        output = output.permute(1, 0, 2)    # (T, N, C)
                        # convert to tensors
                        input_lengths = torch.tensor([output.size(1)], dtype=torch.long).to(device)
                        target_lengths = torch.tensor([word_length], dtype=torch.long).to(device)

                        loss = criterion(output, word_tensor, input_lengths, target_lengths)
                        batch_loss += loss

                    # back prop only if in train
                    if phase == "train":
                        batch_loss.backward()
                        nn.utils.clip_grad_norm_(model.parameters(), 3)
                        optimiser.step()
                # calculates the total loss for epoch over all batches
                current_loss += batch_loss.item()
            # track epoch oss
            if phase == "train":
                epoch_loss = current_loss / len(train_set)
                train_losses.append(epoch_loss)
            else:
                epoch_loss = current_loss / len(val_set)
                val_losses.append(epoch_loss)
            
            print(f'{phase} Loss: {epoch_loss:.4f}')

        time_elapsed = time.time() - since
        print(f"Time elapsed: {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")       

    return model    

In [3]:
import torch
from model import Seq2Seq

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
t_model = Seq2Seq(hidden_size=32,
                  num_layers=2,
                  bidirectional=True,
                  input_size=6,
                  max_letters=20,
                  force_ratio=0.7,
                  output_size=27).to(device)

optimiser = torch.optim.SGD(t_model.parameters(), lr=0.001)
criterion = nn.CTCLoss(blank=0)
t_model = train_model(t_model, train_set, val_set, optimiser,criterion, num_epochs=5)

Epoch 1/5
----------
train Loss: nan
val Loss: nan
Time elapsed: 2m 6s
Epoch 2/5
----------


KeyboardInterrupt: 

In [70]:
test_in, test_word, test_ten = test_set[0]
test_in = test_in.to(device)

In [71]:
test_out = t_model(test_in)

In [72]:
vocabulary = {'_': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8,
              'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16,
              'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24,
              'y': 25, 'z': 26}
reversed_vocab = {k: u for u, k in vocabulary.items()}

In [73]:
def handle_outputs(decoder_output):
    indicies = torch.argmax(decoder_output.squeeze(1), dim=-1).tolist()
    words = []
    for word in indicies:
        characters = [reversed_vocab[i] for i in word]
        words.append(characters)
    
    return words

In [74]:
handle_outputs(test_out)

[['_',
  's',
  's',
  'e',
  's',
  's',
  'e',
  's',
  's',
  's',
  'e',
  's',
  's',
  's',
  'e',
  's',
  's',
  's',
  'e',
  's']]