In [14]:
import torch
from torch import nn, Tensor
import torch.functional as F
import matplotlib.pyplot as plt

In [15]:
class RNNScratch(nn.Module):
    def __init__(self, num_inputs: int, num_hidden: int, sigma: float = 0.01):
        super().__init__()
        self.num_inputs: int = num_inputs
        self.num_hiddens: int = num_hidden

        self.W_xh = nn.Parameter(torch.randn(num_inputs, num_hidden) * sigma)
        self.W_hh = nn.Parameter(torch.randn(num_hidden, num_hidden) * sigma)
        self.b_h = nn.Parameter(torch.zeros(num_hidden))

    def forward(self, inputs: Tensor, state: Tensor = None) -> tuple[list[Tensor], Tensor]:
        if state is None:
            state = torch.zeros((inputs.shape[1], self.num_hiddens))
    
        outputs = []
        for X in inputs:
            # Fowards the input, does an iteration on the state
            # and stores the sum over those and the bias as the new state 
            state = (X @ self.W_xh) + (state @ self.W_hh) + self.b_h
            outputs.append(state)
        return outputs, state

In [16]:
# Testing that the dimensions are as we expect
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
rnn = RNNScratch(num_inputs, num_hiddens)
X = torch.ones((num_steps, batch_size, num_inputs))
outputs, state = rnn(X)

assert len(outputs) == num_steps

In [12]:
# Testing that the state tracking works
batch_size, num_inputs, num_hidden, num_steps = 2, 16, 32, 2
rnn = RNNScratch(num_inputs, num_hidden)
X = torch.ones((num_steps, batch_size, num_inputs))
outputs, state = rnn(X)

X2 = torch.ones((1, 2, 16))
_, state2 = rnn(X2)
_, state2 = rnn(X2, state2)

assert (state == state2).all()

In [13]:
class Classifier(nn.Module):
    def accuracy(self, Y_hat, Y, averaged=True):
        """Compute the number of correct predictions.
    
        Defined in :numref:`sec_classification`"""
        Y_hat = torch.reshape(Y_hat, (-1, Y_hat.shape[-1]))
        preds = torch.astype(torch.argmax(Y_hat, axis=1), Y.dtype)
        compare = torch.astype(preds == torch.reshape(Y, -1), torch.float32)
        return torch.reduce_mean(compare) if averaged else compare

    def loss(self, Y_hat, Y, averaged=True):
        """Defined in :numref:`sec_softmax_concise`"""
        Y_hat = torch.reshape(Y_hat, (-1, Y_hat.shape[-1]))
        Y = torch.reshape(Y, (-1,))
        return F.cross_entropy(Y_hat, Y, reduction='mean' if averaged else 'none')

    def layer_summary(self, X_shape):
        """Defined in :numref:`sec_lenet`"""
        X = torch.randn(*X_shape)
        for layer in self.net:
            X = layer(X)
            print(layer.__class__.__name__, 'output shape:\t', X.shape)

class RNNLMScratch(nn.Module):
    def __init__(self, rnn: RNNScratch, vocab_size, lr = 0.01):
        super().__init__()
        self.rnn = rnn
        self.vocab_size = vocab_size
        self.lr = lr

    def init_params(self) -> None:
        self.W_hq = nn.Parameter(
            torch.randn(
                self.rnn.num_hiddens, self.vocab_size
            ) * self.rnn.sigma
        )
        self.b_q = nn.Parameter(torch.zeros(self.vocab_size))

    def training_set(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        return l