In [None]:
import torch.nn as nn
import matplotlib.pyplot as plt
import torch
import importlib
import model
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import tqdm
import kan

In [None]:
importlib.reload(model)
importlib.reload(kan)

In [None]:
from kan import KAN

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)


train_data = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)


test_data = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)



train_loader = DataLoader(train_data, batch_size=32, shuffle=True)


test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [None]:
from model import GRAMLayer


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(28 * 28, 32),
            nn.LayerNorm(32),
            nn.Linear(32, 16),
            nn.LayerNorm(16),
            nn.Linear(16, 10),
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.layers(x)


class GRAM(nn.Module):
    def __init__(self):
        super(GRAM, self).__init__()
        self.layers = nn.Sequential(
            GRAMLayer(28 * 28, 32), GRAMLayer(32, 16), GRAMLayer(16, 10)
        )

    def forward(self, x):
        # x = x.to("cuda")
        x = x.view(x.size(0), -1)
        return self.layers(x)

In [None]:
def train_and_test_model(model, epochs):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    train_losses = []  # List to store the training loss values
    test_losses = []  # List to store the test loss values
    accuracies = []  # List to store the test accuracies

    for epoch in range(epochs):  # number of epochs
        epoch_loss = 0
        for images, labels in tqdm.tqdm(train_loader):
            optimizer.zero_grad()
            outputs = model(images)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        train_losses.append(epoch_loss / len(train_loader))

        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in tqdm.tqdm(test_loader):
                outputs = model(images)
                loss = nn.CrossEntropyLoss()(outputs, labels)
                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_losses.append(test_loss / len(test_loader))
        accuracies.append(correct / total)

        print(
            "Epoch: {}, Train Loss: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.2f}".format(
                epoch, train_losses[-1], test_losses[-1], accuracies[-1]
            )
        )

    return train_losses, test_losses, accuracies

In [None]:
# Train the GRAM model
model = GRAM()
gram_train_losses, gram_test_losses, gram_accuracies = train_and_test_model(model, 10)

# Train the MLP model
model = MLP()
mlp_train_losses, mlp_test_losses, mlp_accuracies = train_and_test_model(model, 10)

# Plot the losses
plt.figure(figsize=(10, 5))
plt.plot(gram_train_losses, label="GRAM Train Loss")
plt.plot(gram_test_losses, label="GRAM Test Loss")
plt.plot(mlp_train_losses, label="MLP Train Loss")
plt.plot(mlp_test_losses, label="MLP Test Loss")
plt.title("Model Convergence")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Plot the accuracies
plt.figure(figsize=(10, 5))
plt.plot(gram_accuracies, label="GRAM Test Acc")
plt.plot(mlp_accuracies, label="MLP Test Acc")
plt.title("Model Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()