In [1]:
import torch
import pandas as pd
import numpy as np
import re
import gensim
import collections
import pyarabic.araby as araby



# Data
### Helping Functions

In [2]:
def normalization(t):
    t = araby.strip_tashkeel(t)
    t = araby.normalize_hamza(t)
    t = araby.normalize_alef(t)
    t = araby.strip_tatweel(t)
    t = araby.normalize_teh(t)
    t = re.sub("ى","ي",t)
    return t

### Data

In [3]:
data= np.load('../translation project/AD_NMT-master/LAV-MSA-2-both.pkl',allow_pickle=True)

In [4]:
data[0] # lav , msa

['لا انا بعرف وحدة راحت ع فرنسا و معا شنتا حطت فيها الفرش',
 'لا اعرف واحدة ذهبت الى فرنسا و لها غرفة و ضعت فيها الافرشة']

In [5]:
# extract only msa text
msa=[]
for i,ex in enumerate(data):
    msa_text = normalization(ex[1])
    data[i][1] = msa_text
    msa.append(msa_text)

In [6]:
msa = ' '.join(msa)

Dictionaries

In [7]:
msa_d=collections.Counter(msa.split())

In [8]:
min_count = 2

In [9]:
idx2msa = np.array([word for word,freq in msa_d.items() if freq > min_count ])

In [10]:
msa2idx = {word:i for i,word in enumerate(idx2msa)}

In [11]:
msa_data = [' '.join([i for i in t[1].split() if (msa2idx.get(i,-1) != -1 and t[1] != '')]) for t in data]

In [12]:
msa_data = [i for i in msa_data if i != '']

Load Embeddings

In [13]:
t_model = gensim.models.Word2Vec.load('../resources/models/word vectors/word2vec/wiki/full_grams_cbow_100_wiki/full_grams_cbow_100_wiki.mdl')

In [14]:
i2l = list(set(normalization(araby.LETTERS)))
i2v = {}
for index,letter in enumerate(i2l):
    if letter in t_model.wv.index_to_key :
        i2v[index] = t_model.wv.get_vector(letter)

In [15]:
i2v.keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28])

In [16]:
i2l.append(' ')#Space
i2l.append('E')#Empty
i2l.append('X')#UNK
i2l.append('P')#pad

In [17]:
l2i = {v:i for i,v in enumerate(i2l)}

In [18]:
len(i2l)

33

Deep Learning

In [19]:
from torch.utils.data import Dataset,DataLoader

In [20]:
from torch.nn.utils.rnn import pad_sequence 

In [21]:
def noise(txt):
    sz = int(len(txt)*0.2)
    noise_sz = np.random.randint(0,sz if sz>1 else 1,1)
    replace_idx = np.random.choice(len(txt),noise_sz,replace=False)
    letters_idx = np.random.choice(len(i2l)-3,noise_sz,replace=True)
    txt = list(txt)
    for rep,let in zip(replace_idx,letters_idx):
        txt[rep] = i2l[let]
    return ''.join(txt)

In [22]:
class arrDs(Dataset):
    def __init__(self,txt_list,l2i):
        self.data = txt_list
        self.l2i = l2i
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        X = noise(self.data[idx])
        Y = self.data[idx]
        
        X = torch.tensor([self.l2i.get(i,31) for i in X])
        Y = torch.tensor([self.l2i.get(i,31) for i in Y])
        #numerilize
        return (X,Y)

In [23]:
trn_data, val_data = msa_data[:int(0.8*len(msa_data))],msa_data[int(0.8*len(msa_data)):]

In [24]:
trn_ds,val_ds = arrDs(trn_data,l2i),arrDs(val_data,l2i)

In [78]:
def collate_fn(data):
    label  = [i for _,i in data]
    label = pad_sequence(label,batch_first=True,padding_value=32)
    data = [i for i,_ in data]
    data = pad_sequence(data,batch_first=True,padding_value=32)
    return label,label

In [79]:
trn_dl = DataLoader(trn_ds,batch_size=256,collate_fn=collate_fn,drop_last=False)
val_dl = DataLoader(val_ds,batch_size=256,collate_fn=collate_fn,drop_last=False)

Model

In [80]:
import torch, torch.nn as nn

In [81]:
class autocorrect(nn.Module):
    def __init__(self,num_emb,vs,hs,bidirectional=True):
        super().__init__()
        self.emb = nn.Embedding(num_emb,vs)
        self.gru = nn.GRU(vs,hs,num_layers=3,bidirectional=bidirectional,batch_first=True,dropout=0.2)
        self.lin = nn.Linear(2*hs if bidirectional == True else hs,num_emb)
    def forward(self,x):
        bs,seq_len=x.shape
        x = self.emb(x)
        x,_ = self.gru(x)
        x = nn.functional.relu(x)
        x = self.lin(x)
        return torch.softmax(x,dim=-1).view(bs*seq_len,-1)

In [82]:
num_emb = len(i2l)

In [102]:
model = autocorrect(num_emb,100,128).cuda()

In [84]:
#Load available vectors
model.emb.weight.requires_grad_(False)
for i in i2v.keys():
    model.emb.weight[i] = nn.Parameter(torch.from_numpy(i2v[i].copy())).requires_grad_(False)
model.emb.weight.requires_grad_(True);

In [104]:
opt = torch.optim.Adam(model.parameters(),lr=1e-4)

In [105]:
loss = nn.CrossEntropyLoss()

In [106]:
def train(epoch,model,val_dl,trn_dl,loss_fnc):
    model.train()
    for i in range(epoch):
        model.train()
        for batch in trn_dl:
            opt.zero_grad()
            ip,label = batch
            op = model(ip.cuda())
            trn_l = loss_fnc(op,label.view(-1).cuda())
            trn_l.backward()
            opt.step()
        
        with torch.no_grad():
            model.eval()
            for batch in val_dl:
                ip,label = batch
                op = model(ip.cuda())
                val_loss = loss_fnc(op,label.view(-1).cuda())
        print('train_loss ->',trn_l.item() , 'val_loss ->',val_loss.item())

In [107]:
train(5,model,trn_dl,val_dl,loss)

train_loss -> 3.4525818824768066 val_loss -> 3.435774803161621
train_loss -> 3.3414692878723145 val_loss -> 3.294494152069092
train_loss -> 3.15637469291687 val_loss -> 3.087650775909424
train_loss -> 3.0223097801208496 val_loss -> 2.953524589538574
train_loss -> 2.9499573707580566 val_loss -> 2.881425380706787


In [90]:
d = next(iter(val_dl))

In [91]:
op = model(d[0].cuda())

In [92]:
d[0]

tensor([[12, 16, 26,  ..., 32, 32, 32],
        [26, 10, 11,  ..., 32, 32, 32],
        [ 4, 19, 29,  ..., 32, 32, 32],
        ...,
        [16, 10, 27,  ..., 32, 32, 32],
        [19, 16, 29,  ..., 29,  4, 18],
        [18, 16,  2,  ..., 32, 32, 32]])

In [93]:
op = torch.argmax(op.view(256,-1,33),dim=-1)

In [94]:
sent = []
for i in op:
    sent.append(''.join([i2l[j] for j in i]))

In [95]:
sent

['           PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 '                  PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 '       PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 '                                       PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 '                                  PPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 '             

In [101]:
?nn.GRU

In [97]:
all=[]
for i in d:
    sent = []
    for j in i:
        sent.append(''.join([i2l[k] for k in j]))
    all.append(sent)

In [99]:
all[1]

['ساعود يا سارهPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 'عندما تذكرت كدت ابكيPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 'هل انت بخيرPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 'لقد اصبحت الدراسه في وقتنا الحالي صعبه جداPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 'لو ان المرقي غير متمكن سيهزا به الجنPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 'انه مفيد لكنه

In [106]:
all[1]

['ساعود يا سارهPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 'عندما تذكرت كدت ابكيPPPPPPPPPPPPPPPPPPPPPP',
 'هل انت بخيرPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP',
 'لقد اصبحت الدراسه في وقتنا الحالي صعبه جدا']