In [1]:
import torch
import pandas as pd
import numpy as np
import re
import gensim
import collections
import pyarabic.araby as araby



# Data
### Helping Functions

In [2]:
def normalization(t):
    t = araby.strip_tashkeel(t)
    t = araby.normalize_hamza(t)
    t = araby.normalize_alef(t)
    t = araby.strip_tatweel(t)
    t = araby.normalize_teh(t)
    t = re.sub("ى","ي",t)
    return t

### Data

In [3]:
data= np.load('../translation project/AD_NMT-master/LAV-MSA-2-both.pkl',allow_pickle=True)

In [4]:
data[0] # lav , msa

['لا انا بعرف وحدة راحت ع فرنسا و معا شنتا حطت فيها الفرش',
 'لا اعرف واحدة ذهبت الى فرنسا و لها غرفة و ضعت فيها الافرشة']

In [5]:
# extract only msa text
msa=[]
for i,ex in enumerate(data):
    msa_text = normalization(ex[1])
    data[i][1] = msa_text
    msa.append(msa_text)

In [6]:
msa = ' '.join(msa)

Dictionaries

In [7]:
msa_d=collections.Counter(msa.split())

In [8]:
min_count = 2

In [9]:
idx2msa = np.array([word for word,freq in msa_d.items() if freq > min_count ])

In [10]:
msa2idx = {word:i for i,word in enumerate(idx2msa)}

In [11]:
msa_data = [' '.join([i for i in t[1].split() if (msa2idx.get(i,-1) != -1 and t[1] != '')]) for t in data]

In [12]:
msa_data = [i.replace('','') for i in msa_data if i != '']

Load Embeddings

In [13]:
t_model = gensim.models.Word2Vec.load('../resources/models/word vectors/word2vec/wiki/full_grams_cbow_100_wiki/full_grams_cbow_100_wiki.mdl')

In [83]:
i2l = list(set(normalization(araby.LETTERS)))
i2v = {}
for index,letter in enumerate(i2l):
    if letter in t_model.wv.index_to_key :
        i2v[index] = t_model.wv.get_vector(letter)

In [84]:
i2v.keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28])

In [85]:
i2l.append(' ')
#i2l.append('')
i2l.append('X')

In [86]:
l2i = {v:i for i,v in enumerate(i2l)}

Deep Learning

In [87]:
from torch.utils.data import Dataset,DataLoader

In [88]:
from torch.nn.utils.rnn import pad_sequence 

In [89]:
def noise(txt):
    noise_sz = np.random.randint(0,len(txt),1)
    replace_idx = np.random.choice(len(txt),noise_sz,replace=False)
    letters_idx = np.random.choice(len(i2l),noise_sz,replace=True)
    txt = list(txt)
    for rep,let in zip(replace_idx,letters_idx):
        txt[rep] = i2l[let]
    return ''.join(txt)

In [107]:
class arrDs(Dataset):
    def __init__(self,txt_list,l2i):
        self.data = txt_list
        self.l2i = l2i
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        X = noise(self.data[idx])
        Y = self.data[idx]
        
        X = torch.tensor([self.l2i.get(i,30) for i in X])
        Y = torch.tensor([self.l2i.get(i,30) for i in Y])
        #numerilize
        return (X,Y)

In [108]:
trn_data, val_data = msa_data[:int(0.8*len(msa_data))],msa_data[int(0.8*len(msa_data)):]

In [109]:
trn_ds,val_ds = arrDs(trn_data,l2i),arrDs(val_data,l2i)

In [110]:
def collate_fn(data):
    label  = [i for _,i in data]
    label = pad_sequence(label,batch_first=True)
    data = [i for i,_ in data]
    data = pad_sequence(data,batch_first=True)
    return data,label

In [111]:
trn_dl = DataLoader(trn_ds,batch_size=4,collate_fn=collate_fn)
val_dl = DataLoader(val_ds,batch_size=4,collate_fn=collate_fn)

Model

In [112]:
import torch, torch.nn as nn

In [113]:
class autocorrect(nn.Module):
    def __init__(self,num_emb,vs,hs,bidirectional=True):
        super().__init__()
        self.emb = nn.Embedding(num_emb,vs)
        self.gru = nn.GRU(vs,hs,num_layers=2,bidirectional=bidirectional,batch_first=True)
        self.lin = nn.Linear(2*hs if bidirectional == True else hs,num_emb)
    def forward(self,x):
        bs,seq_len=x.shape
        x = self.emb(x)
        x,_ = self.gru(x)
        x = nn.functional.relu(x)
        x = self.lin(x)
        return torch.softmax(x,dim=-1).view(bs*seq_len,-1)

In [114]:
num_emb = len(i2l)

In [127]:
model = autocorrect(num_emb,100,128).cuda()

In [128]:
#Load available vectors
model.emb.weight.requires_grad_(False)
for i in i2v.keys():
    model.emb.weight[i] = nn.Parameter(torch.from_numpy(i2v[i])).requires_grad_(False)
model.emb.weight.requires_grad_(True)

Parameter containing:
tensor([[-2.0347,  0.3087, -2.3566,  ...,  0.3500, -0.1091,  0.2543],
        [-0.7845,  0.8250,  0.5404,  ..., -3.3356,  1.9303,  0.4786],
        [-0.8951,  1.9891, -2.1849,  ...,  0.0729,  0.0684,  5.5708],
        ...,
        [-2.0458,  0.2678, -2.0031,  ...,  1.6643,  1.2496, -1.9677],
        [ 0.3824,  0.4593,  2.9516,  ..., -0.6370, -0.6974,  1.7829],
        [ 0.3815, -1.6622,  0.0140,  ...,  0.1231, -0.5392,  0.4025]],
       device='cuda:0', requires_grad=True)

In [129]:
opt = torch.optim.Adam(model.parameters(),lr=1e-4)

In [130]:
loss = nn.CrossEntropyLoss()

In [135]:
def train(epoch,model,trn_dl,val_dl,loss_fnc):
    model.train()
    for i in range(epoch):
        model.train()
        for batch in trn_dl:
            opt.zero_grad()
            ip,label = batch
            op = model(ip.cuda())
            trn_l = loss_fnc(op,label.view(-1).cuda())
            trn_l.backward()
            opt.step()
        
        with torch.no_grad():
            model.eval()
            for batch in val_dl:
                ip,label = batch
                op = model(ip.cuda())
                val_loss = loss_fnc(op,label.view(-1).cuda())
        print('train_ loss ->',trn_l.item() , 'val_loss ->',val_loss.item())

In [136]:
train(5,model,trn_dl,val_dl,loss)

train_ loss -> 2.9865188598632812 val_loss -> 3.298280954360962
train_ loss -> 2.9467990398406982 val_loss -> 2.7887275218963623
train_ loss -> 2.875417947769165 val_loss -> 2.9136877059936523
train_ loss -> 2.90619158744812 val_loss -> 2.7545876502990723
train_ loss -> 2.9010863304138184 val_loss -> 2.7453982830047607


In [139]:
d = next(iter(val_dl))

In [145]:
for i in d:
    sent = []
    for j in i:
        sent.append([i2l[k] for k in j])

In [146]:
sent

[['س',
  'ا',
  'ع',
  'و',
  'د',
  ' ',
  'ي',
  'ا',
  ' ',
  'س',
  'ا',
  'ر',
  'ه',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك'],
 ['ع',
  'ن',
  'د',
  'م',
  'ا',
  ' ',
  'ت',
  'ذ',
  'ك',
  'ر',
  'ت',
  ' ',
  'ك',
  'د',
  'ت',
  ' ',
  'ا',
  'ب',
  'ك',
  'ي',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك'],
 ['ه',
  'ل',
  ' ',
  'ا',
  'ن',
  'ت',
  ' ',
  'ب',
  'خ',
  'ي',
  'ر',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك',
  'ك'],
 ['ل',
  'ق',
  'د',
  ' ',
  'ا',
  'ص',
  'ب',
  'ح',
  'ت',
  ' ',
  'ا',
  'ل',
  'د',
  'ر',
  'ا',
  'س',
  '