In [2]:
from vit_pytorch import ViT
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import random
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [3]:
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=transform)
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=transform)

In [8]:

def fft_transform(data):
    fft_data = []
    for img, _ in data:
        fft = np.fft.fft2(img.numpy().squeeze())
        fft_data.append(fft)
    return np.array(fft_data)

train_fft = fft_transform(train_data)
test_fft = fft_transform(test_data)


In [10]:
class ComplexDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        sample = np.fft.fftshift(sample)
        real_data = torch.from_numpy(sample.real).float()
        imag_data = torch.from_numpy(sample.imag).float()
        complex_data = torch.stack((real_data, imag_data), dim=0)

        if self.transform:
            complex_data = self.transform(complex_data)

        return complex_data, label

In [12]:
# Create the custom dataset
train_labels = [label for _, label in train_data]
train_dataset = ComplexDataset(train_fft, train_labels, transform=None) 

test_labels = [label for _, label in test_data]
test_dataset = ComplexDataset(test_fft, test_labels)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(len(train_loader))

938


In [16]:
item = next(iter(train_loader))
item[0].shape

torch.Size([64, 2, 28, 28])

In [61]:
model = ViT(
    image_size = 28,
    patch_size = 4,
    num_classes = 10,
    dim = 128,
    depth = 6,
    heads = 16,
    mlp_dim = 256,
    dropout = 0.5,
    emb_dropout = 0.5,
    channels = 2
)

img = torch.randn(1, 2, 28, 28)

preds = model(img) # (1, 1000)

In [62]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")

Total trainable parameters: 3557706


In [63]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [64]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}", unit="batch")
    for batch_idx, (data, target) in enumerate(progress_bar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        progress_bar.set_postfix({"Train Loss": train_loss / (batch_idx + 1), "Train Acc": 100. * correct / total})
    
    train_loss /= len(train_loader)
    train_accuracy = 100. * correct / total
    return train_loss, train_accuracy

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    test_loss /= len(test_loader)
    test_accuracy = 100. * correct / total
    return test_loss, test_accuracy


In [65]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [66]:
model.to(device)

epochs = 500

for epoch in range(1, epochs + 1):
    train_loss, train_accuracy = train(model, device, train_loader, optimizer, epoch)
    test_loss, test_accuracy = test(model, device, test_loader)
    
    print(f"Epoch {epoch}")
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    print()

Epoch 1:   0%|          | 0/938 [00:00<?, ?batch/s]

Epoch 1: 100%|██████████| 938/938 [01:32<00:00, 10.13batch/s, Train Loss=1.15, Train Acc=58.3]


Epoch 1
Train Loss: 1.1540, Train Accuracy: 58.34%
Test Loss: 0.5944, Test Accuracy: 79.26%



Epoch 2: 100%|██████████| 938/938 [01:32<00:00, 10.11batch/s, Train Loss=0.58, Train Acc=79.2] 


Epoch 2
Train Loss: 0.5804, Train Accuracy: 79.17%
Test Loss: 0.4724, Test Accuracy: 83.64%



Epoch 3: 100%|██████████| 938/938 [01:33<00:00, 10.08batch/s, Train Loss=0.501, Train Acc=82.2]


Epoch 3
Train Loss: 0.5014, Train Accuracy: 82.20%
Test Loss: 0.4461, Test Accuracy: 84.54%



Epoch 4: 100%|██████████| 938/938 [01:32<00:00, 10.10batch/s, Train Loss=0.459, Train Acc=83.5]


Epoch 4
Train Loss: 0.4592, Train Accuracy: 83.54%
Test Loss: 0.4325, Test Accuracy: 84.91%



Epoch 5: 100%|██████████| 938/938 [01:32<00:00, 10.11batch/s, Train Loss=0.434, Train Acc=84.4]


Epoch 5
Train Loss: 0.4340, Train Accuracy: 84.44%
Test Loss: 0.4225, Test Accuracy: 85.61%



Epoch 6: 100%|██████████| 938/938 [01:32<00:00, 10.09batch/s, Train Loss=0.416, Train Acc=84.9]


Epoch 6
Train Loss: 0.4161, Train Accuracy: 84.94%
Test Loss: 0.4120, Test Accuracy: 85.81%



Epoch 7: 100%|██████████| 938/938 [01:32<00:00, 10.11batch/s, Train Loss=0.399, Train Acc=85.5]


Epoch 7
Train Loss: 0.3991, Train Accuracy: 85.49%
Test Loss: 0.4110, Test Accuracy: 85.75%



Epoch 8: 100%|██████████| 938/938 [01:32<00:00, 10.11batch/s, Train Loss=0.389, Train Acc=85.8]


Epoch 8
Train Loss: 0.3894, Train Accuracy: 85.84%
Test Loss: 0.4052, Test Accuracy: 86.56%



Epoch 9: 100%|██████████| 938/938 [01:32<00:00, 10.10batch/s, Train Loss=0.38, Train Acc=86.2] 


Epoch 9
Train Loss: 0.3805, Train Accuracy: 86.18%
Test Loss: 0.4055, Test Accuracy: 86.60%



Epoch 10:  19%|█▉        | 177/938 [00:17<01:15, 10.02batch/s, Train Loss=0.371, Train Acc=86.6]