In [1]:
import os
os.chdir("../")

In [2]:
import wandb
import torch
import random
import time
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from torch import optim
from tqdm import tqdm
from models.vt_resnet20 import VTResNet20
from torch.utils.data import DataLoader, Subset
from typing import Any
from torchsummary import summary

from torchvision import datasets

In [3]:
%load_ext autoreload
%autoreload 2

# Utility Functions

In [4]:
def random_seed(seed):
    """Set seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

def store_params(content, name):
    f = open(f'params/{name}.pkl','wb')
    pickle.dump(content, f)
    f.close()

def load_params(name):
    fl = open(f'params/{name}.pkl', "rb")
    loaded = pickle.load(fl)
    return loaded

def store_model(model, name):
    torch.save(model.state_dict(), f'./trained_models/{name}.pth')
                                

# Data Preparation

In [5]:
project_name = 'vt_resnet'
cores = 12
random_seed(8)
input_dim = 32
batch_size = 128
num_classes = 10

In [6]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((input_dim, input_dim)),
    #transforms.RandomResizedCrop(input_dim),
    transforms.RandomHorizontalFlip(),
    normalize
])

valid_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((input_dim, input_dim)),
    normalize
])

train_dataset = datasets.CIFAR10(root='/datasets/cifar', train = True, transform = train_transform, download = False)
valid_dataset = datasets.CIFAR10(root='/datasets/cifar', train = False, transform = valid_transform, download = False)

In [7]:
len(train_dataset), len(valid_dataset)

(50000, 10000)

In [8]:
total_valid_num = len(valid_dataset)
total_train_num = len(train_dataset)
valid_num = int(0.5 * total_valid_num)

valid_mask = list(range(valid_num))
test_mask = list(range(valid_num, total_valid_num))

valid_loader = DataLoader(Subset(valid_dataset, valid_mask), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(Subset(valid_dataset, test_mask), batch_size=batch_size, shuffle=True)

small_train_mask = random.sample(range(total_train_num), 1200)
medium_train_mask = random.sample(range(total_train_num), 5000)
small_valid_mask = random.sample(range(total_valid_num), 200)

small_train_loader = DataLoader(Subset(train_dataset, list(small_train_mask)), batch_size=batch_size, 
                                shuffle=True, num_workers=2)
small_valid_loader = DataLoader(Subset(valid_dataset, list(small_valid_mask)), batch_size=batch_size, 
                                shuffle=True, num_workers=2)

medium_loader = DataLoader(Subset(train_dataset, list(medium_train_mask)), batch_size=batch_size, shuffle=True)

# Training

In [29]:
hyperparameters = {
    'epochs': 20,
    'vt_num_layers':3,
    'resnet_pretrained': True,
    'freeze_resnet': False,
    'batch_size': 128,
    'learning_rate': 0.0005,
    'vt_channels': 64,
    'transformer_enc_layers': 2,
    'transformer_n_heads': 1,
    'transformer_fc_dims': 128,
    'transformer_dropout': 0.5,
    'tokens': 4,
    'token_dims': 128,
    'optimizer': 'adam',
    'weight_decay': 8e-5,
    'input_dim': input_dim,
    'num_classes': num_classes,
}

In [25]:
def evaluate(model: nn.Module, data_loader: Any, device: torch.device, comment: str = ""):
    
    model.eval()
    
    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0
    loss_history = []

    with torch.no_grad():
        for data, target in tqdm(data_loader):
            data = data.to(device)
            target = target.to(device)
            
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)

            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()
    
    avg_loss = total_loss / total_samples
    wandb.log({'valid_loss': avg_loss})
    
    accuracy = 100.0 * correct_samples / total_samples
    return accuracy

In [26]:
def train(model, optimizer, epochs, data_loader, test_loader, device):
    
    wandb.watch(model, log="all", log_freq=10)
    
    full_start = time.time()
    for i in range(epochs):
        
        model.train()
        model.to(device)
        print(f"Starting Epoch {i}")
        
        total_loss = 0
        epoch_time = time.time()
        num_batches = 0
        for j, (data, label) in enumerate(data_loader):
            
            data, label = data.to(device), label.to(device)
            optimizer.zero_grad()
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, label)
            loss.backward()
            
            total_loss += loss.item()
            optimizer.step()
            num_batches += 1
            
            wandb.log({'batch_loss': loss.item()})
        print(f"Finished Epoch {i}")
        
        valid_accuracy = evaluate(model, test_loader, device)
        train_accuracy = evaluate(model, data_loader, device)
        
        print(f"Validation Accuracy: ", valid_accuracy)
        print(f"Training Accuracy: ", train_accuracy)
        
        wandb.log({
            'loss': total_loss / num_batches,
            'valid_accuracy': valid_accuracy,
            'train_accuracy': train_accuracy,
            'epoch_time_minutes': (time.time() - epoch_time) / 60
        })
    wandb.log({'full_run_time_minutes': (time.time() - full_start) / 60})
        

In [27]:
def train_model(hyperparameters):
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    
    with wandb.init(project=project_name, config=hyperparameters):
       
        config = wandb.config
        
        model = VTResNet20(
            vt_num_layers=config.vt_num_layers,
            tokens=config.tokens,
            token_channels=config.token_dims,
            input_dim=config.input_dim,
            vt_channels=config.vt_channels,
            transformer_enc_layers=config.transformer_enc_layers,
            transformer_heads=config.transformer_n_heads,
            transformer_fc_dim=config.transformer_fc_dims,
            transformer_dropout=config.transformer_dropout,
            num_classes=config.num_classes,
            resnet_pretrained=config.resnet_pretrained,
            freeze_resnet=config.freeze_resnet,
        )

        optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
           
        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=cores)    
        
        train(model, optimizer, config.epochs, train_loader, valid_loader, device)

        test_accuracy = evaluate(model, test_loader, device)
        
        sample = train_dataset[0][0].reshape(1, 3, config.input_dim, config.input_dim)
        
        wandb.log({'test_accuracy': test_accuracy})
    
    return model, test_accuracy

In [30]:
model, test_accuracy = train_model(hyperparameters)

cuda:0


Starting Epoch 0


  8%|▊         | 3/40 [00:00<00:01, 20.54it/s]

Finished Epoch 0


100%|██████████| 40/40 [00:01<00:00, 23.48it/s]
100%|██████████| 391/391 [00:09<00:00, 40.95it/s]

Validation Accuracy:  tensor(75.2800, device='cuda:0')
Training Accuracy:  tensor(77.6000, device='cuda:0')
Starting Epoch 1



  8%|▊         | 3/40 [00:00<00:01, 26.64it/s]

Finished Epoch 1


100%|██████████| 40/40 [00:01<00:00, 23.28it/s]
100%|██████████| 391/391 [00:10<00:00, 38.79it/s]

Validation Accuracy:  tensor(81.0200, device='cuda:0')
Training Accuracy:  tensor(83.4260, device='cuda:0')
Starting Epoch 2



  8%|▊         | 3/40 [00:00<00:01, 20.34it/s]

Finished Epoch 2


100%|██████████| 40/40 [00:01<00:00, 21.87it/s]
100%|██████████| 391/391 [00:09<00:00, 40.87it/s]

Validation Accuracy:  tensor(83.1400, device='cuda:0')
Training Accuracy:  tensor(86.1200, device='cuda:0')
Starting Epoch 3



  5%|▌         | 2/40 [00:00<00:02, 18.72it/s]

Finished Epoch 3


100%|██████████| 40/40 [00:01<00:00, 21.41it/s]
100%|██████████| 391/391 [00:09<00:00, 39.23it/s]


Validation Accuracy:  tensor(83.9200, device='cuda:0')
Training Accuracy:  tensor(87.4980, device='cuda:0')
Starting Epoch 4


  2%|▎         | 1/40 [00:00<00:07,  5.32it/s]

Finished Epoch 4


100%|██████████| 40/40 [00:02<00:00, 18.39it/s]
100%|██████████| 391/391 [00:10<00:00, 38.74it/s]


Validation Accuracy:  tensor(84.3200, device='cuda:0')
Training Accuracy:  tensor(88.3500, device='cuda:0')
Starting Epoch 5


  5%|▌         | 2/40 [00:00<00:01, 19.72it/s]

Finished Epoch 5


100%|██████████| 40/40 [00:01<00:00, 20.27it/s]
100%|██████████| 391/391 [00:09<00:00, 41.28it/s]

Validation Accuracy:  tensor(83.7600, device='cuda:0')
Training Accuracy:  tensor(87.3240, device='cuda:0')
Starting Epoch 6



  5%|▌         | 2/40 [00:00<00:01, 19.99it/s]

Finished Epoch 6


100%|██████████| 40/40 [00:01<00:00, 23.50it/s]
100%|██████████| 391/391 [00:09<00:00, 39.52it/s]

Validation Accuracy:  tensor(84.5800, device='cuda:0')
Training Accuracy:  tensor(89.3300, device='cuda:0')
Starting Epoch 7



  5%|▌         | 2/40 [00:00<00:01, 19.89it/s]

Finished Epoch 7


100%|██████████| 40/40 [00:01<00:00, 22.58it/s]
100%|██████████| 391/391 [00:09<00:00, 40.48it/s]

Validation Accuracy:  tensor(85.7000, device='cuda:0')
Training Accuracy:  tensor(90.4920, device='cuda:0')
Starting Epoch 8



  5%|▌         | 2/40 [00:00<00:02, 18.66it/s]

Finished Epoch 8


100%|██████████| 40/40 [00:01<00:00, 21.53it/s]
100%|██████████| 391/391 [00:09<00:00, 40.92it/s]

Validation Accuracy:  tensor(85.6400, device='cuda:0')
Training Accuracy:  tensor(90.9740, device='cuda:0')
Starting Epoch 9



  2%|▎         | 1/40 [00:00<00:04,  8.52it/s]

Finished Epoch 9


100%|██████████| 40/40 [00:01<00:00, 23.24it/s]
100%|██████████| 391/391 [00:09<00:00, 39.11it/s]

Validation Accuracy:  tensor(86.3000, device='cuda:0')
Training Accuracy:  tensor(91.0240, device='cuda:0')
Starting Epoch 10



  8%|▊         | 3/40 [00:00<00:01, 20.59it/s]

Finished Epoch 10


100%|██████████| 40/40 [00:01<00:00, 22.79it/s]
100%|██████████| 391/391 [00:09<00:00, 39.83it/s]

Validation Accuracy:  tensor(85.4200, device='cuda:0')
Training Accuracy:  tensor(91.4660, device='cuda:0')
Starting Epoch 11



  5%|▌         | 2/40 [00:00<00:02, 18.38it/s]

Finished Epoch 11


100%|██████████| 40/40 [00:02<00:00, 19.31it/s]
100%|██████████| 391/391 [00:10<00:00, 38.54it/s]

Validation Accuracy:  tensor(85.8200, device='cuda:0')
Training Accuracy:  tensor(91.2540, device='cuda:0')
Starting Epoch 12



  5%|▌         | 2/40 [00:00<00:02, 15.64it/s]

Finished Epoch 12


100%|██████████| 40/40 [00:01<00:00, 20.29it/s]
100%|██████████| 391/391 [00:09<00:00, 40.12it/s]

Validation Accuracy:  tensor(85.2000, device='cuda:0')
Training Accuracy:  tensor(91.2400, device='cuda:0')
Starting Epoch 13



  5%|▌         | 2/40 [00:00<00:02, 14.44it/s]

Finished Epoch 13


100%|██████████| 40/40 [00:01<00:00, 20.67it/s]
100%|██████████| 391/391 [00:09<00:00, 40.10it/s]

Validation Accuracy:  tensor(86.4400, device='cuda:0')
Training Accuracy:  tensor(92.9420, device='cuda:0')
Starting Epoch 14



  2%|▎         | 1/40 [00:00<00:07,  5.50it/s]

Finished Epoch 14


100%|██████████| 40/40 [00:01<00:00, 21.03it/s]
100%|██████████| 391/391 [00:09<00:00, 39.49it/s]

Validation Accuracy:  tensor(86.3400, device='cuda:0')
Training Accuracy:  tensor(92.7040, device='cuda:0')
Starting Epoch 15



  8%|▊         | 3/40 [00:00<00:01, 20.45it/s]

Finished Epoch 15


100%|██████████| 40/40 [00:01<00:00, 22.08it/s]
100%|██████████| 391/391 [00:10<00:00, 38.52it/s]

Validation Accuracy:  tensor(86.4400, device='cuda:0')
Training Accuracy:  tensor(92.4460, device='cuda:0')
Starting Epoch 16



  5%|▌         | 2/40 [00:00<00:01, 19.11it/s]

Finished Epoch 16


100%|██████████| 40/40 [00:02<00:00, 18.98it/s]
100%|██████████| 391/391 [00:10<00:00, 38.87it/s]


Validation Accuracy:  tensor(86.6400, device='cuda:0')
Training Accuracy:  tensor(93.4540, device='cuda:0')
Starting Epoch 17


  5%|▌         | 2/40 [00:00<00:02, 18.93it/s]

Finished Epoch 17


100%|██████████| 40/40 [00:01<00:00, 21.13it/s]
100%|██████████| 391/391 [00:09<00:00, 40.53it/s]

Validation Accuracy:  tensor(87., device='cuda:0')
Training Accuracy:  tensor(93.7740, device='cuda:0')
Starting Epoch 18



  5%|▌         | 2/40 [00:00<00:01, 19.22it/s]

Finished Epoch 18


100%|██████████| 40/40 [00:01<00:00, 21.38it/s]
100%|██████████| 391/391 [00:09<00:00, 39.29it/s]

Validation Accuracy:  tensor(86.5400, device='cuda:0')
Training Accuracy:  tensor(93.4360, device='cuda:0')
Starting Epoch 19



  2%|▎         | 1/40 [00:00<00:06,  6.50it/s]

Finished Epoch 19


100%|██████████| 40/40 [00:01<00:00, 21.03it/s]
100%|██████████| 391/391 [00:10<00:00, 39.01it/s]
  8%|▊         | 3/40 [00:00<00:01, 20.41it/s]

Validation Accuracy:  tensor(86.5600, device='cuda:0')
Training Accuracy:  tensor(93.6440, device='cuda:0')


100%|██████████| 40/40 [00:01<00:00, 22.26it/s]


VBox(children=(Label(value=' 0.02MB of 0.02MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch_loss,0.25741
_runtime,794.0
_timestamp,1615885155.0
_step,7882.0
valid_loss,0.45123
loss,0.20105
valid_accuracy,86.56
train_accuracy,93.644
epoch_time_minutes,0.66363
full_run_time_minutes,13.0903


0,1
batch_loss,█▄▄▄▃▂▃▃▂▃▃▃▂▂▂▂▂▃▂▂▂▁▂▂▂▁▂▂▁▂▂▂▁▂▂▂▁▂▁▂
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
valid_loss,█▇▆▅▅▄▅▃▅▃▅▃▄▃▄▂▄▂▄▂▄▂▄▂▅▂▄▁▄▁▄▁▄▁▄▁▄▁▄▄
loss,█▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
valid_accuracy,▁▄▆▆▆▆▇▇▇█▇▇▇███████
train_accuracy,▁▄▅▅▆▅▆▇▇▇▇▇▇██▇████
epoch_time_minutes,▄▅▅▄▇▃▆▁▃▃█▃▅▃▃█▄▅▂▇
full_run_time_minutes,▁


In [14]:
test_accuracy

tensor(84.1800, device='cuda:0')

In [15]:
# summary(model.to(torch.device('cuda:0')), (3, 224, 224))

RuntimeError: shape '[2, 64, 16, 16]' is invalid for input of size 1605632