In [12]:
from custom_torch_dataset import SwipeDataset
import os
from torch.utils.data import random_split
import torch

dataset_path = os.path.join(os.getcwd(), "dataset")

data = SwipeDataset(data_dir=dataset_path,
                    batch=False)

gen = torch.Generator().manual_seed(42)

train_set, val_set, test_set = random_split(data, [0.8, 0.1, 0.1], generator=gen)

In [21]:
import torch
from torch.nn.utils.rnn import pack_sequence

def collate_fn(batch):
    """
    :param batch: List of tuples (input, word, word_tensor)
                  - input: (T, 6)
                  - word: a string of characters
                  - word_tensor: encoded word as indicies with 0 as the blank

    """
    # Sort batch by sequence length (descending order)
    batch.sort(key=lambda x: x[0].shape[0], reverse=True)

    inputs, words, targets = zip(*batch)
    input_lengths = torch.LongTensor([x.shape[0] for x in inputs])  # store the lengths of inputs
    input = pack_sequence(inputs)   # pack the inputs

    target_lengths = torch.LongTensor([len(x) for x in words])
    targets = torch.cat(targets)    # concatenate all the targets

    return input, targets, input_lengths, target_lengths, words

In [22]:
from torch.utils.data import DataLoader
batch_size = 128

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

dataloaders = {"train": train_loader,
               "val": val_loader}

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_sequence

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

class CTCEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=2, output_size=27, bidirectional=True, dropout = 0.5, lstm_dropout = 0):
        super(CTCEncoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.dropout = dropout
        self.lstm_dropout = lstm_dropout
        
        self.lstm = nn.LSTM(input_size=input_size, 
                            hidden_size=hidden_size, 
                            num_layers=num_layers, 
                            batch_first=False,
                            bidirectional=bidirectional,
                            dropout=lstm_dropout)
        self.layer_norm = nn.LayerNorm(2 * hidden_size if bidirectional else hidden_size)
        # self.batch_norm = nn.BatchNorm1d(2 * hidden_size if bidirectional else hidden_size)
        self.drop = nn.Dropout(p=dropout)
        
        self.fc = nn.Linear(2 * hidden_size if bidirectional else hidden_size, output_size)
        
    def forward(self, x):
        """
        : param x: (batch_size, seq_len, input_size)
        : return logits: (seq_len, batch_size, output_size)
        """
        lstm_outputs, _ = self.lstm(x)  # lstm_out shape: (seq_len, batch_size, hidden_size*2 if bidirectional)

        lstm_out = pad_packed_sequence(lstm_outputs, batch_first=False)[0]

        lstm_out = self.layer_norm(lstm_out)
        lstm_out = self.drop(lstm_out)

        # lstm_out = self.batch_norm(lstm_out.permute(0, 2, 1))
        # lstm_out = self.drop(lstm_out.permute(0, 2, 1))

        logits = self.fc(lstm_out)  # shape: (batch_size, seq_len, output_size)
        return F.log_softmax(logits, dim=-1)  # Log-softmax for CTC loss
    

In [26]:
import os
import torch
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

model_file = "lstm_n.pt"
save_path = os.path.join(os.getcwd(), "models", model_file)

t_model = torch.load(save_path, weights_only=False).to(device)

In [27]:
def test_model(model, test_data, datapoints):
    criterion = torch.nn.CTCLoss(blank=0, zero_infinity=True)

    model_outputs = []
    ground_truth = []
    # set the model into evaluation mode
    model.eval()

    running_loss = 0.0

    # Iterate over data.
    for inputs, targets, input_lengths, target_lengths, words in test_data:
        inputs = inputs.to(device)
        targets = targets.to(device)

        input_lengths = input_lengths.to(device)
        target_lengths = target_lengths.to(device)

        outputs = model(inputs)
        model_outputs.append(outputs.cpu())
        ground_truth.append(words)
        loss = criterion(outputs, targets, input_lengths, target_lengths)
        
        running_loss += loss * len(target_lengths)  # multiply by batch size

    avg_loss = running_loss / datapoints    # average over the entire test set

    print(f"Average test loss: {avg_loss}")

    return model_outputs, ground_truth

In [28]:
logits, truth = test_model(t_model, test_loader, len(test_set))

Average test loss: 1.1748939752578735


In [29]:
vocabulary = {'_': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8,
              'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16,
              'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24,
              'y': 25, 'z': 26}
reversed_vocab = {k: u for u, k in vocabulary.items()}

import torch
import torch.nn.functional as F
import heapq

def beam_search(logits, ground_truth, beam_width=3):
    """
    Decodes a tensor of logits using beam search, adapted for the provided input structure.

    Args:
        logits (torch.Tensor): A tensor of logits with shape (sequence_length, batch_size, vocab_size).
        ground_truth (list): A list of ground truth words (not directly used in beam search but included for consistency).
        reversed_vocab (dict): A dictionary mapping vocabulary indices to characters.
        beam_width (int): The width of the beam.

    Returns:
        list: A list of decoded strings, one for each sequence in the batch.
    """
    decoded_words = []
    for batch, _ in zip(logits, ground_truth):  # ground_truth is not used in the beam search.
        batched_results = []

        for b in range(batch.shape[1]):
            initial_beam = [([(0.0, "")], 0.0)]
            final_beams = []

            for t in range(batch.shape[0]):
                new_beam = []
                for seq, total_log_prob in initial_beam:
                    timestep_logits = batch[t, b, :]

                    topk_probs, topk_indices = torch.topk(timestep_logits, beam_width)

                    for i in range(beam_width):
                        char_index = topk_indices[i].item()
                        char = reversed_vocab.get(char_index, '_')
                        new_seq = seq + [(topk_probs[i].item(), char)]
                        new_total_log_prob = total_log_prob + topk_probs[i].item()
                        new_beam.append((new_seq, new_total_log_prob))

                initial_beam = heapq.nlargest(beam_width, new_beam, key=lambda x: x[1])

            best_sequence, _ = max(initial_beam, key=lambda x: x[1])
            decoded_word = "".join([char for log_prob, char in best_sequence[1:] if char != "_"])
            batched_results.append(decoded_word)

        decoded_words.append(batched_results)

    return decoded_words


In [30]:
beam_words = beam_search(logits, truth)

In [31]:
def evaluate(predictions, truth):
    exact_match = 0
    length_match = 0
    same_first = 0
    total = 0
    
    for i in range(len(predictions)):
        for j in range(len(predictions[i])):
            pred = predictions[i][j]
            ground = truth[i][j]
            total += 1

            if pred == ground:
                exact_match += 1
            
            if len(pred) == len(ground):
                length_match += 1
            
            if pred[0] == ground[0]:
                same_first += 1
    

    EM = exact_match / total
    LM = length_match / total
    FM = same_first / total

    print(f"Exact Match: {EM}")
    print(f"Correct Length: {LM}")
    print(f"Correct First Letter: {FM}")

In [32]:
evaluate(beam_words, truth)

Exact Match: 0.24599434495758718
Correct Length: 0.4665409990574929
Correct First Letter: 0.7224316682375118


In [37]:
from transformers import T5ForConditionalGeneration, AutoTokenizer

path_to_model = "ai-forever/T5-large-spell"

model = T5ForConditionalGeneration.from_pretrained(path_to_model).to(device)
tokenizer = AutoTokenizer.from_pretrained(path_to_model)

In [None]:
prefix = "spelling: "
sentence = prefix + beam_words[0][0]

encodings = tokenizer(sentence, return_tensors="pt")
encodings = encodings.to(device)
generated_tokens = model.generate(**encodings)
answer = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
print(answer)

['Everything']


In [None]:
from tqdm import tqdm

prefix = "grammar: "

def autocorrect(predictions):
    autocorrected_words = []

    for batch in tqdm(predictions):
        sentences = [prefix + word for word in batch]
        encodings = tokenizer(sentences, return_tensors="pt", padding=True)
        encodings = encodings.to(device)
        generated_tokens = model.generate(**encodings)
        answers = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        autocorrected_words.append(answers)

    return autocorrected_words

In [53]:
final_predictions = autocorrect(beam_words)

100%|██████████| 17/17 [00:51<00:00,  3.03s/it]


In [54]:
def evaluate_autocorrect(predictions, truth):
    exact_match = 0
    total = 0
    
    for i in range(len(predictions)):
        for j in range(len(predictions[i])):
            pred = predictions[i][j].lower()
            ground = truth[i][j]
            total += 1

            if pred == ground:
                exact_match += 1
    

    EM = exact_match / total

    print(f"Exact Match: {EM}")

In [55]:
evaluate_autocorrect(final_predictions, truth)

Exact Match: 0.17436380772855797
