In [1]:
import os
os.chdir("../")

In [40]:
import wandb
import torch
import random
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms
from torch import optim
from tqdm import tqdm
from models.vt_resnet import vt_resnet18
from typing import Any

In [19]:
%load_ext autoreload
%autoreload 2

# Utility Functions

In [27]:
def random_seed(seed):
    """Set seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

def store_params(content, name):
    f = open(f'params/{name}.pkl','wb')
    pickle.dump(content, f)
    f.close()

def load_params(name):
    fl = open(f'params/{name}.pkl', "rb")
    loaded = pickle.load(fl)
    return loaded

def store_model(model, name):
    torch.save(model.state_dict(), f'./trained_models/{name}.pth')
                                

# Data Preparation

In [None]:
project_name = 'vt_resnet18'

random_seed(8)
input_dim = 224
num_classes = 20
class_list = random.sample(range(1000), num_classes)

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((input_dim, input_dim))
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    normalize
])

valid_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((input_dim, input_dim)),
    normalize
])

train_dataset = ImageNetDataset(train=True, num_classes=num_classes, transform=train_transform, class_list=class_list)
valid_dataset = ImageNetDataset(train=False, num_classes=num_classes, transform=valid_transform, class_list=class_list)

In [None]:
len(train_dataset), len(valid_dataset)

In [None]:
total_valid_num = len(valid_dataset)
total_train_num = len(train_dataset)
valid_num = int(0.4 * total_valid_num)

valid_mask = list(range(valid_num))
test_mask = list(range(valid_num, total_valid_num))

valid_loader = DataLoader(Subset(valid_dataset, valid_mask), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(Subset(valid_dataset, test_mask), batch_size=batch_size, shuffle=True)

small_train_mask = random.sample(range(total_train_num), 1200)
medium_train_mask = random.sample(range(total_train_num), 5000)
small_valid_mask = random.sample(range(total_valid_num), 200)

small_train_loader = DataLoader(Subset(train_dataset, list(small_train_mask)), batch_size=batch_size, 
                                shuffle=True, num_workers=2)
small_valid_loader = DataLoader(Subset(valid_dataset, list(small_valid_mask)), batch_size=batch_size, 
                                shuffle=True, num_workers=2)

medium_loader = DataLoader(Subset(train_dataset, list(medium_train_mask)), batch_size=batch_size, shuffle=True)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Sweep Configuration

In [None]:
sweep_config = {
    'method': 'grid', #grid, random
    'metric': {
      'name': 'valid_accuracy',
      'goal': 'maximize'   
    },
    'parameters': {
        'batch_size': {
            'values': [256, 128]
        },
        'learning_rate': {
            'values': [1e-2, 1e-3, 1e-4]
        },
        'transformer_enc_layers':{
            'values': [128, 256, 512]
        },
        'transformer_n_heads':{
          'values': [1, 2, 4]
        },
        'transformer_fc_dims':{
            'values': [512, 1024]
        },
        'tokens': {
            'values': [4, 8, 16]
        },
        'token_dims': {
            'values': [256, 512, 1024]
        }
        'optimizer': {
            'values': ['adam', 'sgd']
        },
        'weight_decay': {
            'values': [0, 4e-3, 4e-4, 4e-5]
        }
    }
}

# Initialize Sweep 

In [None]:
sweep_id = wandb.sweep(sweep_config, project=project_name)

# Training

In [41]:
def evaluate(model: nn.Module, data_loader: Any, device: torch.device, comment: str = ""):
    
    model.eval()
    
    total_samples = len(data_loader.dataset)
    correct_samples = 0
    total_loss = 0
    loss_history = []

    with torch.no_grad():
        for data, target in tqdm(data_loader):
            data = data.to(device)
            target = target.to(device)
            
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)

            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()
    
    avg_loss = total_loss / total_samples
    wandb.log({'valid_loss': avg_loss})
    
    accuracy = 100.0 * correct_samples / total_samples
    return accuracy

In [None]:
def train():
    
    config_defaults = {
        'epochs': 10,
        'batch_size': 128,
        'learning_rate': 0.001,
        'transformer_enc_layers': 2,
        'transformer_n_heads': 1,
        'transformer_fc_dims': 512,
        'transformer_dropout': 0.5,
        'tokens': 8,
        'token_dims': 512,
        'optimizer': 'adam',
        'weight_decay': 0,
        'input_dim': input_dim,
        'num_classes': num_classes,
    }
    
    wandb.init(config=config_defaults)
    config = wandb.config
    
    model = vt_resnet18(
        pretrained=True,
        tokens=L,
        token_channels=D,
        input_dim=input_dim,
        vt_channels=vt_channels,
        transformer_enc_layers=enc_layer,
        transformer_heads=n_head,
        transformer_fc_dim=fc_dim,
        transformer_dropout=dropout,
        num_classes=num_classes,
    )
    
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    
    if config.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    else:
        optimizer = optim.SGD(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay, momentum=0.9)
    
    for i in range(config.epochs):
        
        model.train()
        model.to(device)
        
        total_loss = 0
        for i, (data, label) in enumerate(train_loader):
            
            data, label = data.to(device), label.to(device)
            optimizer.zero_grad()
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, label)
            loss.backward()
            
            total_loss += loss.item()
            optimizer.step()
            
            wandb.log({'batch_loss': loss.item()})
        
        valid_accuracy = evaluate(model, valid_loader, device)
        train_accuracy = evaluate(model, train_loader, device)
        
        wandb.log({
            'loss': total_loss / config.batch_size,
            'valid_accuracy': valid_accuracy,
            'train_accuracy': train_accuracy
        })
        

# Run Sweep

In [None]:
wandb.agent(sweep_id, train)