In [1]:
import json, nltk, re, time, string
import pandas as pd
import numpy as np
from gensim.models import KeyedVectors
from gensim.models import Word2Vec
import gensim.downloader as api
import nlpaug.augmenter.char as nac
from tqdm import tqdm

import torch
import torch.nn as nn

In [4]:
all_letters = string.ascii_letters + " .,;'-" + "0123456789"
n_letters = len(all_letters) + 1 # Plus EOS marker

abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'-0123456789 69


In [None]:
df = pd.read_csv('./clean_dataset.csv')
df['code'] = df['code'].apply(lambda code: code.replace(" ", ""))

In [None]:
def word_augmentation(df, n):
    aug = nac.KeyboardAug()
    aug_df = pd.DataFrame(columns=['input','target','code'])
    
    for i in tqdm(range(len(df))):
        words = df.iloc[i]
        for j in range(n):
            augmented_data = aug.augment(words["input"])
            aug_df = aug_df.append({ "input": augmented_data, "target": words["target"], "code": words["code"] }, ignore_index=True)
    
    return df.append(aug_df)

In [12]:
class SpellCheckerRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SpellCheckerRNN, self).__init__()
        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.o2o = nn.Linear(hidden_size + output_size, output_size)
        self.dropout = nn.Dropout(0.1)
        self.softmax = nn.LogSoftmax(dim=1)
        self.criterion = nn.NLLLoss()

    def forward(self, input, hidden):
        input_combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(input_combined)
        output = self.i2o(input_combined)
        output_combined = torch.cat((hidden, output), 1)
        output = self.o2o(output_combined)
        output = self.dropout(output)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)
    
        
    def train_step(input_line_tensor, target_line_tensor, lr=0.0005):
        target_line_tensor.unsqueeze_(-1)
        hidden = self.initHidden()

        self.zero_grad()
        loss = 0

        for i in range(input_line_tensor.size(0)):
            output, hidden = self.forward(input_line_tensor[i], hidden)
            l = criterion(output, target_line_tensor[i])
            loss += l

        loss.backward()

        for p in self.parameters():
            p.data.add_(p.grad.data, alpha=-lr)

        return output, loss.item() / input_line_tensor.size(0)
    
    def train(epochs, lr):
        # print_every = 5000
        plot_every = epochs // 50
        all_losses = []
        total_loss = 0 # Reset every plot_every iters

        start = time.time()

        for iter in tqdm(range(1, epochs + 1)):
            output, loss = self.train_step("achraf", lr)
            total_loss += loss

            if iter % print_every == 0:
                #print('%s (%d %d%%) %.4f' % (timeSince(start), iter, iter / n_iters * 100, loss))
                print(start, iter, iter/epochs * 100, loss)

            if iter % plot_every == 0:
                all_losses.append(total_loss / plot_every)
                total_loss = 0

In [None]:
# Turn a Unicode string to plain ASCII
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

In [6]:
all_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1 # Plus EOS marker

In [13]:
rnn = SpellCheckerRNN(n_letters, 128, n_letters)

In [16]:
# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for li in range(len(line)):
        letter = line[li]
        tensor[li][0][all_letters.find(letter)] = 1
    return tensor

In [19]:
input =  inputTensor("achraf")
hidden = rnn.initHidden()
rnn(input,hidden)

RuntimeError: Tensors must have same number of dimensions: got 3 and 2