In [1]:
from vit import Vit
import torch
import torchvision
from torch.utils.data import DataLoader, Subset
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch import nn
from torchvision import transforms
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [2]:
manual_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),  
])

In [3]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=manual_transforms
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=manual_transforms
)
batch_size = 16
subset_indices = torch.randperm(len(training_data))[:10000]
train_subset = Subset(training_data,subset_indices)

test_subset_indices = torch.randperm(len(test_data))[:10000]
test_subset = Subset(test_data,test_subset_indices)
train_dataloader = DataLoader(train_subset,batch_size=batch_size)
test_dataloader = DataLoader(test_subset,batch_size=batch_size)

In [6]:
model = Vit(img_size=224,
            in_channels=1,
            patch_size=16,
            num_transformer_layers=1,
            embedding_dim=384,
            mlp_size=1024,
            num_heads=12,
            attn_dropout=0,
            mlp_dropout=0.1,
            embedding_dropout=0.1,
            num_classes=len(training_data.classes))

In [7]:
sum(p.numel() for p in model.parameters())

1560074

In [8]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct, test_loss

def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset)
    avg_loss = 0
    model.train()
    for batch, (X,y) in  enumerate (dataloader):
        X,y = X.to(device),y.to(device)
        pred = model(X)
        loss = loss_fn(pred,y)

        #back propagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch % 100 == 0:
            avg_loss += loss.item()
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    return avg_loss
            

In [9]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4)
epochs = 500
model.to(device)
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("DONE!")

Epoch 1
-------------------------------
loss: 2.404041  [   16/10000]
loss: 2.429628  [  176/10000]
loss: 2.424626  [  336/10000]
loss: 2.199780  [  496/10000]
loss: 2.302136  [  656/10000]
loss: 2.120612  [  816/10000]
loss: 2.226849  [  976/10000]
loss: 2.191674  [ 1136/10000]
loss: 1.842687  [ 1296/10000]
loss: 1.873313  [ 1456/10000]
loss: 1.905916  [ 1616/10000]
loss: 1.553781  [ 1776/10000]
loss: 1.662016  [ 1936/10000]
loss: 1.164260  [ 2096/10000]
loss: 1.142552  [ 2256/10000]
loss: 1.208619  [ 2416/10000]
loss: 1.094178  [ 2576/10000]
loss: 1.331684  [ 2736/10000]
loss: 1.173684  [ 2896/10000]
loss: 1.398838  [ 3056/10000]
loss: 0.976620  [ 3216/10000]
loss: 0.876001  [ 3376/10000]
loss: 1.063350  [ 3536/10000]
loss: 1.174658  [ 3696/10000]
loss: 0.856165  [ 3856/10000]
loss: 1.534769  [ 4016/10000]
loss: 1.151906  [ 4176/10000]
loss: 0.729802  [ 4336/10000]
loss: 1.082898  [ 4496/10000]
loss: 0.870005  [ 4656/10000]
loss: 0.771620  [ 4816/10000]
loss: 1.047502  [ 4976/10000]
