# Preparing Data

In [124]:
from hazm import *
import string
import os
from pathlib import Path
import numpy as np

In [38]:
with open('ferdousi.txt', 'r', encoding='utf-8') as f:
    half_verses = f.read()
    half_verses = half_verses.split('\n')[2:]
len(half_verses)

99217

In [39]:
# to remove the last Mesra
half_verses = half_verses[:-1]
verses = []
for i in range(0, len(half_verses), 2):
    verses.append(f'{half_verses[i]} {half_verses[i+1]}')
len(verses)

49608

In [40]:
stopwords = []
replace_dict = {}
punctuations = '\.:!،؛؟»\]\)\}«\[\(\{' + string.punctuation

with open('stopwords.txt', 'r', encoding='utf-8') as f:
    var = f.readline()
    while var:
        stopwords.append(var.strip())
        var = f.readline()

with open('replace.txt', 'r', encoding='utf-8') as f:
    line = f.readline()
    while line:
        key, value = line.split('-')
        key, value = key.strip(), value.strip()
        replace_dict[f'{key}'] = f'{value}'
        line = f.readline()

In [41]:
normalizer = Normalizer()
stemmer = Stemmer()
lemmatizer = Lemmatizer()

def replace_function(string):
    if string in replace_dict:
      return replace_dict[string]
    return string

def sent_pre_process(sentence, normalize=True, remove_stopwords=False, stemme=False, lemmatize=True, replace=True, remove_punctuations=True, is_first=True):

    # replace some charachters
    replace_char = {'هٔ': 'ه',
                    'ۀ' : 'ه',
                    'ه‌ی' : 'ه'}
    
    if remove_punctuations:
      for char in punctuations:
        replace_char[char] = " "

    for key, value in replace_char.items():
        sentence = sentence.replace(key, value)

    if normalize:
        sentence = normalizer.normalize(sentence)
    if stemme:
        sentence = stemmer.stemme(sentence)
    if lemmatize:
        sentence = lemmatizer.lemmatize(sentence)
    

    tokens = word_tokenize(sentence)
    
    if replace:
        tokens = [replace_function(token) for token in tokens]
    if remove_stopwords:
        tokens = [token for token in tokens if token not in stopwords]
    
    if is_first:
        return sent_pre_process(" ".join(tokens), normalize, remove_stopwords, stemme, lemmatize, replace, remove_punctuations, False)
    
    return tokens

In [42]:
processed_verses = [sent_pre_process(verse) for verse in verses]

In [44]:
processed_verses[0]

['به', 'نام', 'خداوند', 'جان', 'و', 'خرد', 'کزین', 'برتر', 'اندیشه', 'برنگذرد']

In [67]:
start, end, pad, unkown = '<s>', '</s>', '<pad>', '<unk>'
# add special tokens
verse_tokens = [[start] + verse + [end] for verse in processed_verses]

In [68]:
all_tokens = set([word for verse in verse_tokens for word in verse])
all_tokens = [pad, unkown] + list(all_tokens)

# converting words to numbers
word2idx = {word: i for i, word in enumerate(all_tokens)}
idx2word = {i: word for word, i in word2idx.items()}

In [92]:
unkown_idx = word2idx[unkown]
pad_idx = word2idx[pad]
start_idx = word2idx[start]
end_idx = word2idx[end]

def verse_to_numbers(verse):
    return [word2idx.get(word, unkown_idx) for word in verse]

def numbers_to_verse(numbers):
    verse = []
    for number in numbers:
        if number == end_idx:
            break
        elif number == unkown_idx:
            verse.append(unkown)
        verse.append(idx2word[number])
    return verse

In [70]:
# convert verses to list of numbers
tokens_idx = [verse_to_numbers(verse) for verse in verse_tokens]

# Preparing Datasets

In [75]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.nn.utils.rnn import pad_sequence

In [94]:
X = tokens_idx
X = [torch.tensor(x) for x in X]
X = pad_sequence(X, batch_first=True, padding_value=pad_idx)
print(X.shape)
X = F.pad(X, (0, 5, 0, 0), value=pad_idx)
print(X.shape)

with torch.no_grad():
    X, Y = X.clone()[1:], X.clone()[:-1]
    train_set = TensorDataset(X, Y)

torch.Size([49608, 22])
torch.Size([49608, 27])


In [95]:
train_len = int(len(train_set) * 0.9)
val_len = len(train_set) - train_len
train_set, val_set = random_split(train_set, [train_len, val_len])

In [96]:
train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
val_loader = DataLoader(val_set, batch_size=128, shuffle=True)

# Model

In [132]:
class VersePredictor(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, model_type='lstm',
                        num_layers=1, bidirectional=False, dropout=0, max_T=30, pad_idx=pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.model_type = model_type
        if model_type == 'lstm':
            self.encoder = nn.LSTM(embedding_dim, hidden_dim, num_layers, bidirectional=bidirectional, dropout=dropout, batch_first=True)
            self.decoder = nn.LSTM(embedding_dim, hidden_dim, num_layers, bidirectional=bidirectional, dropout=dropout, batch_first=True)
        elif model_type == 'gru':
            self.encoder = nn.GRU(embedding_dim, hidden_dim, num_layers, bidirectional=bidirectional, dropout=dropout, batch_first=True)
            self.decoder = nn.GRU(embedding_dim, hidden_dim, num_layers, bidirectional=bidirectional, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.max_T = max_T
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y):
        print('hi')
        x = self.embedding(x)
        y = self.embedding(y)
        _, x = self.encoder(x)
        x = x[0] if self.model_type == 'lstm' else x
        x = self.dropout(x)
        x, _ = self.decoder(y, (x, torch.zeros_like(x)) if self.model_type == 'lstm' else x)
        x = self.fc(x)
        return x

    def loss(self, x, y):
        x = self(x, y)
        # drop the first word of y
        y = y[:, 1:]
        x = x[:, :-1]
        return F.cross_entropy(x.reshape(-1, x.shape[-1]), y.reshape(-1))

    def predict(self, x):
        x = self.embedding(x)
        _, x = self.encoder(x)
        x = x[0] if self.model_type == 'lstm' else x
        x = self.dropout(x)

        input_embd = self.embedding(torch.tensor([start_idx]))
        input_embds = input_embd.repeat(x.shape[0], 1, 1)
        result = torch.zeros(x.shape[0], self.max_T, dtype=torch.long)
        for i in range(self.max_T):
            x, _ = self.decoder(input_embds, (x, torch.zeros_like(x)) if self.model_type == 'lstm' else x)
            x = self.fc(x)
            idx = F.softmax(x).multinomial(1)[0]
            x = x[idx]
            result[:, i] = x.squeeze()
            input_embds = self.embedding(x).repeat(x.shape[0], 1, 1)
        return result

# Training

In [121]:
import matplotlib.pyplot as plt
from tqdm import tqdm

In [122]:
def train(model, train_loader, val_loader, optimizer, epochs, device, name='model.pt'):

    loss_history = {'train': [], 'val': []}
    min_val_loss = float('inf')

    for epoch in tqdm(range(epochs), total=epochs, desc='Epochs'):
        model.train()
        train_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = model.loss(x, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.shape[0]
        train_loss /= len(train_loader)
        
        model.eval()
        val_loss = 0
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            loss = model.loss(x, y)
            val_loss += loss.item() * x.shape[0]
        val_loss /= len(val_loader)

        if val_loss < min_val_loss:
            min_val_loss = val_loss
            torch.save(model.state_dict(), name)
        
        print(f'Epoch: {epoch + 1:02} | Train Loss: {train_loss:.3f} | Val Loss: {val_loss:.3f}')
        loss_history['train'].append(train_loss)
        loss_history['val'].append(val_loss)

    plt.plot(loss_history['train'], label='train')
    plt.plot(loss_history['val'], label='val')
    plt.legend()
    plt.show()

    return loss_history

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VersePredictor(len(word2idx), embedding_dim=128, hidden_dim=256).to(device)
lr = 0.001
epochs = 15
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
train(model, train_loader, val_loader, optimizer, epochs, device)

# Testing on Train and Validation Data

In [125]:
def print_samples(set, n=5):
    for i in range(n):
        idx = np.random.randint(len(set))
        x, y = set[idx]
        x, y = x.to(device), y.to(device)
        print('Input verse:')
        print(numbers_to_verse(x.cpu().numpy()))
        print('Next verse prediction:')
        for i in range(2):
            y_pred = model.predict(x.unsqueeze(0))
            print(numbers_to_verse(y_pred.cpu().numpy()[0]))
        print('Real next verse:')
        print(numbers_to_verse(y.cpu().numpy()))
        print()

In [None]:
# sample predictions in validation set
model.load_state_dict(torch.load('model.pt'))
model.eval()

In [None]:
print_samples(val_set)

In [None]:
print_samples(train_set)