In [1]:
import torch

from tqdm import tqdm
from time import time

from vit import VisionTransformer
from utils import *
from dataloader import *
from decatt import DecattLoss

torch.manual_seed(0)

<torch._C.Generator at 0x14c625da0870>

In [2]:
device_ids = [i for i in range(torch.cuda.device_count())]
print(device_ids)

[0]


# CIFAR10

In [3]:
image_size = 32
patch_size = 4
in_channels = 3
mlp_dim = 512
num_classes = 10
num_layers = 12
dropout = 0.1
batch_size = 256

lr = 1e-3
weight_decay = 1e-3
num_epochs = 50

trainloader, testloader = cifar10_loaders(image_size, batch_size)

Files already downloaded and verified
Files already downloaded and verified


### ViT Baseline

In [4]:
num_heads = 6
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_baseline_cifar10", model, criterion, optimizer, num_epochs, trainloader, testloader)

# Best Validation Accuracy: 69.210%
# Time to Max Val Accuracy: 20.132

100%|██████████| 50/50 [30:55<00:00, 37.11s/it, Epoch=50, Train Accuracy=88.4, Training Loss=0.329, Validation Accuracy=78.4]

Best Validation Accuracy: 78.390%
Time to Max Val Accuracy: 30.920 mins





In [None]:
print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 11965642

### DeCAtt Loss

In [5]:
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

criterion = DecattLoss(num_heads)
optimizer1 = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.AdamW(model.module.transformer1.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_decatt_cifar10", model, criterion, optimizer1, num_epochs, trainloader, testloader, optimizer2=optimizer2)

# Best Validation Accuracy: 80.310%
# Time to Max Val Accuracy: 23.424 mins

100%|██████████| 50/50 [26:00<00:00, 31.22s/it, Epoch=50, Train Accuracy=92.4, Training Loss=50, Validation Accuracy=80]    

Best Validation Accuracy: 80.310%
Time to Max Val Accuracy: 23.424 mins





In [5]:
print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

Total num params: 4197994


# CIFAR100

In [3]:
image_size = 32
patch_size = 4
in_channels = 3
mlp_dim = 512
num_classes = 100
num_layers = 12
dropout = 0.1
batch_size = 256

lr = 1e-3
weight_decay = 1e-3
num_epochs = 50

trainloader, testloader = cifar100_loaders(image_size, batch_size)

### ViT Baseline

In [4]:
num_heads = 6
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_baseline_cifar100", model, criterion, optimizer, num_epochs, trainloader, testloader)

# Best Validation Accuracy: 40.380%
# Time to Max Val Accuracy: 24.442 mins

  0%|          | 0/50 [00:08<?, ?it/s]


KeyboardInterrupt: 

In [6]:
print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 11983012

Total num params: 11983012


### DeCAtt Loss

In [4]:
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

criterion = DecattLoss(num_heads)
optimizer1 = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.AdamW(model.module.transformer1.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_decatt_cifar100", model, criterion, optimizer1, num_epochs, trainloader, testloader, optimizer2=optimizer2)

# Best Validation Accuracy: 40.590%
# Time to Max Val Accuracy: 18.759 mins

 40%|████      | 20/50 [07:50<11:45, 23.53s/it, Epoch=20, Train Accuracy=27.3, Training Loss=52.7, Validation Accuracy=28.6]


KeyboardInterrupt: 

In [5]:
print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4206724

Total num params: 4206724
