In [None]:
import os
import re
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
df=pd.read_csv('word_annotation_2m.csv')
X_train, X_test, y_train, y_test = train_test_split(df['word'], df['label'],stratify=df['label'], test_size=0.1, random_state=42)

In [None]:
all_letters='بتثجحخدذرزسشصضطظعغفقكلمنهويءآٱأإةؤئىئ '
n_letters=len(all_letters)
train_data=list(zip(X_train,y_train))
test_data=list(zip(X_test,y_test))
device=torch.device('cpu')

In [None]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size,):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size).to(device)
        self.i2o = nn.Linear(input_size + hidden_size, output_size).to(device)
        self.sigmoid = nn.Sigmoid().to(device)
          
    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1).to(device)
        hidden = self.i2h(combined).to(device)
        output = self.i2o(combined).to(device)
        output = self.sigmoid(output).to(device)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size).to(device)


In [None]:
class ErrorDetectionModel:
  def __init__(self,input_size=50,hidden_size=128, output_size=1,ckp_path=None,train_dataloader=None,valid_dataloader=None):
    self.rnn = RNN(n_letters, hidden_size, output_size).to(device)
    self.criterion = nn.BCELoss()
    #load last_ckp
    if ckp_path:
      print('loading checkpoint saved at ',ckp_path)
      self.rnn.load_state_dict(torch.load(ckp_path))
    #set data
    self.train_dataloader=train_dataloader
    self.valid_dataloader=valid_dataloader

  def labelFromOutput(self,output):
    return int(output > 0.5)

  def save_checkpoint(self,path):
    torch.save(self.rnn.state_dict(), path)

  def predict(self,word_tensor_batch):
    output_batch_tensor=torch.zeros(len(word_tensor_batch),1).to(device)
    for idx, word_tensor in enumerate(word_tensor_batch):
      hidden = self.rnn.initHidden()
      for i in range(word_tensor.size()[0]):
          output, hidden = self.rnn(word_tensor[i], hidden)
      output_batch_tensor[idx]=output
    return output_batch_tensor
  
  def evaluate(self,test_dataloader):
    all_predictions=[]
    loss=0
    for words_batch , label_batch, label_tensor_batch, word_tensor_batch  in tqdm(test_dataloader):
      pred_batch = self.predict(word_tensor_batch)
      loss += self.criterion(pred_batch, label_tensor_batch)
      predicted_labels=[int(label==self.labelFromOutput(pred) ) for label, pred in zip(label_tensor_batch,pred_batch)]
      all_predictions+=predicted_labels
    avg_loss=loss/len(test_dataloader)
    acc=np.mean(all_predictions)
    return acc , avg_loss.item()


  def train(self,epochs=10,learning_rate=0.001,save_every=1):
    self.rnn.train()
    all_losses=[]
    min_valid_loss=1000
    optimizer = torch.optim.Adam(self.rnn.parameters(), lr=learning_rate)
    for epoch in range(1,epochs+1):
      current_loss=0
      for words_batch , label_batch, label_tensor_batch, word_tensor_batch  in tqdm(self.train_dataloader):
        output_batch_tensor=self.predict(word_tensor_batch)
        optimizer.zero_grad()
        loss = self.criterion(output_batch_tensor, label_tensor_batch)
        loss.backward()
        optimizer.step()
        current_loss+=loss.item()
        
      current_loss/=len(train_dataloader)
      all_losses.append(current_loss)

      if self.valid_dataloader:
        acc , current_valid_loss=self.evaluate(self.valid_dataloader)
        print('Epoch: {} | valid Loss= {:.3f}  | valid accuracy= {:.3f}'.format(epoch,current_valid_loss,acc))
        if current_valid_loss  < min_valid_loss:
          min_valid_loss=current_valid_loss
          self.save_checkpoint('ckp_best.pt')

      if epoch % save_every ==0 :
          self.save_checkpoint('ckp_{}.pt'.format(epoch))
      print('Epoch: {} | train loss= {:.3f}'.format(epoch,current_loss))
      self.save_checkpoint('ckp_last.pt'.format(epoch))
    return all_losses

In [None]:
class WordsDataset(Dataset):

    def __init__(self,data,max_len=16):
      self.data=data
      self.max_len=max_len

    def letterToIndex(self,letter):
        return all_letters.find(letter)

    def letterToTensor(self,letter):
        tensor = torch.zeros(1, n_letters)
        tensor[0][self.letterToIndex(letter)] = 1
        return tensor

    def lineToTensor(self,line):
        tensor = torch.zeros(self.max_len, 1, n_letters)
        for i in range(min(self.max_len,len(line))):
            tensor[i][0][self.letterToIndex(line[i])] = 1
        return tensor.to(device)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        line , label = self.data[idx]
        line_tensor = self.lineToTensor(line)
        label_tensor=torch.tensor([label], dtype=torch.float).to(device)
        return   line , label, label_tensor, line_tensor

In [None]:
training_data=WordsDataset(train_data)
test_dataset=WordsDataset(test_data)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [None]:
model=ErrorDetectionModel(input_size=n_letters,hidden_size=32,train_dataloader=train_dataloader,valid_dataloader=test_dataloader)
model.train()
model.evaluate(test_dataloader)