In [8]:
# !git clone https://github.com/TomerRonen34/mixed-resolution-vit.git
# %mv mixed-resolution-vit/* ./
# %rm -rf mixed-resolution-vit

In [1]:
import torch

from tqdm import tqdm
from time import time

from vit import VisionTransformer
from utils import *
from dataloader import *
from decatt import DecattLoss

torch.manual_seed(0)

<torch._C.Generator at 0x14f5a4fa94f0>

In [2]:
device_ids = [i for i in range(torch.cuda.device_count())]
print(device_ids)

[0]


# CIFAR10

In [3]:
image_size = 32
patch_size = 4
in_channels = 3
mlp_dim = 512
num_classes = 10
num_layers = 12
dropout = 0.1
batch_size = 256

lr = 1e-3
weight_decay = 1e-3
num_epochs = 50

trainloader, testloader = cifar10_loaders(image_size, batch_size)

Files already downloaded and verified
Files already downloaded and verified


### DeCAtt & FlashAttention

In [4]:
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout, useFlash=True
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4206724

criterion = DecattLoss(num_heads)
optimizer1 = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.AdamW(model.module.transformer1.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_decatt_flash_cifar10", model, criterion, optimizer1, num_epochs, trainloader, testloader, optimizer2=optimizer2)

# Best Validation Accuracy: 79.620%
# Time to Max Val Accuracy: 18.886 mins

Total num params: 4197994


100%|██████████| 50/50 [20:30<00:00, 24.60s/it, Epoch=50, Train Accuracy=92.8, Training Loss=50, Validation Accuracy=78.9]  

Best Validation Accuracy: 79.620%
Time to Max Val Accuracy: 18.886 mins





### DeCAtt & MixedResolutionTokenizer

In [None]:
image_size = 256
min_patch_size = 16
max_patch_size = 64
quadtree_num_patches = 64
num_heads = 3
embed_dim = 64 * num_heads

trainloader, testloader = cifar10_loaders(image_size, batch_size)

model = VisionTransformer(
    image_size, min_patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout,
    useMRT=True, max_patch_size=max_patch_size, quadtree_num_patches=quadtree_num_patches
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

criterion = DecattLoss(num_heads)
optimizer1 = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.AdamW(model.module.transformer1.parameters(), lr=lr, weight_decay=weight_decay)

train("vit_decatt_mrt_cifar10", model, criterion, optimizer1, num_epochs, trainloader, testloader, optimizer2=optimizer2)

# Best Validation Accuracy: 76.700%
# Time to Max Val Accuracy: 103.270 mins

 22%|██▏       | 11/50 [24:14<1:25:53, 132.15s/it, Epoch=11, Train Accuracy=57.8, Training Loss=51, Validation Accuracy=58.2]  

### FlashAttention & MixedResolutionTokenizer

In [None]:
image_size = 256
min_patch_size = 16
max_patch_size = 64
quadtree_num_patches = 64
num_heads = 3
embed_dim = 64 * num_heads

trainloader, testloader = cifar10_loaders(image_size, batch_size)

model = VisionTransformer(
    image_size, min_patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout,
    useMRT=True, max_patch_size=max_patch_size, quadtree_num_patches=quadtree_num_patches, useFlash=True
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

train("vit_flash_mrt_cifar10", model, criterion, optimizer, num_epochs, trainloader, testloader)

Files already downloaded and verified
Files already downloaded and verified
Total num params: 4714506


 70%|███████   | 35/50 [1:46:20<44:59, 179.96s/it, Epoch=35, Train Accuracy=83.2, Training Loss=0.474, Validation Accuracy=75.1]  

### DeCAtt, FlashAttention & MixedResolutionTokenizer

In [None]:
image_size = 256
min_patch_size = 16
max_patch_size = 64
quadtree_num_patches = 64
num_heads = 3
embed_dim = 64 * num_heads

trainloader, testloader = cifar10_loaders(image_size, batch_size)

model = VisionTransformer(
    image_size, min_patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout,
    useMRT=True, max_patch_size=max_patch_size, quadtree_num_patches=quadtree_num_patches, useFlash=True
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

criterion = DecattLoss(num_heads)
optimizer1 = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.AdamW(model.module.transformer1.parameters(), lr=lr, weight_decay=weight_decay)

train("vit_all_cifar10", model, criterion, optimizer1, num_epochs, trainloader, testloader, optimizer2=optimizer2)

# CIFAR100

In [5]:
image_size = 32
patch_size = 4
in_channels = 3
mlp_dim = 512
num_classes = 100
num_layers = 12
dropout = 0.1
batch_size = 256

lr = 1e-3
weight_decay = 1e-3
num_epochs = 50

trainloader, testloader = cifar100_loaders(image_size, batch_size)

### DeCAtt & FlashAttention

In [6]:
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout, useFlash=True
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4206724

criterion = DecattLoss(num_heads)
optimizer1 = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.AdamW(model.module.transformer1.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_decatt_flash_cifar100", model, criterion, optimizer1, num_epochs, trainloader, testloader, optimizer2=optimizer2)

# Best Validation Accuracy: 49.980%
# Time to Max Val Accuracy: 19.812 mins

Total num params: 4206724


100%|██████████| 50/50 [20:12<00:00, 24.26s/it, Epoch=50, Train Accuracy=82, Training Loss=50.4, Validation Accuracy=48.3]  

Best Validation Accuracy: 49.980%
Time to Max Val Accuracy: 19.812 mins





### DeCAtt & MixedResolutionTokenizer

In [None]:
image_size = 256
min_patch_size = 16
max_patch_size = 64
quadtree_num_patches = 64
num_heads = 3
embed_dim = 64 * num_heads

trainloader, testloader = cifar10_loaders(image_size, batch_size)

model = VisionTransformer(
    image_size, min_patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout,
    useMRT=True, max_patch_size=max_patch_size, quadtree_num_patches=quadtree_num_patches
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

criterion = DecattLoss(num_heads)
optimizer1 = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.AdamW(model.module.transformer1.parameters(), lr=lr, weight_decay=weight_decay)

train("vit_decatt_mrt_cifar100", model, criterion, optimizer1, num_epochs, trainloader, testloader, optimizer2=optimizer2)

### DeCAtt, FlashAttention & MixedResolutionTokenizer

In [None]:
image_size = 256
min_patch_size = 16
max_patch_size = 64
quadtree_num_patches = 64
num_heads = 3
embed_dim = 64 * num_heads

trainloader, testloader = cifar10_loaders(image_size, batch_size)

model = VisionTransformer(
    image_size, min_patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout,
    useMRT=True, max_patch_size=max_patch_size, quadtree_num_patches=quadtree_num_patches, useFlash=True
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

criterion = DecattLoss(num_heads)
optimizer1 = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.AdamW(model.module.transformer1.parameters(), lr=lr, weight_decay=weight_decay)

train("vit_all_cifar100", model, criterion, optimizer1, num_epochs, trainloader, testloader, optimizer2=optimizer2)