In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from rnn_dataset import Vocabulary
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence  # padding of every batch

In [2]:
class TextDataset(Dataset):
    def __init__(self, root_dir, filename, freq_threshold=1):
        self.root_dir = root_dir
        self.sentences = []
        with open(os.path.join(root_dir, filename)) as f:
            line = True
            while line:
                line = f.readline()
                if(line != "\n"):
                    self.sentences.append(line)

        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.sentences)

        self.vocab_size = len(self.vocab)

    def __len__(self): return len(self.sentences)

    def one_hot_tensor(self, idx):
        tensor = np.zeros(self.vocab_size)
        tensor[idx] = 1
        return tensor

    def __getitem__(self, idx):
        encoded_text = [self.one_hot_tensor(self.vocab.stoi["<SOS>"])]
        encoded_text += [self.one_hot_tensor(encoded_token)
                         for encoded_token in self.vocab.encode(self.sentences[idx])]
        encoded_text.append(self.one_hot_tensor(self.vocab.stoi["<EOS>"]))

        return torch.tensor(encoded_text).float()


class CollateBatch:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        texts = [item for item in batch]
        texts = pad_sequence(texts, batch_first=False,
                             padding_value=self.pad_idx)

        return texts


def get_loader(root_dir, filename, batch_size=10, num_workers=1, shuffle=True, pin_memory=True):
    dataset = TextDataset(root_dir, filename)
    pad_idx = dataset.vocab.stoi["<PAD>"]

    loader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers,
                        shuffle=shuffle, pin_memory=pin_memory, collate_fn=CollateBatch(pad_idx=pad_idx))
    return loader, dataset

In [4]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.hidden_size = hidden_size
        self.whh = nn.Linear(hidden_size, hidden_size)
        self.wxh = nn.Linear(input_size, hidden_size)
    
    def forward(self, x, hidden_state, prev_state):
        state = self.whh(hidden_state) + self.wxh(x)
        filter_state = torch.sigmoid(state)
        new_state = prev_state * filter_state + torch.tanh(state) * filter_state
        hidden_state = torch.tanh(new_state) * filter_state
        return hidden_state, new_state

In [11]:
class LSTMGenerator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMGenerator, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.lstm = LSTMCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def init_hidden_state(self, batch_size):
        return torch.zeros(batch_size, self.hidden_size)
    
    def forward(self, x, hidden_state, prev_state):
        for i in range(x.shape[0]):
            hidden_state, prev_state = self.lstm(x[i], hidden_state, prev_state)
        
        return F.softmax(self.fc(hidden_state), dim=1)

    def fit(self, dataset, batch_size, epochs, lr=0.001):
        optimizer = torch.optim.SGD(self.parameters(), lr=lr)
        criterion = torch.nn.MSELoss() # ignore_index=pad_idx?
        
        for epoch in range(epochs):
            for k, (sentence) in enumerate(dataset):
                if(len(sentence) < 4): continue
                total_loss = 0
                for idx in range(2,len(sentence)):
                    h0 = self.init_hidden_state(batch_size=batch_size)
                                        
                    # forward
                    output = self.forward(sentence[0:idx], h0, h0)
                    loss = criterion(output, sentence[idx])

                    # backward
                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1)

                    # gradient descent or Adam step
                    optimizer.step()

                    total_loss += loss
            
                if k%20 == 0: print(f"epoch: [{epoch+1} / {epochs}] | sentence: [{k} / {len(dataset)}] | total loss: {total_loss}")

In [28]:
def generate(generator, vocab, sentence, horizon=1):
    vocab_size = len(vocab)
    
    def one_hot_tensor(idx):
        tensor = [0] * vocab_size
        tensor[idx] = 1
        return tensor
    
    # encode input sentence
    encoded_text = []
    encoded_text.append([one_hot_tensor(vocab.stoi["<SOS>"])])
    encoded_text += [[one_hot_tensor(encoded_token)]
                     for encoded_token in vocab.encode(sentence)]

    encoded_text = torch.tensor(encoded_text).float()
    print(encoded_text.shape)
    h0 = generator.init_hidden_state(batch_size=1)

    new_words = []
    for _ in range(horizon):
        output = generator(encoded_text, h0, h0)
        word_index = torch.argmax(output).item()
        new_tensor = torch.tensor([[one_hot_tensor(word_index)]]).float()

        encoded_text = torch.cat((encoded_text, new_tensor), dim=0)
        new_words.append(vocab[word_index])
        
    return sentence + " " + " ".join(new_words)

In [36]:
BATCH_SIZE=1
HIDDEN_SIZE=124
LR=0.005
NUM_EPOCHS=20

In [34]:
dataloader,dataset = get_loader("../data/", "text.txt", batch_size=BATCH_SIZE)

In [37]:
generator = LSTMGenerator(input_size=dataset.vocab_size, hidden_size=HIDDEN_SIZE,
                                output_size=dataset.vocab_size)

In [38]:
generator.train()
generator.fit(dataloader, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS, lr=LR)

epoch: [1 / 20] | sentence: [0 / 67] | total loss: 0.047557100653648376
epoch: [1 / 20] | sentence: [20 / 67] | total loss: 0.04884263500571251
epoch: [1 / 20] | sentence: [40 / 67] | total loss: 0.0359894335269928
epoch: [1 / 20] | sentence: [60 / 67] | total loss: 0.08611826598644257
epoch: [2 / 20] | sentence: [0 / 67] | total loss: 0.021850328892469406
epoch: [2 / 20] | sentence: [20 / 67] | total loss: 0.02956286072731018
epoch: [2 / 20] | sentence: [40 / 67] | total loss: 0.061696507036685944
epoch: [2 / 20] | sentence: [60 / 67] | total loss: 0.02056482806801796
epoch: [3 / 20] | sentence: [0 / 67] | total loss: 0.08740343898534775
epoch: [3 / 20] | sentence: [20 / 67] | total loss: 0.08611826598644257
epoch: [3 / 20] | sentence: [40 / 67] | total loss: 0.012854075990617275
epoch: [3 / 20] | sentence: [60 / 67] | total loss: 0.07198044657707214
epoch: [4 / 20] | sentence: [0 / 67] | total loss: 0.026991358026862144
epoch: [4 / 20] | sentence: [20 / 67] | total loss: 0.0218503288

In [39]:
generator.eval()
generate(generator, dataset.vocab, "Do not think that", horizon=10)

torch.Size([5, 1, 777])


'Do not think that or made curation onslaught or onslaught or onslaught or onslaught'