In [1]:
import torch
from torch import nn

In [2]:
import collections
import os
import random
import re
from typing import Literal

import requests

def download(url: str) -> str:
    """Download a file, return the local filename."""
    file_path = '../data/' + url.split('/')[-1]
    if os.path.exists(file_path):
        return file_path
    
    print(f'Downloading {file_path} from {url}...')
    res = requests.get(url, stream=True, verify=True)
    with open(file_path, 'wb') as f:
        f.write(res.content)
    return file_path

def read_time_machine() -> list[str]:
    """Load the time machine dataset into a list of text lines."""
    file = download('http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt')
    with open(file, 'r') as f:
        lines = f.readlines()
    
    lines_transformed = [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]
    return [line for line in lines_transformed if line]

def tokenize(lines: list[str], token: Literal['word', 'char']='word') -> list[str]:
    """Split text lines into word or character tokens."""
    if token == 'word':
        tokens: list[list[str]] =  [line.split() for line in lines]
    else:
        tokens: list[list[str]] = [list(line) for line in lines]
    return [token for sublist in tokens for token in sublist]
    
def count_corpus(tokens: list[str]) -> collections.Counter[str]:
    """Count token frequencies."""
    return collections.Counter(tokens)

In [3]:
class Vocab:
    """Vocabulary for text."""
    def __init__(self, tokens: list[str]=[], min_freq=0, reserved_tokens: list[str]=[]):
        # Sort according to frequencies
        counter = count_corpus(tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1],
                                   reverse=True)
        # The index for the unknown token is 0
        self.idx_to_token = ['<unk>'] + reserved_tokens + [token for token, freq in self._token_freqs if freq >= min_freq]
        self.token_to_idx =  {token: idx for idx, token in enumerate(self.idx_to_token)}

    def __len__(self) -> int:
        return len(self.idx_to_token)

    def __getitem__(self, tokens: str | list[str]) -> int | list[int]:
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.token_to_idx.get(token, self.unk) for token in tokens]

    def to_tokens(self, indices: int | list[int]) -> str | list[str]:
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self) -> Literal[0]:  # Index for the unknown token
        return 0

    @property
    def token_freqs(self) -> list[tuple[str, int]]:  # Token frequencies
        return self._token_freqs


class SeqDataLoader:
    """An iterator to load sequence data."""
    def __init__(self, batch_size: int, num_steps: int, max_tokens: int=10000):
        self.batch_size = batch_size
        self.num_steps = num_steps
        self.vocab = self.generate_vocab()
        self.corpus = self.generate_corpus(max_tokens)
        
    def generate_vocab(self) -> Vocab:
        lines = read_time_machine()
        tokens = tokenize(lines, 'char')
        return Vocab(tokens)
    
    def generate_corpus(self, max_tokens):
        tokens = tokenize(read_time_machine(), 'char')
        corpus = [self.vocab[token] for line in tokens for token in line]
        return corpus[:max_tokens]

    def __iter__(self):
        """Generate a mini-batch of subsequences using sequential partitioning."""
        # Start with a random offset to partition a sequence
        offset = random.randint(0, self.num_steps)
        num_tokens = ((len(self.corpus) - offset - 1) // self.batch_size) * self.batch_size
        Xs = torch.tensor(self.corpus[offset: offset + num_tokens])
        Ys = torch.tensor(self.corpus[offset + 1: offset + 1 + num_tokens])
        Xs, Ys = Xs.reshape(self.batch_size, -1), Ys.reshape(self.batch_size, -1)
        num_batches = Xs.shape[1] // self.num_steps
        for i in range(0, self.num_steps * num_batches, self.num_steps):
            X = Xs[:, i: i + self.num_steps]
            Y = Ys[:, i: i + self.num_steps]
            yield X, Y

In [4]:
batch = 50
steps = 10

In [5]:
train_iter = SeqDataLoader(batch, steps)
vocab = train_iter.vocab

In [None]:
input_size = vocab_size = len(vocab)
hidden_size = 64

In [None]:
lstm_layer = nn.LSTM(input_size, hidden_size)
net = nn.RNN(lstm_layer, vocab_size)

In [None]:
perplexities = train(net, train_iter, vocab, lr=1, num_epochs=100, device)