In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import numpy as np

import re
import random



import time
import math

from tqdm import tqdm
from transformers import BertTokenizer,DistilBertTokenizer
from transformers import  DistilBertModel,DistilBertConfig


import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [21]:
device

In [4]:
tokenizer = DistilBertTokenizer.from_pretrained("dbmdz/distilbert-base-turkish-cased")
language_model = DistilBertModel.from_pretrained("dbmdz/distilbert-base-turkish-cased").to(device)

In [22]:
with open('../input/tr-corpus-sent-txt/datasentbysent.txt') as f:
    lines = f.readlines()

In [23]:
class Dataset(): #maxlen ekle padding de yap
    def __init__(self,sentence,tokenizer,max_seq_len,device):
        self.sentence = sentence 
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.device = device
    def __getitem__(self, index: int):
        sent = self.sentence[index]
        return (self.tokenizer(''.join(corrupt_sentence(sent)),max_length=self.max_seq_len,
                          padding='max_length',
                          return_tensors="pt").to(self.device),
            self.tokenizer(sent,max_length=self.max_seq_len,
                          padding='max_length',
                          return_tensors="pt").to(self.device)
)
    def __len__(self) -> int:
        return len(self.sentence)

In [24]:
pattern = '[?.\n\t&!]'
#lines_ = [re.split(pattern,i) for i in lines]
noktali_sesli = {"ı":"i","ö":"o","ü":"u","ğ":"g","ç":"c","ş":"s"}
vowel = ["a","e","ı","i","o","ö","u","ü","A","E","I","İ","O","Ö","U","Ü"]
vowel_str = "aeıioöuü"


def containsVowels(string):
    string = string.lower()
    for char in string:
        if char in vowel_str:
            return True
    return False



def dot_transform(word):
    for char, replacement in noktali_sesli.items():
        word = re.sub(char, replacement, word)
    return word


def random_drop(word):
    idx = random.randint(0,len(word)-1)
    word = word[:idx] + word[idx+1:]
    return word


def duplicate_char(word,max_dup = 3):
    idx = random.randint(0,len(word)-1)
    random_c = random.randint(1,max_dup)
    word = word[:idx] + word[idx]*random_c + word[idx+1:]
    return word

def drop_vowel(word):
    idx = []
    if not(containsVowels(word)):
        return word
    
    for j,i in enumerate(word):
        if i in vowel:
            idx.append(j)
            
    #print(word,idx)
    drop_idx = idx[random.randint(0,len(idx)-1)]
    word = word = word[:drop_idx] + word[drop_idx+1:]
    return word

def drop_all_vowels(word):
    idx = []
    
    if not(containsVowels(word)):
        return word
    for j,i in enumerate(word):
        if i in vowel:
            idx.append(j)
    for j,drop_idx in enumerate(idx):
        word = word[:drop_idx-j] + word[drop_idx+1-j:]
    return word
list_of_corruptions = [dot_transform,random_drop,drop_vowel,duplicate_char,drop_all_vowels]
def corrupt_word(word, corruption_prob= 0.8):
    if len(word) < 2:
        return word
    prob = random.randint(1,10)/10
    corruptions = []
    while prob < corruption_prob:
        corruptions.append(list_of_corruptions[random.randint(0,len(list_of_corruptions))-1])
        prob = prob*2
    for i in corruptions:
        word = i(word)
    return word
        
  
def corrupt_sentence(sentence,corruption_prob_w= 0.8,corruption_prob_s = 0.5):
    words = sentence.split(" ")

    corrupt_num = int(len(words)*corruption_prob_s)
    
    c_idx = random.sample(range(0,len(words)-1),corrupt_num)
    
    for i in c_idx:
        try:
            words[i] = corrupt_word(words[i])+' '
        except:
            pass
    return words


In [25]:
class Transformer(nn.Module):
    
    def __init__(self,d_model,n_heads,
                 device,tokenizer,
                 vocab_size,
                 num_encoder_layers = 2,
                 num_decoder_layers = 2,
                 max_seq_len = 50,
                 
                ):
        
        super().__init__()
        
        self.transformer = nn.Transformer(d_model=d_model,
                                          nhead=n_heads,
                                          device = device,
                                          num_encoder_layers = num_encoder_layers,
                                          num_decoder_layers = num_decoder_layers,
                                          dim_feedforward = 1024,
                                          #batch_first=
                                         )
        self.tokenizer = tokenizer
        self.pos_embed = nn.Parameter(torch.randn(max_seq_len,d_model))
        self.max_seq_len = max_seq_len
        self.embedding = nn.Embedding(num_embeddings=vocab_size,embedding_dim=d_model,)
        self.norm = nn.LayerNorm(d_model)
        self.output = nn.Linear(d_model,vocab_size)
        
        
    def forward(self,src_input_ids,src_att_mask,trg_input_ids,space_shift):
        

        
        embeds_trg = (self.embedding(trg_input_ids) + self.pos_embed + space_shift).permute(1, 0, 2)
        embeds_src = (self.embedding(src_input_ids) + self.pos_embed + space_shift).permute(1, 0, 2)
        
        
        transformer_outs = self.transformer(embeds_src,embeds_trg) # b x N x d
        # b x N x e
        
        out = self.norm(transformer_outs)
        out = self.output(out)
        return out
    

        
    
    
    def num_of_parameters(self,):

        return sum(p.numel() for p in self.parameters())

def encode(sentence,corrupt=False):

    


    return tokenizer(sentence,max_length=max_seq_len,
                          padding='max_length',
                          return_tensors="pt").to(device)


In [43]:
### HYPERPARAMETERS
max_seq_len = 250
LR = 0.0031
n_iter = 10000
batch_size = 16
shuffle = True
update_bar_per_iter = 1



In [44]:


trans = Transformer(d_model = 768,n_heads = 4,
                    device = device,vocab_size = tokenizer.vocab_size,
                    tokenizer = tokenizer,
                    max_seq_len = max_seq_len,
                    ).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index = 0)
optim = torch.optim.Adam(trans.parameters(),lr=LR)


In [45]:
trans.num_of_parameters()

In [46]:
class Trainer:
    
    
    def __init__(self,model,loss_fn,optimizer,language_model,param_init,):
        
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.language_model = language_model
        
        if param_init:
            self.init_param()
        
        
        
    def train(self,data,shuffle = True,n_iter = 1000,update_bar_per_iter = 100):
        
        self.model.train()
        self.language_model.eval()
        bar = tqdm(range(n_iter),desc = f"Epoch -> {0} Loss -> {0.000} ")
        
        general_loss = 0.0
        loss_per_update = 0.0
        best_loss = 10e+9
        batch_size = next(iter(dataloader))[0]['input_ids'].shape[0]
        self.model2download = self.model.state_dict()
        for i in bar:
            
                #index = random.randint(0,len(data))
                src,trg = next(iter(dataloader))
                
                self.optimizer.zero_grad()

                src_input_ids = src['input_ids'].squeeze(1)
                src_att_mask = src['attention_mask'].squeeze(1)
                trg_input_ids = trg['input_ids'].squeeze(1)

                 #['input_ids']
                with torch.no_grad():
                    space_shift = self.language_model(input_ids=src_input_ids,
                                                      attention_mask=src_att_mask)[0][:,0].unsqueeze(1)


                
                preds = self.model(src_input_ids=src_input_ids,src_att_mask=src_att_mask,
                      trg_input_ids=trg_input_ids,space_shift=space_shift)
                
                
                loss = self.loss_fn(preds.permute(1,2,0),trg_input_ids)
                
                general_loss += loss.item()
                loss_per_update += loss.item()
                loss.backward()
                self.optimizer.step()
                
                
                
                
                if not(i % update_bar_per_iter):
                    bar.set_description(f"Iteration-> {i+1} | Loss -> {general_loss /((i+1)*batch_size)}")
                    if loss_per_update < best_loss:
                        best_loss = loss_per_update
                        loss_per_update = 0.0
                        self.model2download = self.model.state_dict()
        
    
        
    def init_param(self,):
        
        for x in self.model.parameters():
            nn.init.normal_(x)
            
        
        
    def download_best_model(self,PATH):
        torch.save(self.model2download, PATH)
    
    def load_pretrained_model(self,PATH):
        self.model.load_state_dict(torch.load(PATH))
             

In [47]:
data = Dataset(lines,tokenizer=tokenizer,max_seq_len=max_seq_len,device=device)
dataloader = torch.utils.data.DataLoader(dataset=data,batch_size=8,shuffle=shuffle)

In [48]:
trainer = Trainer(model=trans,loss_fn=loss_fn,optimizer=optim,language_model = language_model,param_init=True)

In [49]:
trainer.train(data=dataloader,n_iter=n_iter,update_bar_per_iter=update_bar_per_iter)

In [54]:
torch.save(trans.state_dict(), "./model.pth")
