# Vision Transformer Training

In [7]:
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from utils import EarlyStopper
from models import SimplifiedVisionTransformer

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 1024
learning_rate = 1e-3
hidden_channels = 64
num_transformer_layers = 2
num_heads = 8
mlp_ratio = 4

### Load Dataset

In [8]:
transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
print(next(iter(train_dataloader))[0].shape)

torch.Size([1024, 1, 28, 28])


### Load Vision Transformer, Loss & Optimizer

In [6]:
model = SimplifiedVisionTransformer(
  in_channels=1,
  hidden_channels=hidden_channels,
  out_channels=10,
  num_transformer_layers=num_transformer_layers,
  num_heads=num_heads,
  mlp_ratio=mlp_ratio,
  dropout=0.1,
  image_height=28,
  image_width=28
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
print(model(next(iter(train_dataloader))[0].to(device)).shape)

In [None]:
num_epochs = 1000
training_losses = []
training_accuracies = []
testing_losses = []
testing_accuracies = []
early_stopper = EarlyStopper(patience=10)

for epoch in tqdm.tqdm(range(num_epochs), desc="Training..."):
  model.train()

  epoch_loss = 0.0
  epoch_correct = 0
  epoch_total = 0

  # Training loop
  for x, y in train_dataloader:
    optimizer.zero_grad()
    output = model(x.float())
    