In [1]:
import torch

from tqdm import tqdm
from time import time

from vit import VisionTransformer
from utils import *
from dataloader import *
from decatt import DecattLoss

torch.manual_seed(0)

<torch._C.Generator at 0x154ecd1c8890>

In [2]:
device_ids = [i for i in range(torch.cuda.device_count())]
print(device_ids)

[0]


# CIFAR10

In [None]:
image_size = 32
patch_size = 4
in_channels = 3
mlp_dim = 512
num_classes = 10
num_layers = 12
dropout = 0.1
batch_size = 256

lr = 1e-3
weight_decay = 1e-3
num_epochs = 50

trainloader, testloader = cifar10_loaders(image_size, batch_size)

### ViT Baseline

In [None]:
num_heads = 6
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 11965642

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_baseline_cifar10", model, criterion, optimizer, num_epochs, trainloader, testloader)

# Best Validation Accuracy: 78.390%
# Time to Max Val Accuracy: 30.920 mins

### DeCAtt Loss

In [None]:
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

criterion = DecattLoss(num_heads)
optimizer1 = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.AdamW(model.module.transformer1.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_decatt_cifar10", model, criterion, optimizer1, num_epochs, trainloader, testloader, optimizer2=optimizer2)

# Best Validation Accuracy: 80.310%
# Time to Max Val Accuracy: 23.424 mins

### Mixed-Resolution Tokenizer

In [None]:
image_size = 256
min_patch_size = 16
max_patch_size = 64
quadtree_num_patches = 64
num_heads = 3
embed_dim = 64 * num_heads

trainloader, testloader = cifar10_loaders(image_size, batch_size)

model = VisionTransformer(
    image_size, min_patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout,
    useMRT=True, max_patch_size=max_patch_size, quadtree_num_patches=quadtree_num_patches
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_mrt_cifar10", model, criterion, optimizer, num_epochs, trainloader, testloader)


### FlashAttention v1

In [None]:
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout, useFlash=True
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_flash_cifar10", model, criterion, optimizer, num_epochs, trainloader, testloader)

# Best Validation Accuracy: 79.190%
# Time to Max Val Accuracy: 18.252 mins

# CIFAR100

In [3]:
image_size = 32
patch_size = 4
in_channels = 3
mlp_dim = 512
num_classes = 100
num_layers = 12
dropout = 0.1
batch_size = 256

lr = 1e-3
weight_decay = 1e-3
num_epochs = 50

trainloader, testloader = cifar100_loaders(image_size, batch_size)

### ViT Baseline

In [4]:
num_heads = 6
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 11983012

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_baseline_cifar100", model, criterion, optimizer, num_epochs, trainloader, testloader)

# Best Validation Accuracy: 45.720%
# Time to Max Val Accuracy: 30.498 mins

100%|██████████| 50/50 [31:07<00:00, 37.34s/it, Epoch=50, Train Accuracy=70.1, Training Loss=1, Validation Accuracy=45.3]   

Best Validation Accuracy: 45.720%
Time to Max Val Accuracy: 30.498 mins





### DeCAtt Loss

In [5]:
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4206724

criterion = DecattLoss(num_heads)
optimizer1 = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.AdamW(model.module.transformer1.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_decatt_cifar100", model, criterion, optimizer1, num_epochs, trainloader, testloader, optimizer2=optimizer2)

# Best Validation Accuracy: 51.330%
# Time to Max Val Accuracy: 25.256 mins

Total num params: 4206724


100%|██████████| 50/50 [26:19<00:00, 31.59s/it, Epoch=50, Train Accuracy=83.4, Training Loss=50.3, Validation Accuracy=50.5]

Best Validation Accuracy: 51.330%
Time to Max Val Accuracy: 25.256 mins





### Mixed-Resolution Tokenizer

In [None]:
image_size = 256
min_patch_size = 16
max_patch_size = 64
quadtree_num_patches = 64
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, min_patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout,
    useMRT=True, max_patch_size=max_patch_size, quadtree_num_patches=quadtree_num_patches
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_mrt_cifar100", model, criterion, optimizer, num_epochs, trainloader, testloader)

### FlashAttention v1

In [None]:
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout, useFlash=True
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_flash_cifar100", model, criterion, optimizer, num_epochs, trainloader, testloader)

# Best Validation Accuracy: 50.500%
# Time to Max Val Accuracy: 17.197 mins