In [1]:
!nvidia-smi

Mon Jun  6 23:04:19 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.129.06   Driver Version: 470.129.06   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| N/A   62C    P0    N/A /  N/A |    710MiB /  2002MiB |     15%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [None]:
! pip install -q wandb
! git clone "https://github.com/amnghd/Persian_poems_corpus.git"
! mkdir "corpus"
! cp "Persian_poems_corpus/normalized/ferdousi_norm.txt" "Persian_poems_corpus/normalized/hafez_norm.txt" "Persian_poems_corpus/normalized/moulavi_norm.txt"
"./corpus/"


fatal: destination path 'Persian_poems_corpus' already exists and is not an empty directory.


In [56]:
import torch
from torch import nn, optim
import wandb
import pandas as pd
from collections import Counter
import os
import itertools
import numpy as np
from torch.utils.data import DataLoader
import time
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device, torch.cuda.is_available())


cpu False


In [None]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [57]:
class Config:
    pass


wandb_active = False
project_name = 'poem_generator'
run_name = 'all_poem_train'
checkpoints_dir = '../data/checkpoints/'
corpus_dir = '../data/poems/'
vocab_path = '../data/vocabulary.txt'

if wandb_active:
    wandb.init(project=project_name, name=run_name)
    config = wandb.config
else:
    config = Config()
config.batch_size = 256
config.embedding_size = 512
config.lstm_num_layers = 3
config.lstm_hidden_size = 512
config.sequence_length = 10
config.log_interval = 10
config.learning_rate = 0.001
config.vocab_size = 38590
config.lstm_dropout = 0.2

In [58]:
class Model(nn.Module):
    def __init__(self, config, device=torch.device('cpu')):
        super(Model, self).__init__()
        self.lstm_size = config.embedding_size
        self.lstm_hidden_size = config.lstm_hidden_size
        self.lstm_dropout = 0.2
        self.embedding_dim = config.embedding_size
        self.num_layers = config.lstm_num_layers
        self.device = device
        self.vocab_size = config.vocab_size
        self.embedding = nn.Embedding(
            num_embeddings=self.vocab_size,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_hidden_size,
            num_layers=self.num_layers,
            dropout=self.lstm_dropout,
        )
        self.fc = nn.Linear(self.lstm_size, self.vocab_size)
        self.to(device)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(self.device),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(self.device))

In [59]:
class PoemDataset(torch.utils.data.Dataset):
    def __init__(
            self,
            config,
            device=torch.device('cpu'),
            poet='ferdousi',
            corpus_dir='./Persian_poems_corpus/normalized',
            vocab_path='./vocabulary.txt'
    ):
        self.config = config
        self.device = device
        self.corpus_dir = corpus_dir
        self.vocab_path = vocab_path

        self.words_by_poet = self.load_words(corpus_dir)
        self.vocabulary = self.load_vocabulary()
        self.index_to_word = {index: word for index, word in enumerate(self.vocabulary)}
        self.word_to_index = {word: index for index, word in enumerate(self.vocabulary)}
        self.poet = poet

    def preprocess_lines(self, lines, mask_key):
        lines = [line.strip() for line in lines]
        lines = filter(lambda line: len(line) > 0, lines)
        lines = map(lambda line: line.replace('\n', ''), lines)
        lines = map(lambda line: line.replace('\t', ''), lines)
        lines = map(lambda line: line.replace('\r', ''), lines)
        lines = map(
            lambda index_line:
            f'[BOM_{mask_key}] ' + index_line[1] + ' [EOS]' if index_line[0] % 2 == 1
            else f'[BOM_{mask_key}] ' + index_line[1],
            enumerate(lines)
        )
        words = itertools.chain.from_iterable(map(lambda line: line.split(' '), lines))
        words = filter(lambda word: len(word) > 0, words)
        words = list(words)
        return words

    def load_words(self, corpus_dir):
        words_by_poet = {}
        for filename in os.listdir(corpus_dir):
            with open(os.path.join(corpus_dir, filename)) as f:
                poet_name = filename.split('_')[0]
                lines = f.readlines()
                words_by_poet[poet_name] = self.preprocess_lines(lines, poet_name)
        return words_by_poet

    def load_vocabulary(self):
        with open(self.vocab_path) as f:
            vocabulary = f.readlines()
        vocabulary = [word.strip() for word in vocabulary]
        return vocabulary

    @property
    def all_poets(self):
        return self.words_by_poet.keys()

    @property
    def poet(self):
        return self._poet

    @poet.setter
    def poet(self, poet):
        self._poet = poet
        if poet == 'all':
            self.words = list(itertools.chain.from_iterable(self.words_by_poet.values()))
        else:
            self.words = self.words_by_poet[poet]
        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def __len__(self):
        return len(self.words_indexes) - self.config.sequence_length

    def __getitem__(self, index):
        tensors = (
            torch.tensor(self.words_indexes[index:index + self.config.sequence_length]).to(self.device),
            torch.tensor(self.words_indexes[index + 1:index + self.config.sequence_length + 1]).to(self.device),
        )
        return tensors


dataset = PoemDataset(config, device=torch.device('cpu'), poet='all', corpus_dir=corpus_dir, vocab_path=vocab_path)

for i in range(10):
    print(dataset[i])

38590

In [None]:
# run this cell to generate new vocabulary

all_words = list(itertools.chain.from_iterable(dataset.words_by_poet.values()))
word_counts = Counter(all_words)
vocab = sorted(list(word_counts))
with open(vocab_path, 'w') as f:
    for word in vocab:
        f.write(word + '\n')

In [66]:
def train(dataset, model, config, checkpoint_path='../data/checkpoints', max_epochs=10, ):
    if wandb_active:
        wandb.watch(model)
    model.train()

    dataloader = DataLoader(dataset, batch_size=config.batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    print({'batch_count': len(dataloader), 'epoch_count': max_epochs})
    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(config.sequence_length)
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()
            print({'epoch': epoch, 'batch': batch, 'loss': loss.item()})
            if wandb_active and batch % config.log_interval == 0:
                wandb.log({"loss": loss})
        try:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, os.path.join(checkpoint_path, f'{run_name}_checkpoint_{epoch}_{time.time()}.pt'))
        except:
            pass

In [67]:
def predict(dataset, model, text, max_predict_length=12):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    i = 0
    while words[-1] != '[EOS]' and i < max_predict_length:
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])
        i += 1
    return words

# Model Usage

## fresh model training

In [68]:
model = Model(config, device)
dataset = PoemDataset(config, device=device, poet='all', corpus_dir=corpus_dir, vocab_path=vocab_path)

In [69]:
# train on generic dataset
train(dataset, model, config, checkpoint_path=checkpoints_dir)
if wandb_active:
    wandb.finish()

{'batch_count': 4789, 'epoch_count': 10}
{'epoch': 0, 'batch': 0, 'loss': 10.556665420532227}
{'epoch': 0, 'batch': 1, 'loss': 10.528473854064941}
{'epoch': 0, 'batch': 2, 'loss': 10.473825454711914}
{'epoch': 0, 'batch': 3, 'loss': 10.087091445922852}
{'epoch': 0, 'batch': 4, 'loss': 9.208667755126953}
{'epoch': 0, 'batch': 5, 'loss': 8.61754035949707}
{'epoch': 0, 'batch': 6, 'loss': 8.090521812438965}


KeyboardInterrupt: 

In [None]:
model_ferdousi = copy.deepcopy(model)
dataset.poet = 'ferdousi'
run_name = 'ferdousi_fine_tune'
if wandb_active:
    wandb.init(project=project_name, name=run_name, reinit=True)
train(dataset, model_ferdousi, config, checkpoint_path=checkpoints_dir)
if wandb_active:
    wandb.finish()

In [None]:
model_hafez = copy.deepcopy(model)
dataset.poet = 'hafez'
run_name = 'hafez_fine_tune'
if wandb_active:
    wandb.init(project=project_name, name=run_name, reinit=True)
train(dataset, model_hafez, config, checkpoint_path=checkpoints_dir)
if wandb_active:
    wandb.finish()

In [None]:
model_moulavi = copy.deepcopy(model)
dataset.poet = 'moulavi'
run_name = 'moulavi_fine_tune'
if wandb_active:
    wandb.init(project=project_name, name=run_name, reinit=True)
train(dataset, model_moulavi, config, checkpoint_path=checkpoints_dir)
if wandb_active:
    wandb.finish()

## load from checkpoint

In [None]:
model = Model(config, device)
chechkpoint = torch.load('/content/drive/MyDrive/NLP Class/checkpoints/model_checkpoint_1654448961.97679.pt',
                         map_location=torch.device('cpu'))
print(chechkpoint['epoch'])
model.load_state_dict(chechkpoint['model_state_dict'])

9


<All keys matched successfully>

In [47]:
print('\n'.join(predict(dataset, model, text='[BOM_ferdousi] توانا بود هر که')))

[BOM_ferdousi]
توانا
بود
هر
که
جانند
مستوی
سنگیان
پنداشت
مالامالیم
انگاشتیم
مقراضی
خضرم
چفسیده
التقا
شهرهار
خاکا
