In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from time import time

from vit import VisionTransformer
from utils import save_stats
from dataloader import *
from decatt import DecattLoss

torch.manual_seed(0)

<torch._C.Generator at 0x15467be8c850>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
def evaluate(model):
    model.eval()
    val_acc, total = 0, 0
    with torch.no_grad():
        for images, labels in testloader:

            outputs, _ = model(images.to(device))
            val_acc += (outputs.argmax(dim=1) == labels.to(device)).sum().item()
            total += labels.shape[0]
            
    val_acc = val_acc / total
    return val_acc


def train(criterion, path):

    model = VisionTransformer(
        image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout
    ).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
#     scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(trainloader), epochs=num_epochs)

    # Train the model
    best_val_acc = 0
    epochs_no_improve, max_patience = 0, 20
    early_stop = False
    step = -1

    pbar = tqdm(range(num_epochs))
    
    start = time()
    for epoch in pbar:

        epoch_acc, epoch_loss, total = 0.0, 0.0, 0
        model.train()
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            step += 1

            outputs, attentions = model(inputs.to(device))
            loss = criterion(outputs, attentions, labels.to(device))

            loss.backward()
            optimizer.step()
#             scheduler.step()

            epoch_acc += (outputs.argmax(dim=1) == labels.to(device)).sum().item()
            epoch_loss += loss.item()
            total += labels.shape[0]
        
        epoch_loss = epoch_loss / len(trainloader)
        epoch_acc = epoch_acc / total
        val_acc = evaluate(model)
        
        save_stats(epoch, val_acc, f"stats/{path}_valacc.txt")
        save_stats(epoch, epoch_acc, f"stats/{path}_trainacc.txt")
        save_stats(step, epoch_loss, f"stats/{path}_trainloss.txt")
        
        pbar.set_postfix({"Epoch": epoch+1, "Train Accuracy": epoch_acc*100, "Training Loss": epoch_loss, "Validation Accuracy": val_acc*100})

        # Save the best model

        if val_acc > best_val_acc:
            epochs_no_improve = 0
            best_val_acc = val_acc
            tta = time() - start
            
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer,
#                 'scheduler' : scheduler,
            },  f'saved_models/{path}.pth')

        else:
            epochs_no_improve += 1

        if epoch > 100 and epochs_no_improve >= max_patience:
            print('Early stopping!')
            early_stop = True
            break
    
    print(f"Best Validation Accuracy: {best_val_acc:.3f}%")
    print(f"Time to Max Val Accuracy: {tta / 60:.3f} mins")

# CIFAR10

In [5]:
image_size = 32
patch_size = 4
in_channels = 3
embed_dim = 512
num_heads = 8
mlp_dim = 1024
num_classes = 10
num_layers = 4
dropout = 0.1
batch_size = 256

lr = 0.003
weight_decay = 0.0001
num_epochs = 150

# image_size = 32
# patch_size = 4
# in_channels = 3
# embed_dim = 256
# num_heads = 4
# mlp_dim = 1024
# num_classes = 10
# num_layers = 6
# dropout = 0.1
# batch_size = 128

# lr = 0.0006
# weight_decay = 0.1
# num_epochs = 150

In [6]:
trainloader, testloader = cifar10_loaders(image_size, batch_size)

Files already downloaded and verified
Files already downloaded and verified


### ViT Baseline

In [None]:
path = "vit_baseline_cifar10"

criterion = nn.CrossEntropyLoss()

def loss_func(outputs, attentions, labels):
    return criterion(outputs, labels)

train(loss_func, path)

  2%|▏         | 3/150 [01:57<1:31:52, 37.50s/it, Epoch=3, Train Accuracy=22.2, Training Loss=2.01, Validation Accuracy=26.6]

### DeCatt Loss

In [None]:
path = "vit_decatt_cifar10"

criterion = DecattLoss(num_heads)
train(criterion, path)