In [1]:
import torch

from tqdm import tqdm
from time import time

from vit_flash import VisionTransformer
from utils import *
from dataloader import *
from decatt import DecattLoss

torch.manual_seed(0)

<torch._C.Generator at 0x1496d4acd4f0>

In [2]:
device_ids = [i for i in range(torch.cuda.device_count())]
print(device_ids)

[0]


# CIFAR10

In [3]:
image_size = 32
patch_size = 4
in_channels = 3
mlp_dim = 512
num_classes = 10
num_layers = 12
dropout = 0.1
batch_size = 256

lr = 1e-3
weight_decay = 1e-3
num_epochs = 50

trainloader, testloader = cifar10_loaders(image_size, batch_size)

Files already downloaded and verified
Files already downloaded and verified


### FlashAttention v1

In [4]:
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout, useFlash=True
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_flash_cifar10", model, criterion, optimizer, num_epochs, trainloader, testloader)

# Best Validation Accuracy: 79.190%
# Time to Max Val Accuracy: 18.252 mins

100%|██████████| 50/50 [19:50<00:00, 23.81s/it, Epoch=50, Train Accuracy=93.5, Training Loss=0.186, Validation Accuracy=78.8]

Best Validation Accuracy: 79.190%
Time to Max Val Accuracy: 18.252 mins





In [5]:
print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4197994

Total num params: 4197994


# CIFAR100

In [7]:
image_size = 32
patch_size = 4
in_channels = 3
mlp_dim = 512
num_classes = 100
num_layers = 12
dropout = 0.1
batch_size = 256

lr = 1e-3
weight_decay = 1e-3
num_epochs = 50

trainloader, testloader = cifar100_loaders(image_size, batch_size)

### FlashAttention v1

In [8]:
num_heads = 3
embed_dim = 64 * num_heads

model = VisionTransformer(
    image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=dropout, useFlash=True
)
model = torch.nn.DataParallel(model.cuda(), device_ids=device_ids).cuda()

print(f"Total num params: {sum(p.numel() for p in model.parameters())}")
# Total num params: 4206724

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
train("vit_flash_cifar100", model, criterion, optimizer, num_epochs, trainloader, testloader)

# Best Validation Accuracy: 50.500%
# Time to Max Val Accuracy: 17.197 mins

100%|██████████| 50/50 [20:01<00:00, 24.03s/it, Epoch=50, Train Accuracy=85.5, Training Loss=0.458, Validation Accuracy=49.7]

Best Validation Accuracy: 50.500%
Time to Max Val Accuracy: 17.197 mins



