`Gradient with respect to W_ph:`

`∂L(T)/∂W_ph = ∂L(T)/∂p(T) * ∂p(T)/∂W_ph
= ∂L(T)/∂p(T) * h(T)^T`

`Gradient with respect to W_hh:`

`∂L(T)/∂W_hh = Σ(t=1 to T) ∂L(T)/∂h(T) * ∂h(T)/∂h(t) * ∂h(t)/∂W_hh`

I think generally speaking the second one would have a exploding gradient and a vanishing gradient problem, also for long term dependencies it may not work anymore since the vanishing gradient will not carry on the information anymore.

import torch
import torch.nn as nn

class VanillaRNN(nn.Module):

    def __init__(self, seq_length, input_dim, num_hidden, num_classes, batch_size, device="cpu"):
        super(VanillaRNN, self).__init__()
        self.seq_length = seq_length
        self.num_hidden = num_hidden
        self.batch_size = batch_size
        self.device = device

        # Initialize weights and biases
        self.W_hx = nn.Parameter(torch.randn(num_hidden, input_dim, device=device))
        self.W_hh = nn.Parameter(torch.randn(num_hidden, num_hidden, device=device))
        self.b_h = nn.Parameter(torch.zeros(num_hidden, device=device))
        
        self.W_ph = nn.Parameter(torch.randn(num_classes, num_hidden, device=device))
        self.b_p = nn.Parameter(torch.zeros(num_classes, device=device))

    def forward(self, x):
        # x shape: (batch_size, seq_length, input_dim)
        h = torch.zeros(self.batch_size, self.num_hidden, device=self.device)
        
        outputs = []

        for t in range(self.seq_length):
            x_t = x[:, t, :]
            h = torch.tanh(torch.matmul(x_t, self.W_hx.t()) + torch.matmul(h, self.W_hh.t()) + self.b_h)
            p = torch.matmul(h, self.W_ph.t()) + self.b_p
            outputs.append(p)

        return torch.stack(outputs, dim=1)

In [None]:
if config.model_type == "RNN":
    model = VanillaRNN(
        seq_length=config.input_length,
        input_dim=config.input_dim,
        num_hidden=config.num_hidden,
        num_classes=config.num_classes,
        batch_size=config.batch_size,
        device=device
    ).to(device)
elif config.model_type == "LSTM":
    model = LSTM(
        seq_length=config.input_length,
        input_dim=config.input_dim,
        num_hidden=config.num_hidden,
        num_classes=config.num_classes,
        batch_size=config.batch_size,
        device=device
    ).to(device)

In [None]:
################################################################################
# MIT License
#
# Copyright (c) 2018
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to conditions.
#


################################################################################

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import time
from datetime import datetime
import numpy as np

import torch
from torch.utils.data import DataLoader

from part1.dataset import PalindromeDataset
from part1.vanilla_rnn import VanillaRNN
from part1.lstm import LSTM

# You may want to look into tensorboardX for logging
# from tensorboardX import SummaryWriter

################################################################################


def train(config):

    assert config.model_type in ("RNN", "LSTM")

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    # Initialize the model that we are going to use
    if config.model_type == "RNN":
    model = VanillaRNN(
        seq_length=config.input_length,
        input_dim=config.input_dim,
        num_hidden=config.num_hidden,
        num_classes=config.num_classes,
        batch_size=config.batch_size,
        device=device
    ).to(device)
    elif config.model_type == "LSTM":
        model = LSTM(
            seq_length=config.input_length,
            input_dim=config.input_dim,
            num_hidden=config.num_hidden,
            num_classes=config.num_classes,
            batch_size=config.batch_size,
            device=device
        ).to(device)  # fixme

    # Initialize the dataset and data loader (note the +1)
    dataset = PalindromeDataset(config.input_length + 1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = nn.CrossEntropyLoss()  # fixme
    optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate)  # fixme

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Only for time measurement of step through network
        t1 = time.time()

        # Add more code here ...
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)
        
        outputs = model(batch_inputs)
        ############################################################################
        # QUESTION: what happens here and why?
        ############################################################################
        torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=config.max_norm)
        ############################################################################

        # Add more code here ...
        loss = criterion(outputs[:, -1, :], batch_targets)
        loss.backward()
        optimizer.step()

        loss = np.inf  # fixme
        accuracy = 0.0  # fixme

        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size / float(t2 - t1)

        if step % 10 == 0:

            print(
                "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"),
                    step,
                    config.train_steps,
                    config.batch_size,
                    examples_per_second,
                    accuracy,
                    loss,
                )
            )

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print("Done training.")


################################################################################
################################################################################

if __name__ == "__main__":

    # Parse training configuration
    parser = argparse.ArgumentParser()

    # Model params
    parser.add_argument(
        "--model_type",
        type=str,
        default="RNN",
        help="Model type, should be 'RNN' or 'LSTM'",
    )
    parser.add_argument(
        "--input_length", type=int, default=10, help="Length of an input sequence"
    )
    parser.add_argument(
        "--input_dim", type=int, default=1, help="Dimensionality of input sequence"
    )
    parser.add_argument(
        "--num_classes", type=int, default=10, help="Dimensionality of output sequence"
    )
    parser.add_argument(
        "--num_hidden",
        type=int,
        default=128,
        help="Number of hidden units in the model",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=128,
        help="Number of examples to process in a batch",
    )
    parser.add_argument(
        "--learning_rate", type=float, default=0.001, help="Learning rate"
    )
    parser.add_argument(
        "--train_steps", type=int, default=10000, help="Number of training steps"
    )
    parser.add_argument("--max_norm", type=float, default=10.0)
    parser.add_argument(
        "--device", type=str, default="cuda:0", help="Training device 'cpu' or 'cuda:0'"
    )

    config = parser.parse_args()

    # Train the model
    train(config)

In [None]:
import matplotlib.pyplot as plt
from part1.dataset import PalindromeDataset
from part1.vanilla_rnn import VanillaRNN
from torch.utils.data import DataLoader

def train_and_evaluate(seq_length, config):
    # Initialize the model
    model = VanillaRNN(
        seq_length=seq_length,
        input_dim=config.input_dim,
        num_hidden=config.num_hidden,
        num_classes=config.num_classes,
        batch_size=config.batch_size,
        device=config.device
    ).to(config.device)

    # Initialize the dataset and data loader
    dataset = PalindromeDataset(seq_length + 1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    # Setup the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate)

    # Training loop
    for step, (batch_inputs, batch_targets) in enumerate(data_loader):
        batch_inputs = batch_inputs.to(config.device)
        batch_targets = batch_targets.to(config.device)

        optimizer.zero_grad()
        outputs = model(batch_inputs)
        loss = criterion(outputs[:, -1, :], batch_targets)
        loss.backward()
        optimizer.step()

        if step >= config.train_steps:
            break

    # Evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_inputs, batch_targets in data_loader:
            batch_inputs = batch_inputs.to(config.device)
            batch_targets = batch_targets.to(config.device)
            outputs = model(batch_inputs)
            _, predicted = torch.max(outputs[:, -1, :], 1)
            total += batch_targets.size(0)
            correct += (predicted == batch_targets).sum().item()
            if total >= 1000:  # Evaluate on 1000 samples
                break
    
    accuracy = correct / total
    return accuracy

# Experiment with increasing sequence lengths
seq_lengths = range(5, 21, 2)  # [5, 7, 9, ..., 19]
accuracies = []

for seq_length in seq_lengths:
    accuracy = train_and_evaluate(seq_length, config)
    accuracies.append(accuracy)
    print(f"Sequence length: {seq_length}, Accuracy: {accuracy:.4f}")

# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(seq_lengths, accuracies, marker='o')
plt.title('Accuracy vs Palindrome Length for Vanilla RNN')
plt.xlabel('Palindrome Length')
plt.ylabel('Accuracy')
plt.grid(True)
plt.savefig('vanilla_rnn_accuracy.png')
plt.show()

In [None]:
import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, seq_length, input_dim, num_hidden, num_classes, batch_size, device="cpu"):
        super(LSTM, self).__init__()
        self.seq_length = seq_length
        self.num_hidden = num_hidden
        self.batch_size = batch_size
        self.device = device

        # Initialize weights and biases
        self.W_gx = nn.Parameter(torch.randn(num_hidden, input_dim, device=device))
        self.W_gh = nn.Parameter(torch.randn(num_hidden, num_hidden, device=device))
        self.b_g = nn.Parameter(torch.zeros(num_hidden, device=device))

        self.W_ix = nn.Parameter(torch.randn(num_hidden, input_dim, device=device))
        self.W_ih = nn.Parameter(torch.randn(num_hidden, num_hidden, device=device))
        self.b_i = nn.Parameter(torch.zeros(num_hidden, device=device))

        self.W_fx = nn.Parameter(torch.randn(num_hidden, input_dim, device=device))
        self.W_fh = nn.Parameter(torch.randn(num_hidden, num_hidden, device=device))
        self.b_f = nn.Parameter(torch.zeros(num_hidden, device=device))

        self.W_ox = nn.Parameter(torch.randn(num_hidden, input_dim, device=device))
        self.W_oh = nn.Parameter(torch.randn(num_hidden, num_hidden, device=device))
        self.b_o = nn.Parameter(torch.zeros(num_hidden, device=device))

        self.W_ph = nn.Parameter(torch.randn(num_classes, num_hidden, device=device))
        self.b_p = nn.Parameter(torch.zeros(num_classes, device=device))

    def forward(self, x):
        # x shape: (batch_size, seq_length, input_dim)
        h = torch.zeros(self.batch_size, self.num_hidden, device=self.device)
        c = torch.zeros(self.batch_size, self.num_hidden, device=self.device)
        
        outputs = []

        for t in range(self.seq_length):
            x_t = x[:, t, :]
            
            g = torch.tanh(torch.matmul(x_t, self.W_gx.t()) + torch.matmul(h, self.W_gh.t()) + self.b_g)
            i = torch.sigmoid(torch.matmul(x_t, self.W_ix.t()) + torch.matmul(h, self.W_ih.t()) + self.b_i)
            f = torch.sigmoid(torch.matmul(x_t, self.W_fx.t()) + torch.matmul(h, self.W_fh.t()) + self.b_f)
            o = torch.sigmoid(torch.matmul(x_t, self.W_ox.t()) + torch.matmul(h, self.W_oh.t()) + self.b_o)
            
            c = g * i + c * f
            h = torch.tanh(c) * o
            
            p = torch.matmul(h, self.W_ph.t()) + self.b_p
            outputs.append(p)

        return torch.stack(outputs, dim=1)

In [None]:
import matplotlib.pyplot as plt
from part1.dataset import PalindromeDataset
from part1.vanilla_rnn import VanillaRNN
from part1.lstm import LSTM
from torch.utils.data import DataLoader

def train_and_evaluate(model_type, seq_length, config):
    if model_type == "RNN":
        model = VanillaRNN(seq_length, config.input_dim, config.num_hidden, config.num_classes, config.batch_size, config.device).to(config.device)
    elif model_type == "LSTM":
        model = LSTM(seq_length, config.input_dim, config.num_hidden, config.num_classes, config.batch_size, config.device).to(config.device)

    dataset = PalindromeDataset(seq_length + 1)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate)

    # Training loop (same as before)
    # ...

    # Evaluation (same as before)
    # ...

    return accuracy

# Experiment with increasing sequence lengths
seq_lengths = range(5, 31, 5)  # [5, 10, 15, 20, 25, 30]
rnn_accuracies = []
lstm_accuracies = []

for seq_length in seq_lengths:
    rnn_accuracy = train_and_evaluate("RNN", seq_length, config)
    lstm_accuracy = train_and_evaluate("LSTM", seq_length, config)
    rnn_accuracies.append(rnn_accuracy)
    lstm_accuracies.append(lstm_accuracy)
    print(f"Sequence length: {seq_length}")
    print(f"RNN Accuracy: {rnn_accuracy:.4f}")
    print(f"LSTM Accuracy: {lstm_accuracy:.4f}")

# Plot the results
plt.figure(figsize=(10, 6))
plt.plot(seq_lengths, rnn_accuracies, marker='o', label='RNN')
plt.plot(seq_lengths, lstm_accuracies, marker='s', label='LSTM')
plt.title('Accuracy vs Palindrome Length: RNN vs LSTM')
plt.xlabel('Palindrome Length')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.savefig('rnn_vs_lstm_accuracy.png')
plt.show()