In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from time import time

from vit import VisionTransformer
from utils import save_stats
from dataloader import *
from decatt import DecattLoss

torch.manual_seed(0)

<torch._C.Generator at 0x14f6c4010870>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
def evaluate(model):
    model.eval()
    val_acc, total = 0, 0
    with torch.no_grad():
        for images, labels in testloader:

            outputs, _ = model(images.to(device))
            val_acc += (outputs.argmax(dim=1) == labels.to(device)).sum().item()
            total += labels.shape[0]
            
    val_acc = val_acc / total
    return val_acc


def train_baseline(path):
    torch.manual_seed(0)

    model = VisionTransformer(
        image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout
    ).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Train the model
    best_val_acc = 0
    epochs_no_improve, max_patience = 0, 20
    early_stop = False
    step = -1

    pbar = tqdm(range(num_epochs))
    
    start = time()
    for epoch in pbar:

        epoch_acc, epoch_loss, total = 0.0, 0.0, 0
        model.train()
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            step += 1

            outputs, _ = model(inputs.to(device))
            loss = criterion(outputs, labels.to(device))

            loss.backward()
            optimizer.step()
#             scheduler.step()

            epoch_acc += (outputs.argmax(dim=1) == labels.to(device)).sum().item()
            epoch_loss += loss.item()
            total += labels.shape[0]
        
        epoch_loss = epoch_loss / len(trainloader)
        epoch_acc = epoch_acc / total
        val_acc = evaluate(model)
        
        save_stats(epoch, {
            f"stats/{path}_valacc.txt": val_acc,
            f"stats/{path}_trainacc.txt": epoch_acc,
            f"stats/{path}_trainloss.txt": epoch_loss,
            f"stats/{path}_traintime.txt": time() - start
        })
        
        pbar.set_postfix({"Epoch": epoch+1, "Train Accuracy": epoch_acc*100, "Training Loss": epoch_loss, "Validation Accuracy": val_acc*100})

        # Save the best model

        if val_acc > best_val_acc:
            epochs_no_improve = 0
            best_val_acc = val_acc
            tta = time() - start
            
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer,
#                 'scheduler' : scheduler,
            },  f'saved_models/{path}.pth')

        else:
            epochs_no_improve += 1

        if epoch > 100 and epochs_no_improve >= max_patience:
            print('Early stopping!')
            early_stop = True
            break
    
    print(f"Best Validation Accuracy: {best_val_acc:.3f}%")
    print(f"Time to Max Val Accuracy: {tta / 60:.3f} mins")

In [4]:
def train_decatt(path):
    torch.manual_seed(0)

    model = VisionTransformer(
        image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout
    ).to(device)
    
    criterion = DecattLoss(num_heads)
    optimizer1 = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    optimizer2 = torch.optim.Adam(model.transformer1.parameters(), lr=lr, weight_decay=weight_decay)
#     scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(trainloader), epochs=num_epochs)

    # Train the model
    best_val_acc = 0
    epochs_no_improve, max_patience = 0, 20
    early_stop = False
    step = -1

    pbar = tqdm(range(num_epochs))
    
    start = time()
    for epoch in pbar:

        epoch_acc, epoch_loss, total = 0.0, 0.0, 0
        model.train()
        for inputs, labels in trainloader:
            optimizer1.zero_grad()
            optimizer2.zero_grad()
            step += 1

            outputs, attentions = model(inputs.to(device))
            ce_loss, decatt_loss = criterion(outputs, attentions, labels.to(device))

            ce_loss.backward(retain_graph=True)
            optimizer1.step()
            decatt_loss.backward()
            optimizer2.step()
#             scheduler.step()

            epoch_acc += (outputs.argmax(dim=1) == labels.to(device)).sum().item()
            epoch_loss += ce_loss.item() + decatt_loss.item()
            total += labels.shape[0]
        
        epoch_loss = epoch_loss / len(trainloader)
        epoch_acc = epoch_acc / total
        val_acc = evaluate(model)
        
        save_stats(epoch, {
            f"stats/{path}_valacc.txt": val_acc,
            f"stats/{path}_trainacc.txt": epoch_acc,
            f"stats/{path}_trainloss.txt": epoch_loss,
            f"stats/{path}_traintime.txt": time() - start
        })
        
        pbar.set_postfix({"Epoch": epoch+1, "Train Accuracy": epoch_acc*100, "Training Loss": epoch_loss, "Validation Accuracy": val_acc*100})

        # Save the best model

        if val_acc > best_val_acc:
            epochs_no_improve = 0
            best_val_acc = val_acc
            tta = time() - start
            
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer1': optimizer1,
                'optimizer2': optimizer2,
#                 'scheduler' : scheduler,
            },  f'saved_models/{path}.pth')

        else:
            epochs_no_improve += 1

        if epoch > 100 and epochs_no_improve >= max_patience:
            print('Early stopping!')
            early_stop = True
            break
    
    print(f"Best Validation Accuracy: {best_val_acc:.3f}%")
    print(f"Time to Max Val Accuracy: {tta / 60:.3f} mins")

# CIFAR10

In [5]:
image_size = 32
patch_size = 4
in_channels = 3
num_heads = 3
embed_dim = 64 * num_heads
mlp_dim = 512
num_classes = 10
num_layers = 12
dropout = 0.1
batch_size = 256

lr = 1e-3
weight_decay = 0.0001
num_epochs = 30

# image_size = 32
# patch_size = 4
# in_channels = 8
# num_heads = 3
# embed_dim = 512
# mlp_dim = 1024
# num_classes = 10
# num_layers = 4
# dropout = 0.1
# batch_size = 256

# lr = 3e-3
# weight_decay = 0.0001
# num_epochs = 30

In [6]:
trainloader, testloader = cifar10_loaders(image_size, batch_size)

Files already downloaded and verified
Files already downloaded and verified


### ViT Baseline

In [None]:
train_baseline("vit_baseline_cifar10")

  0%|          | 0/30 [00:00<?, ?it/s]

### DeCatt Loss

In [None]:
train_decatt("vit_decatt_cifar10")

 30%|███       | 9/30 [06:11<14:08, 40.38s/it, Epoch=9, Train Accuracy=57.6, Training Loss=51, Validation Accuracy=61.4]  