In [1]:
from custom_torch_dataset import SwipeDataset
import os
from torch.utils.data import random_split
import torch

dataset_path = os.path.join(os.getcwd(), "dataset")

data = SwipeDataset(data_dir=dataset_path,
                    batch=False)

gen = torch.Generator().manual_seed(42)

train_set, val_set, test_set = random_split(data, [0.8, 0.1, 0.1], generator=gen)

In [2]:
import torch
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    """
    :param batch: List of tuples (input, word, word_tensor)
                  - input: (T, 6)
                  - word: a string of characters
                  - word_tensor: encoded word as indicies with 0 as the blank

    """
    # Sort batch by sequence length (descending order)
    batch.sort(key=lambda x: x[0].shape[0], reverse=True)
    inputs, words, targets = zip(*batch)
    input_lengths = torch.LongTensor([x.shape[0] for x in inputs])  # store the lengths of inputs

    target_lengths = torch.LongTensor([len(x) for x in words])
    targets = torch.cat(targets)    # concatenate all the targets

    padded_inputs = pad_sequence(inputs, batch_first=True)  # pad inputs to max length with zeros   (B, T, 6)

    return padded_inputs, targets, input_lengths, target_lengths, words

In [3]:
from torch.utils.data import DataLoader
batch_size = 128

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

dataloaders = {"train": train_loader,
               "val": val_loader}

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

class SwipeToTextCTC(nn.Module):
    def __init__(self, input_size=6, conv_channels=32, hidden_size=128,
                 num_layers=2, output_size=27, bidirectional=True,
                 dropout=0.1):
        super(SwipeToTextCTC, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv1d(in_channels=input_size, out_channels=conv_channels, kernel_size=5, padding=2, stride=1),
            nn.ReLU(),
            nn.Dropout(p=dropout)
        )

        self.layer_norm1 = nn.LayerNorm(conv_channels)

        self.lstm = nn.LSTM(
            input_size=conv_channels,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=False,
            bidirectional=bidirectional,
            dropout=dropout
        )

        lstm_out_dim = 2 * hidden_size if bidirectional else hidden_size
        self.layer_norm2 = nn.LayerNorm(lstm_out_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(lstm_out_dim, output_size)

    def forward(self, x, input_lengths):
        """
        x: (B, T, 6) = (batch, sequence length, feature dim)
        output: (T, B, output_size)
        """
        x = x.permute(0, 2, 1)         # (B, 6, T)
        x = self.conv(x)               # (B, conv_channels, T)
        x = x.permute(2, 0, 1)         # (T, B, conv_channels)

        x = self.layer_norm1(x)
        lstm_in = pack_padded_sequence(x, input_lengths.cpu(), batch_first=False)
        lstm_outputs, _ = self.lstm(lstm_in)   
        lstm_out = pad_packed_sequence(lstm_outputs, batch_first=False)[0]  

        lstm_out = self.layer_norm2(lstm_out)
        lstm_out = self.dropout(lstm_out)
        logits = self.fc(lstm_out)     # (T, B, output_size)
        
        return F.log_softmax(logits, dim=-1)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

class SwipeToTextCTC(nn.Module):
    def __init__(self, input_size=6, conv_channels=32, kernel_size = 5,
                 hidden_size=128, num_layers=2, output_size=27, bidirectional=True,
                 dropout=0.1):
        super(SwipeToTextCTC, self).__init__()

        padding = kernel_size // 2 

        self.conv = nn.Sequential(
            nn.Conv1d(in_channels=input_size, out_channels=conv_channels, kernel_size=kernel_size, padding=padding, stride=1),
            nn.ReLU(),
        )

        self.norm_drop1 = nn.Sequential(
            nn.LayerNorm(conv_channels),
            nn.Dropout(p=dropout)
        )

        self.lstm = nn.LSTM(
            input_size=conv_channels,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=False,
            bidirectional=bidirectional,
            dropout=dropout
        )

        lstm_out_dim = 2 * hidden_size if bidirectional else hidden_size

        self.norm_drop2 = nn.Sequential(
            nn.LayerNorm(lstm_out_dim),
            nn.Dropout(p=dropout)
        )

        self.fc = nn.Linear(lstm_out_dim, output_size)

    def forward(self, x, input_lengths):
        """
        x: (B, T, 6) = (batch, sequence length, feature dim)
        output: (T, B, output_size)
        """
        x = x.permute(0, 2, 1)         # (B, 6, T)
        x = self.conv(x)               # (B, conv_channels, T)
        x = x.permute(2, 0, 1)         # (T, B, conv_channels)

        x = self.norm_drop1(x)
        lstm_in = pack_padded_sequence(x, input_lengths.cpu(), batch_first=False)
        lstm_outputs, _ = self.lstm(lstm_in)   
        lstm_out = pad_packed_sequence(lstm_outputs, batch_first=False)[0]  

        lstm_out = self.norm_drop2(lstm_out)
        logits = self.fc(lstm_out)     # (T, B, output_size)
        
        return F.log_softmax(logits, dim=-1)

In [5]:
import os
import torch
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

model_file = "conv_32_k11_128_4.pt"
save_path = os.path.join(os.getcwd(), "models", model_file)

t_model = torch.load(save_path, weights_only=False).to(device)

In [6]:
t_model

SwipeToTextCTC(
  (conv): Sequential(
    (0): Conv1d(6, 32, kernel_size=(11,), stride=(1,), padding=(5,))
    (1): ReLU()
  )
  (norm_drop1): Sequential(
    (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (1): Dropout(p=0.3, inplace=False)
  )
  (lstm): LSTM(32, 128, num_layers=4, dropout=0.3, bidirectional=True)
  (norm_drop2): Sequential(
    (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (1): Dropout(p=0.3, inplace=False)
  )
  (fc): Linear(in_features=256, out_features=27, bias=True)
)

In [7]:
def test_model(model, test_data, datapoints):
    criterion = torch.nn.CTCLoss(blank=0, zero_infinity=True)

    model_outputs = []
    ground_truth = []
    # set the model into evaluation mode
    model.eval()

    running_loss = 0.0

    # Iterate over data.
    for inputs, targets, input_lengths, target_lengths, words in test_data:
        inputs = inputs.to(device)
        targets = targets.to(device)

        input_lengths = input_lengths.to(device)
        target_lengths = target_lengths.to(device)

        outputs = model(inputs, input_lengths)
        model_outputs.append(outputs.cpu())
        ground_truth.append(words)
        loss = criterion(outputs, targets, input_lengths, target_lengths)
        
        running_loss += loss * len(target_lengths)  # multiply by batch size

    avg_loss = running_loss / datapoints    # average over the entire test set

    print(f"Average test loss: {avg_loss}")

    return model_outputs, ground_truth

In [8]:
logits, truth = test_model(t_model, test_loader, len(test_set))

Average test loss: 1.1674144268035889


In [9]:
vocabulary = {'_': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8,
              'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16,
              'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24,
              'y': 25, 'z': 26}
reversed_vocab = {k: u for u, k in vocabulary.items()}

In [10]:
import torch
import torch.nn.functional as F
import heapq

def beam_search(logits, ground_truth, beam_width=3):
    """
    Decodes a tensor of logits using beam search, adapted for the provided input structure.

    Args:
        logits (torch.Tensor): A tensor of logits with shape (sequence_length, batch_size, vocab_size).
        ground_truth (list): A list of ground truth words (not directly used in beam search but included for consistency).
        reversed_vocab (dict): A dictionary mapping vocabulary indices to characters.
        beam_width (int): The width of the beam.

    Returns:
        list: A list of decoded strings, one for each sequence in the batch.
    """
    decoded_words = []
    for batch, _ in zip(logits, ground_truth):  # ground_truth is not used in the beam search.
        batched_results = []

        for b in range(batch.shape[1]):
            initial_beam = [([(0.0, "")], 0.0)]
            final_beams = []

            for t in range(batch.shape[0]):
                new_beam = []
                for seq, total_log_prob in initial_beam:
                    timestep_logits = batch[t, b, :]

                    topk_probs, topk_indices = torch.topk(timestep_logits, beam_width)

                    for i in range(beam_width):
                        char_index = topk_indices[i].item()
                        char = reversed_vocab.get(char_index, '_')
                        new_seq = seq + [(topk_probs[i].item(), char)]
                        new_total_log_prob = total_log_prob + topk_probs[i].item()
                        new_beam.append((new_seq, new_total_log_prob))

                initial_beam = heapq.nlargest(beam_width, new_beam, key=lambda x: x[1])

            best_sequence, _ = max(initial_beam, key=lambda x: x[1])
            decoded_word = "".join([char for log_prob, char in best_sequence[1:] if char != "_"])
            batched_results.append(decoded_word)

        decoded_words.append(batched_results)

    return decoded_words


In [11]:
beam_words = beam_search(logits, truth)

In [12]:
def evaluate(predictions, truth):
    exact_match = 0
    length_match = 0
    same_first = 0
    total = 0
    
    for i in range(len(predictions)):
        for j in range(len(predictions[i])):
            pred = predictions[i][j]
            ground = truth[i][j]
            total += 1

            if pred == ground:
                exact_match += 1
            
            if len(pred) == len(ground):
                length_match += 1
            
            if pred[0] == ground[0]:
                same_first += 1
    

    EM = exact_match / total
    LM = length_match / total
    FM = same_first / total

    print(f"Exact Match: {EM}")
    print(f"Correct Length: {LM}")
    print(f"Correct First Letter: {FM}")



In [13]:
evaluate(beam_words, truth)

Exact Match: 0.2648444863336475
Correct Length: 0.4934024505183789
Correct First Letter: 0.735626767200754


In [14]:
from transformers import T5ForConditionalGeneration, AutoTokenizer

path_to_model = "ai-forever/T5-large-spell"

model = T5ForConditionalGeneration.from_pretrained(path_to_model).to(device)
tokenizer = AutoTokenizer.from_pretrained(path_to_model)

In [15]:
from tqdm import tqdm

prefix = "grammar: "

def autocorrect(predictions):
    autocorrected_words = []

    for batch in tqdm(predictions):
        sentences = [prefix + word for word in batch]
        encodings = tokenizer(sentences, return_tensors="pt", padding=True)
        encodings = encodings.to(device)
        generated_tokens = model.generate(**encodings)
        answers = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        autocorrected_words.append(answers)

    return autocorrected_words

In [16]:
final_predictions = autocorrect(beam_words)

100%|██████████| 17/17 [02:59<00:00, 10.56s/it]


In [17]:
def evaluate_autocorrect(predictions, truth):
    exact_match = 0
    total = 0
    
    for i in range(len(predictions)):
        for j in range(len(predictions[i])):
            pred = predictions[i][j].lower()
            ground = truth[i][j]
            total += 1

            if pred == ground:
                exact_match += 1
    

    EM = exact_match / total

    print(f"Exact Match: {EM}")

In [18]:
evaluate_autocorrect(final_predictions, truth)

Exact Match: 0.10980207351555137
