In [None]:
import os

import torch 
import torch.nn as nn
from torch import Tensor

In [None]:
from typing import Tuple


class RNN(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input: Tensor, hidden: Tensor) -> Tuple[Tensor, Tensor]:
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.h2o(hidden)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self) -> Tensor:
        return torch.zeros(1, self.hidden_size)


In [None]:
class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = len(self.idx2word)
            self.idx2word.append(word)

    def __len__(self):
        return len(self.word2idx)


class Corpus(object):
    def __init__(self, data_path):
        self.dictionary = Dictionary()
        self.train = self.get_data(os.path.join(data_path, "train.txt"))
        self.test = self.get_data(os.path.join(data_path, "test.txt"))
        self.val = self.get_data(os.path.join(data_path, "valid.txt"))

    def get_data(self, path):
        # Add words to the dictionary
        all_words = []
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                all_words.append(words)
                tokens += len(words)
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize the file content
        ids = torch.LongTensor(tokens)
        token = 0
        for words in all_words:
            for word in words:
                ids[token] = self.dictionary.word2idx[word]
                token += 1
        return ids


def batchify(data: torch.Tensor, batch_size: int) -> torch.Tensor:
    num_batches = data.size(0) // batch_size
    data = data.clone().narrow(0, 0, num_batches * batch_size)
    data = data.view(batch_size, -1).transpose(0, 1).contiguous()
    return data

In [None]:
class RNNLM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout, dropout_lstm, bidirectional):
        super(RNNLM, self).__init__()
        self.dropout = nn.Dropout(dropout, inplace=False)
        self.encoder = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(
            embed_size, hidden_size, num_layers, dropout=dropout_lstm, bidirectional=bidirectional)

        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.bidirectional = 2 if bidirectional else 1
        self.vocab_size = vocab_size

        self.fc = nn.Linear(hidden_size * self.bidirectional, vocab_size)

    def init_weights(self, initrange=0.05):
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-initrange, initrange)

    def forward(self, x: Tensor, hidden: Tensor) -> (Tensor, Tensor):
        batch_size = x.size(1)
        embedding = self.dropout(self.encoder(x))
        output, hidden = self.lstm(embedding, hidden)
        output = self.dropout(output)
        output = output.swapaxes(0, 1)
        decoded = self.fc(output.reshape(-1, self.hidden_size))
        return decoded.reshape(batch_size, -1, self.vocab_size), hidden

    def init_hidden(self, batch_size, device):
        # LSTM h and c
        hidden = torch.zeros(self.num_layers * self.bidirectional, batch_size, self.hidden_size, device=device)
        return hidden, hidden.clone()