In [1]:
from torch import nn
import torch

# Exercise 1

In [2]:
input_size = 2
hidden_size = 4
batch = 1
seq_len = 5

In [3]:
ex1_net = nn.RNN(input_size, hidden_size)

In [8]:
input = torch.rand((seq_len, batch, input_size))
input.shape

torch.Size([5, 1, 2])

In [9]:
state = torch.rand((1, batch, hidden_size))
state.shape

torch.Size([1, 1, 4])

In [10]:
output, state_new = ex1_net(input, state)
output.shape, state_new.shape

(torch.Size([5, 1, 4]), torch.Size([1, 1, 4]))

In [11]:
output

tensor([[[-0.3161,  0.1390,  0.7439, -0.6749]],

        [[ 0.0449,  0.4974,  0.5711, -0.8761]],

        [[ 0.2063,  0.4397,  0.6833, -0.8539]],

        [[ 0.0434,  0.4223,  0.6619, -0.8188]],

        [[-0.0951,  0.5377,  0.7382, -0.8407]]], grad_fn=<StackBackward0>)

# Exercise 2

## Data

In [28]:
import collections
import os
import random
import re
from typing import Literal

import requests

def download(url: str) -> str:
    """Download a file, return the local filename."""
    file_path = '../data/' + url.split('/')[-1]
    if os.path.exists(file_path):
        return file_path
    
    print(f'Downloading {file_path} from {url}...')
    res = requests.get(url, stream=True, verify=True)
    with open(file_path, 'wb') as f:
        f.write(res.content)
    return file_path

def read_time_machine() -> list[str]:
    """Load the time machine dataset into a list of text lines."""
    file = download('http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt')
    with open(file, 'r') as f:
        lines = f.readlines()
    
    lines_transformed = [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]
    return [line for line in lines_transformed if line]

def tokenize(lines: list[str], token: Literal['word', 'char']='word') -> list[str]:
    """Split text lines into word or character tokens."""
    if token == 'word':
        tokens: list[list[str]] =  [line.split() for line in lines]
    else:
        tokens: list[list[str]] = [list(line) for line in lines]
    return [token for sublist in tokens for token in sublist]


In [None]:

class Vocab:
    """Vocabulary for text."""
    def __init__(self, tokens: list[str]=[], min_freq=0, reserved_tokens: list[str]=[]):
        # Sort according to frequencies
        counter = collections.Counter(tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1],
                                   reverse=True)
        # The index for the unknown token is 0
        self.idx_to_token = ['<unk>'] + reserved_tokens + [token for token, freq in self._token_freqs if freq >= min_freq]
        self.token_to_idx =  {token: idx for idx, token in enumerate(self.idx_to_token)}

    def __len__(self) -> int:
        return len(self.idx_to_token)

    def __getitem__(self, tokens: str | list[str]) -> int | list[int]:
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.token_to_idx.get(token, self.unk) for token in tokens]

    def to_tokens(self, indices: int | list[int]) -> str | list[str]:
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self) -> Literal[0]:  # Index for the unknown token
        return 0

    @property
    def token_freqs(self) -> list[tuple[str, int]]:  # Token frequencies
        return self._token_freqs


class SeqDataLoader:
    """An iterator to load sequence data."""
    def __init__(self, batch_size: int, num_steps: int, max_tokens: int=10000):
        self.batch_size = batch_size
        self.num_steps = num_steps
        self.vocab = self.generate_vocab()
        self.corpus = self.generate_corpus(max_tokens)
        
    def generate_vocab(self) -> Vocab:
        lines = read_time_machine()
        tokens = tokenize(lines, 'char')
        return Vocab(tokens)
    
    def generate_corpus(self, max_tokens):
        tokens = tokenize(read_time_machine(), 'char')
        corpus = [self.vocab[token] for line in tokens for token in line]
        return corpus[:max_tokens]

    def __iter__(self):
        """Generate a mini-batch of subsequences using sequential partitioning."""
        # Start with a random offset to partition a sequence
        offset = random.randint(0, self.num_steps)
        num_tokens = ((len(self.corpus) - offset - 1) // self.batch_size) * self.batch_size
        Xs = torch.tensor(self.corpus[offset: offset + num_tokens])
        Ys = torch.tensor(self.corpus[offset + 1: offset + 1 + num_tokens])
        Xs, Ys = Xs.reshape(self.batch_size, -1), Ys.reshape(self.batch_size, -1)
        num_batches = Xs.shape[1] // self.num_steps
        for i in range(0, self.num_steps * num_batches, self.num_steps):
            X = Xs[:, i: i + self.num_steps]
            Y = Ys[:, i: i + self.num_steps]
            yield X, Y

In [13]:
batch = 30
steps = 10

In [40]:
train_iter = SeqDataLoader(batch, steps)
vocab = train_iter.vocab

## Network

In [None]:
input_size = len(vocab)
hidden_size = 32
reccurent_layers = 2

In [None]:
ex2_net = nn.RNN(input_size, hidden_size, reccurent_layers)

## Training

In [None]:
def grad_clipping(net: nn.RNN, theta: float):
    """Clip the gradient."""
    params = [p for p in net.parameters() if p.requires_grad]
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

In [None]:
import math

Loss = callable


def train_epoch(net: nn.RNN, train_iter: SeqDataLoader, loss, optimizer: torch.optim.Optimizer, device):
    """Train a net within one epoch."""
    state = None
    # Sum of training loss, no. of tokens
    total_loss: int = 0
    total_tokens: int = 0
    for X, Y in train_iter:
        if state is None:
            # Initialize `state` when it is the first iteration
            state = net.begin_state(batch_size=X.shape[0], device=device)
        else:
            if not isinstance(state, tuple):
                # `state` is a tensor for `nn.GRU`
                state.detach_()
            else:
                # `state` is a tuple of tensors for `nn.LSTM`
                for s in state:
                    s.detach_()
        y = Y.T.reshape(-1)
        X, y = X.to(device), y.to(device)
        y_hat, state = net(X, state)
        l = loss(y_hat, y.long()).mean()
        optimizer.zero_grad()
        l.backward()
        grad_clipping(net, 1)
        optimizer.step()
        total_loss += float(l * y.numel())
        total_tokens += y.numel()
        
    return math.exp(total_loss / total_tokens)


In [None]:
def predict(prefix: str, num_preds: int, net: nn.RNN, vocab: Vocab, device) -> str:
    """Generate new characters following the `prefix`."""
    state = net.begin_state(batch_size=1, device=device)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))
    for y in prefix[1:]:  # Warm-up period
        _, state = net(get_input(), state)
        outputs.append(vocab[y])
    for _ in range(num_preds):  # Predict `num_preds` steps
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs])

In [None]:
def train(net: nn.RNN, train_iter: SeqDataLoader, vocab: Vocab, lr: float, num_epochs: int, device):
    """Train a model."""
    loss = nn.CrossEntropyLoss()
    perplexities = []
    # Initialize
    optimizer = torch.optim.SGD(net.parameters(), lr)
    # Train and predict
    for epoch in range(num_epochs):
        ppl = train_epoch(net, train_iter, loss, optimizer, device)
        if (epoch + 1) % 10 == 0:
            print(predict('time traveller', 50, net, vocab, device))
            perplexities.append(ppl)
    print(f'perplexity {ppl:.1f}, device {str(device)}')
    print(predict('time traveller', 50, net, vocab, device))
    print(predict('traveller', 50, net, vocab, device))

    return perplexities

In [None]:
num_epochs = 200
lr = 1.5
perplexities = train(ex2_net, train_iter, vocab, lr, num_epochs, device) #1 min