<a href="https://colab.research.google.com/github/Kushal-Nandha/CIFAR-10/blob/master/CIFAR_10_ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
!pip install vit_pytorch
from vit_pytorch import ViT
import torch.optim as optim

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting vit_pytorch
  Downloading vit_pytorch-1.2.1-py3-none-any.whl (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.3/87.3 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.1 (from vit_pytorch)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, vit_pytorch
Successfully installed einops-0.6.1 vit_pytorch-1.2.1


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Defining the parameters
batch_size = 4
img_size = 40
patch_size = 8
dim = 64
depth = 6
heads = 8
mlp_dim = 128
dropout = 0.1
lr = 1e-3
epochs = 5
momentum = 0.9

In [4]:
transform_train = transforms.Compose([
    transforms.Resize(40),
    transforms.RandomCrop(img_size, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [5]:
transform_test = transforms.Compose([
    transforms.Resize(40),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


In [6]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:06<00:00, 27793274.91it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [7]:
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [8]:
model = ViT(
    image_size = 40,
    patch_size = patch_size,
    num_classes = 10,
    dim = dim,
    depth = 6,
    heads = 8,
    mlp_dim = mlp_dim,
    dropout=0.1,
    emb_dropout=0.1,
    channels = 3
).to(device)

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = 0.001)

In [10]:
# Training
for epoch in range(epochs):
    model.train()
    train_loss = 0
    train_correct = 0
    for batch_idx, (data,target) in enumerate(trainloader):
        data = data.to(device)
        target = target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * data.size(0)
        _, predicted = torch.max(output.data, 1)
        train_correct += (predicted == target).sum().item()
    
    train_loss /= len(trainloader.dataset)
    train_acc = 100. * train_correct / len(trainloader.dataset)
    print(f'Epoch {epoch + 1}/{epochs} Training Loss: {train_loss:.6f}, Training Accuracy: {train_acc:.2f}%')
    
    

Epoch 1/5 Training Loss: 1.939782, Training Accuracy: 26.97%
Epoch 2/5 Training Loss: 1.842453, Training Accuracy: 31.46%
Epoch 3/5 Training Loss: 1.822674, Training Accuracy: 32.68%
Epoch 4/5 Training Loss: 1.819042, Training Accuracy: 32.35%
Epoch 5/5 Training Loss: 1.811189, Training Accuracy: 32.85%
