In [12]:
from vit_pytorch import ViT
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [13]:

transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=transform)

In [14]:
class ComplexDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx][0]
        label = self.labels[idx]

        if self.transform:
            sample = self.transform(sample)
        return sample, label

In [15]:
# Create the custom dataset
train_labels = [label for _, label in train_data]
train_dataset = ComplexDataset(train_data, train_labels, transform=None) 

test_labels = [label for _, label in test_data]
test_dataset = ComplexDataset(test_data, test_labels)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(len(train_loader))

938


In [16]:
item = next(iter(train_loader))

In [24]:
model = ViT(
    image_size = 28,
    patch_size = 7,
    num_classes = 10,
    dim = 256,
    depth = 6,
    heads = 12,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1,
    channels = 1
)

img = torch.randn(1, 1, 28, 28)

preds = model(img) # (1, 1000)

In [20]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")

Total trainable parameters: 6359698


In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [22]:
writer = SummaryWriter('runs/experiment_1') 

def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}", unit="batch")
    for batch_idx, (data, target) in enumerate(progress_bar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        progress_bar.set_postfix({"Train Loss": train_loss / (batch_idx + 1), "Train Acc": 100. * correct / total})
        
        # Log batch-level metrics
        writer.add_scalar('Loss/Train Batch', loss.item(), epoch * len(train_loader) + batch_idx)
        writer.add_scalar('Accuracy/Train Batch', 100. * correct / total, epoch * len(train_loader) + batch_idx)
    
    train_loss /= len(train_loader)
    train_accuracy = 100. * correct / total
    
    # Log epoch-level metrics
    writer.add_scalar('Loss/Train Epoch', train_loss, epoch)
    writer.add_scalar('Accuracy/Train Epoch', train_accuracy, epoch)
    
    return train_loss, train_accuracy

def test(model, device, test_loader, criterion, epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    test_loss /= len(test_loader)
    test_accuracy = 100. * correct / total
    
    # Log epoch-level metrics
    writer.add_scalar('Loss/Test', test_loss, epoch)
    writer.add_scalar('Accuracy/Test', test_accuracy, epoch)
    
    return test_loss, test_accuracy


In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [11]:


model.to(device)
epochs = 500
for epoch in range(1, epochs + 1):
    train_loss, train_accuracy = train(model, device, train_loader, optimizer, criterion, epoch)
    test_loss, test_accuracy = test(model, device, test_loader, criterion, epoch)
    print(f"Epoch {epoch}")
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    print()


writer.close()

Epoch 1: 100%|██████████| 938/938 [01:00<00:00, 15.59batch/s, Train Loss=0.525, Train Acc=81.3]


Epoch 1
Train Loss: 0.5248, Train Accuracy: 81.29%
Test Loss: 0.3776, Test Accuracy: 86.56%



Epoch 2: 100%|██████████| 938/938 [01:00<00:00, 15.48batch/s, Train Loss=0.336, Train Acc=87.8]


Epoch 2
Train Loss: 0.3360, Train Accuracy: 87.78%
Test Loss: 0.3317, Test Accuracy: 87.76%



Epoch 3: 100%|██████████| 938/938 [01:00<00:00, 15.51batch/s, Train Loss=0.293, Train Acc=89.3]


Epoch 3
Train Loss: 0.2929, Train Accuracy: 89.28%
Test Loss: 0.3174, Test Accuracy: 88.60%



Epoch 4: 100%|██████████| 938/938 [01:00<00:00, 15.48batch/s, Train Loss=0.265, Train Acc=90.3]


Epoch 4
Train Loss: 0.2652, Train Accuracy: 90.30%
Test Loss: 0.3208, Test Accuracy: 88.56%



Epoch 5: 100%|██████████| 938/938 [01:00<00:00, 15.49batch/s, Train Loss=0.241, Train Acc=91]  


Epoch 5
Train Loss: 0.2413, Train Accuracy: 91.05%
Test Loss: 0.3218, Test Accuracy: 88.90%



Epoch 6: 100%|██████████| 938/938 [01:00<00:00, 15.46batch/s, Train Loss=0.219, Train Acc=91.9]


Epoch 6
Train Loss: 0.2186, Train Accuracy: 91.87%
Test Loss: 0.3102, Test Accuracy: 89.10%



Epoch 7: 100%|██████████| 938/938 [01:00<00:00, 15.47batch/s, Train Loss=0.202, Train Acc=92.4]


Epoch 7
Train Loss: 0.2021, Train Accuracy: 92.45%
Test Loss: 0.3142, Test Accuracy: 89.20%



Epoch 8: 100%|██████████| 938/938 [01:00<00:00, 15.50batch/s, Train Loss=0.183, Train Acc=93.2]


Epoch 8
Train Loss: 0.1835, Train Accuracy: 93.18%
Test Loss: 0.3001, Test Accuracy: 90.23%



Epoch 9: 100%|██████████| 938/938 [01:00<00:00, 15.50batch/s, Train Loss=0.164, Train Acc=93.9]


Epoch 9
Train Loss: 0.1642, Train Accuracy: 93.88%
Test Loss: 0.3235, Test Accuracy: 89.17%



Epoch 10: 100%|██████████| 938/938 [01:00<00:00, 15.47batch/s, Train Loss=0.149, Train Acc=94.5]


Epoch 10
Train Loss: 0.1490, Train Accuracy: 94.55%
Test Loss: 0.3175, Test Accuracy: 89.87%



Epoch 11: 100%|██████████| 938/938 [01:00<00:00, 15.54batch/s, Train Loss=0.134, Train Acc=95]  


Epoch 11
Train Loss: 0.1345, Train Accuracy: 95.00%
Test Loss: 0.3347, Test Accuracy: 89.77%



Epoch 12: 100%|██████████| 938/938 [01:00<00:00, 15.51batch/s, Train Loss=0.122, Train Acc=95.5]


Epoch 12
Train Loss: 0.1224, Train Accuracy: 95.52%
Test Loss: 0.3244, Test Accuracy: 89.92%



Epoch 13: 100%|██████████| 938/938 [01:00<00:00, 15.56batch/s, Train Loss=0.11, Train Acc=95.9] 


Epoch 13
Train Loss: 0.1104, Train Accuracy: 95.93%
Test Loss: 0.3813, Test Accuracy: 89.13%



Epoch 14: 100%|██████████| 938/938 [01:00<00:00, 15.52batch/s, Train Loss=0.101, Train Acc=96.4] 


Epoch 14
Train Loss: 0.1013, Train Accuracy: 96.37%
Test Loss: 0.3556, Test Accuracy: 90.13%



Epoch 15:  94%|█████████▍| 880/938 [00:56<00:03, 15.46batch/s, Train Loss=0.0901, Train Acc=96.7]


KeyboardInterrupt: 