# Day 34 - RNNs from Scratch

In [1]:
import torch
import einops
from torch import nn, optim

In [2]:
class Recurrent(nn.Module):
    def __init__(self, n_hiddens=5):
        super().__init__()
        self.step = nn.Linear(n_hiddens, n_hiddens)
        self.input = nn.Sequential(
            nn.Linear(1, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_hiddens),
        )
        self.out = nn.Sequential(
            nn.Linear(n_hiddens, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, 1),
        )

    def forward(self, x):
        batch_size, length = x.shape
        out = torch.zeros([batch_size, length])
        self.state = torch.zeros([batch_size, n_hiddens])
        
        x = einops.rearrange(x, "b l -> l b 1")
        
        for i, x_i in enumerate(x):
            self.state = torch.relu(self.step(self.state) + self.input(x_i))
            step_out = self.out(self.state)
            out[:, i] = einops.rearrange(step_out, "b 1 -> b")

        return out

In [3]:
class RecurrentSimplified(nn.Module):
    def __init__(self, n_hiddens=5):
        super().__init__()
        self.input = nn.Sequential(
            # Input and hidden state will be concatenated
            # and multiplied by a shared weight matrix
            nn.Linear(1 + n_hiddens, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_hiddens),
        )
        self.out = nn.Sequential(
            nn.Linear(n_hiddens, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, 1),
        )

    def forward(self, x):
        batch_size, length = x.shape
        out = torch.zeros([batch_size, length])
        self.state = torch.zeros([batch_size, n_hiddens])
        
        x = einops.rearrange(x, "b l -> l b 1")
        
        for i, x_i in enumerate(x):
            self.state = torch.relu(self.input(einops.pack([self.state, x_i], "b *")[0]))
            step_out = self.out(self.state)
            out[:, i] = einops.rearrange(step_out, "b 1 -> b")

        return out

In [4]:
seqs = torch.tensor(
    [
        [-11, -10, -9, -8, -7],
        [35, 36, 37, 38, 39],
        [-9, -8, -7, -6, -5],
        [1, 2, 3, 4, 5],
        [4, 5, 6, 7, 8],
        [7, 8, 9, 10, 11],
        [12, 13, 14, 15, 16],
        [28, 29, 30, 31, 32],
    ],
    dtype=torch.float32,
)

In [5]:
n_hiddens = 128

In [6]:
net = RecurrentSimplified(n_hiddens=n_hiddens)
net

RecurrentSimplified(
  (input): Sequential(
    (0): Linear(in_features=129, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=128, bias=True)
  )
  (out): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [7]:
optimizer = optim.Adam(net.parameters(), lr=2e-2)

In [8]:
loss_fn = nn.MSELoss()

In [9]:
preds = net(seqs)
preds

tensor([[-0.0657, -0.0647, -0.0642, -0.0640, -0.0639],
        [-0.0690, -0.0692, -0.0697, -0.0700, -0.0705],
        [-0.0646, -0.0639, -0.0639, -0.0632, -0.0625],
        [-0.0608, -0.0618, -0.0628, -0.0638, -0.0648],
        [-0.0638, -0.0648, -0.0659, -0.0669, -0.0675],
        [-0.0664, -0.0675, -0.0681, -0.0684, -0.0688],
        [-0.0689, -0.0695, -0.0697, -0.0699, -0.0700],
        [-0.0682, -0.0680, -0.0680, -0.0682, -0.0684]], grad_fn=<CopySlices>)

In [10]:
loss = sum(loss_fn(preds[:, i], seqs[:, i] + 1) for i in range(seqs.shape[1]))
loss

tensor(1827.2748, grad_fn=<AddBackward0>)

In [11]:
optimizer.zero_grad()
loss.backward()
optimizer.step()

In [12]:
from tqdm.auto import tqdm

In [13]:
for _ in tqdm(range(512), desc="Epochs"):
    preds = net(seqs)
    loss = sum(loss_fn(preds[:, i], seqs[:, i] + 1) for i in range(seqs.shape[1]))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Epochs:   0%|          | 0/512 [00:00<?, ?it/s]

In [14]:
net(seqs)

tensor([[-9.9728, -9.0151, -7.8596, -6.9163, -5.9898],
        [35.9987, 37.0015, 37.9708, 39.0153, 40.0073],
        [-8.0230, -7.1070, -6.0036, -5.0618, -4.1351],
        [ 2.0284,  3.0049,  3.9906,  5.0017,  6.0045],
        [ 5.0033,  5.9802,  7.0068,  8.0066,  9.0069],
        [ 8.0049,  8.9851, 10.0031, 11.0075, 12.0069],
        [13.0005, 13.9891, 14.9971, 16.0090, 17.0070],
        [28.9986, 29.9979, 30.9787, 32.0134, 33.0072]], grad_fn=<CopySlices>)

In [15]:
new_seq = torch.tensor(
    [50, 51, 52, 53, 54],
    dtype=torch.float32,
)
new_seq = einops.rearrange(new_seq, "l -> 1 l")
torch.round(net(new_seq))

tensor([[51., 52., 53., 54., 55.]], grad_fn=<RoundBackward0>)