In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from conv_norm import PreConv

import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F

import numpy as np
from timeit import default_timer as timer
from utils import ImportanceSampler

USE_GPU = True
dtype = torch.float32 # We will be using float throughout this tutorial.

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

In [None]:
from utils import get_accuracy, load_dataset
from models import get_model
check_accuracy = lambda loader, model: get_accuracy(loader, model, device, dtype)

In [None]:
def train_model(model_name, dataset_name, model_params={}, hyperparams={}):

    learning_rate = hyperparams.get('lr', 1e-3)
    num_epochs = hyperparams.get('num_epochs', 10)
    weight_decay = hyperparams.get('weight_decay', 0)
    train_ratio = hyperparams.get('train_ratio', 0.8)
    batch_size = hyperparams.get('batch_size', 64)
    seed = hyperparams.get('seed', 0)
    imp_sampling = model_params.get('importance_sampling', False)
    gamma = model_params.get('gamma', 0.9)

    torch.manual_seed(seed)
    np.random.seed(seed)

    loader_train, loader_val, loader_test, num_train, num_channels = load_dataset(dataset_name, train_ratio, batch_size)
    model = get_model(model_name, model_params, learning_rate, loader_train, num_channels, device)

    print("Model architecture:")
    print(model)

    print(f'INFO: Training {model_name} on {dataset_name} with lr {learning_rate}, num_epochs={num_epochs}, weight_decay={weight_decay}')

    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)

    epoch_vals = []
    
    weight = torch.tensor([1.0]*num_train)

    t_acc, t_loss = check_accuracy(loader_train, model)
    val_acc, val_loss = check_accuracy(loader_val, model)
    
    start = timer()
    c_time = timer()-start

    print(f'Plot: Train, {0}, {t_loss:.3f}, {t_acc:.2f}, {c_time:.1f}')
    print(f'Plot: Val, {0}, {val_loss:.3f}, {val_acc:.2f}, {c_time:.1f}')

    for e in range(num_epochs):
        model.train()
        doUniform = (e == 0) or (imp_sampling == False)
        loader_train_sampled = loader_train
        if not doUniform:
            train_sampler = ImportanceSampler(num_train, weight, batch_size)
            loader_train_sampled, _, _, _, _ = load_dataset(dataset_name, train_ratio, batch_size, train_sampler)
        
        for t, tpl in enumerate(loader_train_sampled):
                torch.cuda.empty_cache()
                model.train()  # put model to training mode
                x = tpl[0].to(device=device, dtype=dtype)  # move to device, e.g. GPU
                y = tpl[1].to(device=device, dtype=torch.long)

                scores = model(x)
                loss = F.cross_entropy(scores, y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if not doUniform:
                    idx = tpl[2]
                    weight[idx] = gamma * weight[idx] + (1 - gamma) * float(loss)

        t_acc, t_loss = check_accuracy(loader_train, model)
        model.eval()
        val_acc, val_loss = check_accuracy(loader_val, model)
        c_time = timer()-start

        print(f'Plot: Train, {e+1}, {t_loss:.3f}, {t_acc:.2f}, {c_time:.1f}')
        print(f'Plot: Val, {e+1}, {val_loss:.3f}, {val_acc:.2f}, {c_time:.1f}')

    test_acc, test_loss = check_accuracy(loader_test, model)
    print(f'Plot: Test, {val_loss:.3f}, {val_acc:.2f}, {c_time:.1f}')

    return model

In [None]:
gradinit_params = {
    "gradinit_iters": 200,
    "gradinit_alg": "adam",
    "gradinit_lr": 1e-2,
    "gradinit_grad_clip": 1,
}
model_params = {
    "gradinit": gradinit_params,
    "convnorm" : {"mode_conv": [('first_frac', 0.25)], "mode_bn": [('first_frac', 0.25)]},
    "approx_mult" : {'mult_val' : 0.8, 'mode_linear' : [('last_num', 1)], 'mode_conv' : [('last_num', 1)]},
    # "importance_sampling" : True,

}
hyperparams = {
    "lr" : 3e-3,
    "num_epochs" : 25,
    "weight_decay" : 0,
    "train_ratio" : 0.8,
    "batch_size" : 256,
}

def test_setup():
    train_model('Resnet18', 'CIFAR100', model_params, hyperparams)

test_setup()