In [1]:
from torch import nn
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


cuda:0


In [2]:
import wandb
wandb.init(project="ferdousi-generator")

[34m[1mwandb[0m: Currently logged in as: [33msoroushtabesh[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
config = wandb.config
config.max_epochs = 20
config.batch_size = 256
config.sequence_length = 6
config.log_interval = 10

In [4]:


class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 256
        self.embedding_dim = 256
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)
        self.to(device)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).cuda(),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).cuda())

In [5]:
import torch
import pandas as pd
from collections import Counter


class Dataset(torch.utils.data.Dataset):
    def __init__(
            self,
            args,
    ):
        self.args = args
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        with open('../data/ferdousi_norm.txt') as f:
            lines = f.readlines()
        lines = [line.strip() for line in lines]
        lines = ['__bom__ ' + line if i % 2 == 0 else '__bos__ ' + line for i, line in
                 enumerate(lines)]
        words = [word for line in lines for word in line.split()]
        return words

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.args.sequence_length

    def __getitem__(self, index):
        tensors = (
            torch.tensor(self.words_indexes[index:index + self.args.sequence_length]).cuda(),
            torch.tensor(self.words_indexes[index + 1:index + self.args.sequence_length + 1]).cuda(),
        )
        return tensors

In [6]:
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader


def train(dataset, model, args):
    wandb.watch(model)
    model.train()

    dataloader = DataLoader(dataset, batch_size=args.batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print({'batch_count': len(dataloader), 'epoch_count': args.max_epochs})
    for epoch in range(args.max_epochs):
        state_h, state_c = model.init_state(args.sequence_length)
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()
            print({'epoch': epoch, 'batch': batch, 'loss': loss.item()})
            if batch % args.log_interval == 0:
                wandb.log({"loss": loss})

In [7]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).cuda()
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [8]:
dataset = Dataset(config)
# get first 10 items in dataset
for i in range(10):
    print(dataset[i])

(tensor([  0,   3,  81, 363, 118,   2], device='cuda:0'), tensor([  3,  81, 363, 118,   2,  98], device='cuda:0'))
(tensor([  3,  81, 363, 118,   2,  98], device='cuda:0'), tensor([ 81, 363, 118,   2,  98,   1], device='cuda:0'))
(tensor([ 81, 363, 118,   2,  98,   1], device='cuda:0'), tensor([363, 118,   2,  98,   1, 365], device='cuda:0'))
(tensor([363, 118,   2,  98,   1, 365], device='cuda:0'), tensor([118,   2,  98,   1, 365, 698], device='cuda:0'))
(tensor([118,   2,  98,   1, 365, 698], device='cuda:0'), tensor([  2,  98,   1, 365, 698, 221], device='cuda:0'))
(tensor([  2,  98,   1, 365, 698, 221], device='cuda:0'), tensor([  98,    1,  365,  698,  221, 3552], device='cuda:0'))
(tensor([  98,    1,  365,  698,  221, 3552], device='cuda:0'), tensor([   1,  365,  698,  221, 3552,    0], device='cuda:0'))
(tensor([   1,  365,  698,  221, 3552,    0], device='cuda:0'), tensor([ 365,  698,  221, 3552,    0,  363], device='cuda:0'))
(tensor([ 365,  698,  221, 3552,    0,  363], devi

In [9]:
model = Model(dataset)

train(dataset, model, config)

{'epoch': 7, 'batch': 363, 'loss': 4.209418773651123}
{'epoch': 7, 'batch': 364, 'loss': 3.847109079360962}
{'epoch': 7, 'batch': 365, 'loss': 3.8111469745635986}
{'epoch': 7, 'batch': 366, 'loss': 3.6310434341430664}
{'epoch': 7, 'batch': 367, 'loss': 3.771148443222046}
{'epoch': 7, 'batch': 368, 'loss': 3.8145344257354736}
{'epoch': 7, 'batch': 369, 'loss': 3.7349956035614014}
{'epoch': 7, 'batch': 370, 'loss': 3.785506248474121}
{'epoch': 7, 'batch': 371, 'loss': 3.9743564128875732}
{'epoch': 7, 'batch': 372, 'loss': 3.65920090675354}
{'epoch': 7, 'batch': 373, 'loss': 3.840552568435669}
{'epoch': 7, 'batch': 374, 'loss': 3.857278823852539}
{'epoch': 7, 'batch': 375, 'loss': 3.4766480922698975}
{'epoch': 7, 'batch': 376, 'loss': 3.699756622314453}
{'epoch': 7, 'batch': 377, 'loss': 3.8779518604278564}
{'epoch': 7, 'batch': 378, 'loss': 3.964479684829712}
{'epoch': 7, 'batch': 379, 'loss': 3.995656728744507}
{'epoch': 7, 'batch': 380, 'loss': 3.993419647216797}
{'epoch': 7, 'batch': 

KeyboardInterrupt: 

In [10]:

print('\n'.join(predict(dataset, model, text='__bom__ توانا بود هر که')))

__bom__
توانا
بود
هر
که
دریا
سپاس
خاک
و
دل
__bos__
وزو
سر
بزرگی
دلی
را
کسی
__bom__
فدای
تو
روشن
نه
نیکوترست
__bos__
سواران
بدخواه
سر
نگسلی
__bom__
نخواهم
که
بیند
بدین
داستان
__bos__
نگیرم
به
جز
دل
دلارای
بود
__bom__
سیاوش
سر
شاه
بر
چشم
شاه
__bos__
بیاراستن
بیکران
با
سپاه
__bom__
گر
ایدونک
هم
بار
بشتافتم
__bos__
تو
گویی
جوانست
و
هوش
اندر
آی
__bom__
نپذرفت
با
من
برو
زر
به
هم
__bos__
ز
فرمان
ما
برتر
از
شهریار
__bom__
سیاووش
پردانش
و
چون
شیر
بدار
__bos__
تو
باشد
ز
ما
چون
ترا
بازگرد
__bom__
یکی
سان
که
راند
چنین
نابکار
__bos__


In [6]:
torch.cuda.is_available()

True

In [11]:
# save torch model and configs
import time
torch.save({'model_state_dict':model.state_dict()}, f'../data/checkpoints/model_{time.time()}.pt')

In [12]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,██▆▅▅▆▅▃▃▄▃▃▃▃▂▃▃▃▃▃▃▃▃▂▁▃▃▂▂▂▂▂▂▃▂▁▁▁▁▁

0,1
loss,4.09747
