In [None]:
import sys

sys.path.append("../..")

In [None]:
from flashrnn import flashrnn

import torch
import torch.nn as nn
import torch.optim as optim

# from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
# import torchaudio

# Hyperparameters
batch_size = 32
learning_rate = 0.01
num_epochs = 1


class ParityDataset(Dataset):
    def __init__(self, sequence_length, num_samples):
        """
        Args:
            sequence_length (int): Length of each binary sequence.
            num_samples (int): Number of samples to generate.
        """
        self.sequence_length = sequence_length
        self.num_samples = num_samples
        self.data, self.labels = self.generate_data()

    def generate_data(self):
        # Generate random binary sequences
        data = torch.randint(0, 2, (self.num_samples, self.sequence_length))
        # Compute the parity of each sequence (1 if odd number of ones, 0 if even)
        labels = (
            data.sum(dim=1) % 2
        )  # Sum and then mod 2 for parity (0 = even, 1 = odd)
        return data, labels

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Return the binary sequence and the corresponding label
        return self.data[idx], self.labels[idx]


# Parameters for the dataset
sequence_length = 16  # Length of each binary sequence
num_samples = 1 << 20  # Number of samples in the dataset

# Create the dataset
parity_dataset = ParityDataset(sequence_length, num_samples)

# Create a DataLoader to iterate through the dataset
batch_size = 512
train_loader = DataLoader(parity_dataset, batch_size=batch_size, shuffle=True)

parity_dataset_val = ParityDataset(sequence_length, num_samples)

# Create a DataLoader to iterate through the dataset
batch_size = 512
val_loader = DataLoader(parity_dataset_val, batch_size=batch_size, shuffle=True)

In [None]:
hidden_size = 64


class LastPool(nn.Module):
    def forward(self, x):
        return x[..., -1, :]


class LSTMWrap(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.hidden_dim = hidden_dim
        self.reset_parameters()

    def reset_parameters(self):
        self.lstm.reset_parameters()
        with torch.no_grad():
            self.lstm.bias_hh_l0[self.hidden_dim : 2 * self.hidden_dim] = 4.0
            self.lstm.bias_ih_l0[self.hidden_dim : 2 * self.hidden_dim] = 4.0

    def forward(self, x):
        return self.lstm(x)[0]


class LSTMWrapFlashRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.W = nn.Linear(input_dim, 4 * hidden_dim)
        self.hidden_dim = hidden_dim
        self.R = nn.Parameter(torch.zeros(4 * hidden_dim, hidden_dim))
        self.b = nn.Parameter(torch.zeros(4 * hidden_dim))

    def _set_parameters_from(self, mod: LSTMWrap):
        with torch.no_grad():
            self.W.weight += -self.W.weight + mod.lstm.weight_ih_l0
            self.W.bias += -self.W.bias + mod.lstm.bias_ih_l0
            self.R += -self.R + mod.lstm.weight_hh_l0
            self.b += -self.b + mod.lstm.bias_hh_l0

    def forward(self, x):
        B, S, _ = x.shape
        return flashrnn(
            self.W(x).view(B, S, 4, 1, self.hidden_dim),
            self.R.view(4, 1, self.hidden_dim, self.hidden_dim),
            self.b.view(4, 1, self.hidden_dim),
            dtype="float32",
            function="lstm",
        )[0][0].reshape(B, S, self.hidden_dim)


model = torch.nn.Sequential(
    nn.Embedding(2, hidden_size),
    LSTMWrap(hidden_size, hidden_size),
    LastPool(),
    nn.Linear(hidden_size, 2),
)


model = model.to(device="cuda")

model2 = torch.nn.Sequential(
    nn.Embedding(2, hidden_size),
    LSTMWrapFlashRNN(hidden_size, hidden_size),
    LastPool(),
    nn.Linear(hidden_size, 2),
)

model2 = model2.to(device="cuda")

print(model)
print(model2)

print("\n\nParameters: \n")
print({np: p.numel() for np, p in model.named_parameters()})
print({np: p.numel() for np, p in model2.named_parameters()})

# set parameters equal
with torch.no_grad():
    model2[0].weight += -model2[0].weight + model[0].weight
    model2[1]._set_parameters_from(model[1])
    model2[3].weight += -model2[3].weight + model[3].weight
    model2[3].bias += -model2[3].bias + model[3].bias

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
optimizer2 = optim.Adam(model2.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    model2.train()
    running_loss = 0.0
    running_loss2 = 0.0
    max_loss_diff = 0.0

    with tqdm(train_loader, unit="batch") as tepoch:
        for idx, (inputs, labels) in enumerate(tepoch):
            # Set description for the progress bar
            tepoch.set_description(f"Epoch {epoch + 1}")

            inputs, labels = inputs.to("cuda"), labels.to("cuda")
            with torch.no_grad():
                inputs2, labels2 = inputs.clone().detach(), labels.clone().detach()

            # Zero the parameter gradients
            optimizer.zero_grad()
            optimizer2.zero_grad()

            # Forward pass
            outputs = model(inputs)
            outputs2 = model2(inputs2)

            # Compute the loss
            loss = criterion(outputs, labels)
            loss2 = criterion(outputs2, labels2)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            loss2.backward()
            optimizer2.step()

            # Update running loss
            running_loss += loss.item()
            running_loss2 += loss2.item()

            # Show the current batch loss in the progress bar
            max_loss_diff = max(max_loss_diff, abs(loss.item() - loss2.item()))
            tepoch.set_postfix(
                loss=loss.item(),
                loss2=loss2.item(),
                loss_diff=abs(loss.item() - loss2.item()),
                max_loss_diff=max_loss_diff,
            )
            if idx % 100 == 0:
                input()

    # Print loss for the epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

print("Training complete!")