In [None]:
import sys
import torch
from torch import nn
import os
import math

sys.path.append(os.path.abspath("../../data"))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64

In [None]:
from regression_data import WeatherDatasetWrapper

class WeatherDataset(WeatherDatasetWrapper):
    predictors = ["tmax", "tmin", "rain"]
    target = "tmax_tomorrow"
    sequence_length = 7

wrapper = WeatherDataset(DEVICE)
datasets = wrapper.generate_datasets(BATCH_SIZE)
train = datasets["train"]
valid = datasets["validation"]

In [None]:
class GRUCell(nn.Module):
    def __init__(self, input_units, hidden_units, output_units):
        super(GRUCell, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.output_units = output_units

        k = math.sqrt(1/hidden_units)
        self.input_weights = nn.Parameter(torch.rand(3, input_units, hidden_units) * 2 * k - k)
        self.input_biases = nn.Parameter(torch.rand(3, 1, hidden_units) * 2 * k - k)

        self.hidden_weights = nn.Parameter(torch.rand(3, hidden_units, hidden_units) * 2 * k - k)
        self.hidden_biases = nn.Parameter(torch.rand(3, 1, hidden_units) * 2 * k - k)

    def forward(self, x, prev_hidden):
        # Compute the regular RNN forward pass
        # Compute update and reset gates for GRU
        reset_gate = torch.sigmoid(x @ self.input_weights[0,] + self.input_biases[0,] + prev_hidden @ self.hidden_weights[0,] + self.hidden_biases[0,])
        update_gate = torch.sigmoid(x @ self.input_weights[1,] + self.input_biases[1,] + prev_hidden @ self.hidden_weights[1,] + self.hidden_biases[1,])
        new_gate = torch.tanh(x @ self.input_weights[2,] + self.input_biases[2,] + torch.mul(reset_gate, prev_hidden @ self.hidden_weights[2,] + self.hidden_biases[2,]))

        hidden_x = torch.mul((1 - update_gate), new_gate) + torch.mul(update_gate, new_gate)
        return hidden_x

In [None]:
class Network(nn.Module):
    def __init__(self, sequence_len, input_units, output_units, hidden_units=512, layers=2):
        super(Network, self).__init__()
        self.sequence_len = sequence_len
        self.hidden_units = hidden_units
        self.input_units = input_units
        self.output_units = output_units
        self.layers = layers

        self.linear_encode = nn.Linear(in_features=input_units, out_features=hidden_units)

        self.gru = GRUCell(input_units=hidden_units, hidden_units=hidden_units, output_units=hidden_units)
        self.linear_decode = nn.Linear(in_features=hidden_units, out_features=output_units)

    def forward(self, x):
        batch_size = x.shape[0]
        # Embed the input sequence to reduce dimensionality
        encoded = self.linear_encode(x).swapaxes(0,1)

        # Encode the input sequence
        # Both tensors will have sequence then batch
        hiddens = torch.zeros((1, batch_size, self.hidden_units), device=DEVICE)
        outputs = torch.zeros((1, batch_size, self.output_units), device=DEVICE)
        for j in range(self.sequence_len):
            hidden = self.gru(encoded[j,:], hiddens[j,])
            # Add first sequence axis
            output = self.linear_decode(hidden).unsqueeze(0)
            hidden = hidden.unsqueeze(0)
            outputs = torch.cat((outputs, output), dim=0)
            hiddens = torch.cat((hiddens, hidden), dim=0)

        # Move batch back to axis 0, and trim first element
        out_hiddens = hiddens[1:,:,:].swapaxes(0,1)
        out_output = outputs[1:,:,:].swapaxes(0,1)
        return out_output, out_hiddens

In [None]:
from tqdm.auto import tqdm
model = Network(wrapper.sequence_length, input_units=len(wrapper.predictors), output_units=1, hidden_units=512, layers=1).to(DEVICE)
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)

In [None]:
EPOCHS = 100

for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    for batch, (sequence, target) in tqdm(enumerate(train)):
        optimizer.zero_grad()
        pred, hidden = model(sequence)
        loss = loss_fn(pred, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    with torch.no_grad():
        # Compute validation loss.  Unless you have a lot of training data, the validation loss won't decrease.
        valid_loss = 0
        for batch, (sequence, target) in enumerate(valid):
            # Only feed in the first token of the actual target
            pred, hidden = model(sequence)
            loss = loss_fn(pred, target)
            valid_loss += loss.item()
        print(f"Epoch {epoch} train loss: {train_loss / len(train)} valid loss: {valid_loss / len(valid)}")