In [None]:
from src.mamba_vit import MambaVit
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import ToTensor
from tqdm.notebook import tqdm

In [None]:
m = MambaViT(image_size=28, patch_size=4, num_classes=10, channels=1, n_layer=8, dim=32, pool="mean").to("cuda")
m = m.train()
optimizer = torch.optim.SGD(m.parameters(), lr=0.0005)
loss_fn = nn.CrossEntropyLoss()

In [None]:
mnist_train = torchvision.datasets.MNIST("", download=True, train=True, transform=ToTensor())
mnist_test = torchvision.datasets.MNIST("", download=True, train=False, transform=ToTensor())
train_dataloader = torch.utils.data.DataLoader(mnist_train, batch_size=1024, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(mnist_train, batch_size=1024, shuffle=False)

In [None]:
for epoch in range(100):
    train_loss_list = []
    for img, gt in tqdm(train_dataloader):
        img = img.to("cuda")
        gt = gt.to("cuda")
        pred = m(img)
        loss = loss_fn(input=pred, target=gt)
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss)
    validation_loss_list = []
    accuracy_list = []
    for img, gt in tqdm(test_dataloader):
        img = img.to("cuda")
        gt = gt.to("cuda")
        with torch.no_grad():
            pred = m(img)
        loss = loss_fn(input=pred, target=gt)
        validation_loss_list.append(loss.cpu().detach())
        accuracy_list.append(pred.softmax(-1).argmax(-1).cpu() == gt.cpu())
    print("Training loss: ", torch.mean(torch.stack(train_loss_list)).item())
    print("Validation loss: ", torch.mean(torch.stack(validation_loss_list)).item())
    print("Validation accuracy: ", torch.mean(torch.cat(accuracy_list).float()).item())