In [None]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=1
import os, sys, time
import warnings
sys.path.insert(0, '..')
import lib

import numpy as np
from copy import deepcopy
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline

# For reproducibility
import random
seed = random.randint(0, 2 ** 32 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(seed)

In [None]:
from matplotlib.backends.backend_pdf import PdfPages
plt.style.use('seaborn-darkgrid')
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

# Setting

In [None]:
model_type = 'fixup_resnet'

# Dataset 
data_dir = './data'
train_batch_size = 128
valid_batch_size = 128
test_batch_size = 64
num_workers = 3
pin_memory = True

num_classes = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

loss_function = F.cross_entropy

# MAML
max_epochs = 2000
maml_n_var_batches = 190

maml_steps = 190
max_grad_norm = 4.
max_meta_parameters_grad_norm = 10.

loss_kwargs={'reduction':'mean'}

first_valid_step = 38
loss_interval = 38

assert (maml_steps - first_valid_step) % loss_interval == 0
validation_steps = (maml_steps - first_valid_step) / loss_interval + 1

# Optimizer
optimizer_type='momentum'
nesterov = True
learning_rate = 0.1
momentum = 0.9
weight_decay = 0.0005

# Meta optimizer
meta_betas = (0.9, 0.997)
meta_learning_rate = 0.0001

checkpoint_steps = 3
recovery_step = None


kwargs = dict(
    first_valid_step=first_valid_step,
    valid_loss_interval=loss_interval, 
    loss_kwargs=loss_kwargs, 
)

In [None]:
exp_name = f"FixupResNet18_CIFAR100_{optimizer_type}_lr{learning_rate}"
exp_name += f"_meta_rl{meta_learning_rate}_steps{maml_steps}_interval{loss_interval}"
exp_name += f"_var_batches{maml_n_var_batches}_tr_bs{train_batch_size}_val_bs{valid_batch_size}_seed_{seed}"
print("Experiment name: ", exp_name)

logs_path = "./logs/{}".format(exp_name)
assert recovery_step is not None or not os.path.exists(logs_path)
# !rm -rf {logs_path}

In [None]:
# Load train and valid data
from torchvision import transforms, datasets
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import TensorDataset, DataLoader

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
])

eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
])

train_dataset = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
valid_dataset = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=eval_transform)
test_set = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=eval_transform)

num_train = len(train_dataset)
indices = list(range(num_train))
split = 40000

if maml_n_var_batches * train_batch_size >= split:
    warnings.warn("Your training process involves one entire epoch")
    
np.random.shuffle(indices)
train_idx, valid_idx = indices[:split], indices[split:]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, sampler=train_sampler,
    num_workers=num_workers, pin_memory=pin_memory,
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=valid_batch_size, sampler=valid_sampler,
    num_workers=num_workers, pin_memory=pin_memory,
)

test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=test_batch_size, shuffle=False, 
    num_workers=num_workers, pin_memory=pin_memory
)

In [None]:
if optimizer_type == 'sgd':
    optimizer = lib.IngraphGradientDescent(learning_rate=learning_rate)
elif optimizer_type == 'momentum':
    optimizer = lib.IngraphMomentum(learning_rate=learning_rate, momentum=momentum,
                                    weight_decay=weight_decay, nesterov=nesterov)
elif optimizer_type == 'rmsprop':
    optimizer = lib.IngraphRMSProp(learning_rate=learning_rate, beta=beta, epsilon=epsilon)
elif optimizer_type == 'adam':
    optimizer = lib.IngraphAdam(learning_rate=learning_rate, beta2=beta2, beta1=beta1, epsilon=epsilon)
else: 
    raise NotImplemetedError("This optimizer is not implemeted")

model = lib.models.fixup_resnet.FixupResNet18(num_classes=num_classes)
maml = lib.MAML(model, model_type, optimizer=optimizer, 
    checkpoint_steps=checkpoint_steps,
    loss_function=loss_function
).to(device)

In [None]:
import torch.nn.parallel
from concurrent.futures import Future, ThreadPoolExecutor, as_completed, TimeoutError
GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=16)

def run_in_background(func: callable, *args, **kwargs) -> Future:
    """ run func(*args, **kwargs) in background and return Future for its outputs """
    return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)


def samples_batches(dataloader, num_batches):
    x_batches, y_batches = [], []
    for batch_i, (x_batch, y_batch) in enumerate(dataloader):
        if batch_i >= num_batches: break
        x_batches.append(x_batch)
        y_batches.append(y_batch) 
    return x_batches, y_batches


def compute_maml_loss(maml, x_batches, y_batches, x_val_batches, y_val_batches, device):
    with lib.training_mode(maml, is_train=True):
        maml.resample_parameters()
        updated_model, train_loss_history, valid_loss_history, *etc = \
            maml.forward(x_batches, y_batches, x_val_batches, y_val_batches, 
                         device=device, **kwargs)  
        train_loss = torch.cat(train_loss_history[first_valid_step:]).mean()
        valid_loss = torch.cat(valid_loss_history).mean() if len(valid_loss_history) > 0 else torch.zeros(1)
    return train_loss, valid_loss


@torch.no_grad()
def compute_test_loss(model, loss_function, test_loader, device='cuda'):
    model.eval()   
    test_loss, cls_error = 0., 0.
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        preds = model(x_test)
        test_loss += loss_function(preds, y_test) * x_test.shape[0]
        cls_error += 1. * (y_test != preds.argmax(axis=-1)).sum()
    test_loss /= len(test_loader.dataset)
    cls_error /= len(test_loader.dataset)
    model.train()
    return test_loss.item(), cls_error.item()

In [None]:
class TrainerResNet(lib.Trainer):
    def train_on_batch(self, train_loader, valid_loader, prefix='train/', **kwargs):
        """ Performs a single gradient update and reports metrics """    
        x_batches, y_batches = samples_batches(train_loader, maml_steps)
        x_val_batches, y_val_batches = samples_batches(valid_loader, validation_steps)

        self.meta_optimizer.zero_grad()
        
        train_loss, valid_loss = compute_maml_loss(self.maml, x_batches, y_batches, 
                                                   x_val_batches, y_val_batches, self.device)
        valid_loss.backward()
        
        global_grad_norm = nn.utils.clip_grad_norm_(list(self.maml.initializers.parameters()), 
                                                    max_meta_parameters_grad_norm)
        self.writer.add_scalar(prefix + "global_grad_norm", global_grad_norm, self.total_steps)
        
        for name, param in self.maml.initializers.named_parameters():
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                print("Outer Nan or inf in grads")
            grad_norms[name] = param.grad.norm(2).item()
            param.grad = torch.where(torch.isnan(param.grad), torch.zeros_like(param.grad), param.grad)
            param.grad = torch.where(torch.isinf(param.grad), torch.zeros_like(param.grad), param.grad)
            
        self.meta_optimizer.step()
        self.logs.append((self.total_steps, global_grad_norm))
        
        return self.record(train_loss=train_loss.item(),
                           valid_loss=valid_loss.item(), prefix=prefix)
    
    def parallel_train_on_batch(self, train_loader, valid_loader, prefix='train/', **kwargs):
        # generate training/validation batches for each device
        replica_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
        replicas = torch.nn.parallel.replicate(self.maml, devices=replica_devices, detach=False)

        replica_inputs = []
        for i, (replica, replica_device) in enumerate(zip(replicas, replica_devices)):
            replica.untrained_initializers = lib.switch_initializers_device(replica.untrained_initializers, 
                                                                            replica_device)

            x_batches, y_batches = samples_batches(train_loader, maml_steps)
            x_val_batches, y_val_batches = samples_batches(valid_loader, validation_steps)

            replica_inputs.append((replica, x_batches, y_batches,
                                   x_val_batches, y_val_batches, replica_device))
        
        replica_losses_futures = [run_in_background(compute_maml_loss, *replica_input)
                                  for replica_input in replica_inputs]
        
        replica_losses = [future.result() for future in replica_losses_futures]
        train_loss = sum(train_loss.item() for train_loss, _ in replica_losses) / len(replica_losses)
        valid_loss = sum(valid_loss.to(self.device) for _, valid_loss in replica_losses) / len(replica_losses)
        
        self.meta_optimizer.zero_grad()
        valid_loss.backward()
            
        global_grad_norm = nn.utils.clip_grad_norm_(list(self.maml.initializers.parameters()), 
                                                    max_meta_parameters_grad_norm)
        self.writer.add_scalar(prefix + "global_grad_norm", global_grad_norm, self.total_steps)
        
        for name, param in self.maml.initializers.named_parameters():
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                print("Outer Nan or inf in grads")
            grad_norms[name] = param.grad.norm(2).item()
            param.grad = torch.where(torch.isnan(param.grad), torch.zeros_like(param.grad), param.grad)
            param.grad = torch.where(torch.isinf(param.grad), torch.zeros_like(param.grad), param.grad)
        
        self.meta_optimizer.step()
        self.logs.append((self.total_steps, global_grad_norm))
        
        return self.record(train_loss=train_loss,
                           valid_loss=valid_loss.item(), prefix=prefix)
        
    @torch.no_grad()
    def plot_pdf(self):
        plt.figure(figsize=[22, 34])
        i = 0
        for name, (weight_maml_init, bias_maml_init) in self.maml.initializers.items():
            weight_base_init, _ = self.maml.untrained_initializers[name]
            base_mean = weight_base_init.mean.item()
            base_std = weight_base_init.std.item()
            maml_mean = weight_maml_init.mean.item()
            maml_std = weight_maml_init.std.item()
            
            base_init = torch.distributions.Normal(base_mean, base_std)
            maml_init = torch.distributions.Normal(maml_mean, maml_std)
            i += 1
            plt.subplot(6, 4, i)
            xx = np.linspace(min([base_mean - 3.*base_std, maml_mean - 3.*maml_std]), 
                             max([base_mean + 3.*base_std, maml_mean + 3.*maml_std]), 1000)
    
            if i == 12:
                yy = base_init.log_prob(torch.tensor(xx)).exp().numpy()
                plt.plot(xx, yy, '--', label='Fixup')
                yy = maml_init.log_prob(torch.tensor(xx)).exp().numpy()
                plt.plot(xx, yy, c='g', label='Fixup + DIMAML')
                leg = plt.legend(loc=4, fontsize=14.5, frameon=False)
                for line in leg.get_lines():
                    line.set_linewidth(1.6)
            else:
                yy = base_init.log_prob(torch.tensor(xx)).exp().numpy()
                plt.plot(xx, yy, '--')
                yy = maml_init.log_prob(torch.tensor(xx)).exp().numpy()
                plt.plot(xx, yy, c='g')
            
            plt.xticks(fontsize=12)
            plt.yticks(fontsize=12)
            plt.title(name + '_weight', fontsize=14)
        plt.show()
        
    def evaluate_metrics(self, train_loader, test_loader, prefix='val/', **kwargs):
        """ Predicts and evaluates metrics over the entire dataset """
        torch.cuda.empty_cache()
        
        print('Baseline')
        self.maml.resample_parameters(initializers=self.maml.untrained_initializers, is_final=True)
        base_model = deepcopy(self.maml.model)    
        base_train_loss_history, base_test_loss_history, base_test_error_history = \
            eval_model(base_model, train_loader, test_loader, epochs=1, device=self.device)
            
        print('Ours')
        self.maml.resample_parameters(is_final=True)
        maml_model = deepcopy(self.maml.model)
        maml_train_loss_history, maml_test_loss_history, maml_test_error_history = \
            eval_model(maml_model, train_loader, test_loader, epochs=1, device=self.device)
        
        draw_plots(base_train_loss_history, base_test_loss_history, base_test_error_history, 
                   maml_train_loss_history, maml_test_loss_history, maml_test_error_history)
        
        self.writer.add_scalar(prefix + "train_AUC", sum(maml_train_loss_history), self.total_steps)
        self.writer.add_scalar(prefix + "test_AUC", sum(maml_test_loss_history), self.total_steps)
        self.writer.add_scalar(prefix + "test_loss", maml_test_loss_history[-1], self.total_steps)
        self.writer.add_scalar(prefix + "test_cls_error", maml_test_error_history[-1], self.total_steps)    

In [None]:
########################
# Generate Train Batch #
########################
            
def generate_train_batches(train_loader, batches_in_epoch=150):
    x_batches, y_batches = [], []
    for batch_i, (x_batch, y_batch) in enumerate(train_loader):
        if batch_i >= batches_in_epoch: break
        x_batches.append(x_batch)
        y_batches.append(y_batch)

    assert len(x_batches) == len(y_batches) == batches_in_epoch

    local_x = torch.cat(x_batches, dim=0)
    local_y = torch.cat(y_batches, dim=0)
    local_dataset = TensorDataset(local_x, local_y)
    local_dataloader = DataLoader(local_dataset, batch_size=train_batch_size, 
                                  shuffle=True, num_workers=num_workers)
    return local_dataloader
        

##################
# Eval functions #
##################

def adjust_learning_rate(optimizer, epoch, milestones=[30, 50]):
    """decrease the learning rate at 30 and 50 epoch"""
    lr = learning_rate
    if epoch >= milestones[0]: 
        lr /= 10
    if epoch >= milestones[1]: 
        lr /= 10
    for param_group in optimizer.param_groups:
        if param_group['initial_lr'] == learning_rate:
            param_group['lr'] = lr
        else:
            if epoch < milestones[0]:
                param_group['lr'] = param_group['initial_lr']
            elif epoch < milestones[1]:
                param_group['lr'] = param_group['initial_lr'] / 10.
            else:
                param_group['lr'] = param_group['initial_lr'] / 100.
    return lr


def eval_model(model, train_loader, test_loader, epochs=3, test_loss_interval=40, device='cuda'):
    if optimizer_type == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    elif optimizer_type == 'momentum':
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, 
                                    momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
    elif optimizer_type == 'rmsprop':
        optimizer = torch.optim.RMSProp(parameters, lr=learning_rate, beta=beta, eps=epsilon)
    elif optimizer_type == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, 
                                     betas=(beta1, beta2), epsilon=epsilon)
    else: 
        raise NotImplemetedError("{} optimizer is not implemeted".format(optimizer_type))
        
    for param_group in optimizer.param_groups:
        param_group['initial_lr'] = learning_rate
        
    # Train loop
    train_loss_history = []
    test_loss_history = []
    test_error_history = []

    training_mode = model.training
    
    total_iters = 0
    for epoch in range(epochs):
        model.train()
        lr = adjust_learning_rate(optimizer, epoch)
        for i, (x_batch, y_batch) in enumerate(train_loader):
            optimizer.zero_grad()
            preds = model(x_batch.to(device))
            loss = loss_function(preds, y_batch.to(device))
            loss.backward()
            
            grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            
            if (total_iters == 0) or (total_iters + 1) % test_loss_interval == 0:
                train_loss_history.append(loss.item())
                model.eval()
                test_loss, test_error = compute_test_loss(model, loss_function, test_loader, device=device)
                print("Epoch {} | Train Loss {:.4f} | Test Loss {:.4f} | Classification Error {:.4f}"\
                      .format(epoch, loss.item(), test_loss, test_error))
                test_loss_history.append(test_loss)
                test_error_history.append(test_error)
                model.train()
            
            total_iters += 1
    
    model.train(training_mode)
    return train_loss_history, test_loss_history, test_error_history

    
def draw_plots(base_train_loss, base_test_loss, base_test_error,
               maml_train_loss, maml_test_loss, maml_test_error):
    plt.figure(figsize=(20, 6))
    plt.subplot(1,3,1)
    plt.plot(moving_average(base_train_loss, span=10), label='Baseline')
    plt.plot(moving_average(maml_train_loss, span=10), c='g', label='Ours')
    plt.legend(fontsize=14)
    plt.title("Train loss", fontsize=14)
    plt.subplot(1,3,2)
    plt.plot(base_test_loss, label='Baseline')
    plt.plot(maml_test_loss, c='g', label='Ours')
    plt.legend(fontsize=14)
    plt.title("Test loss", fontsize=14)
    plt.subplot(1,3,3)
    plt.plot(base_test_error, label='Baseline')
    plt.plot(maml_test_error, c='g', label='Ours')
    plt.legend(fontsize=14)
    plt.title("Test classification error", fontsize=14)                                           

In [None]:
from IPython.display import clear_output
from pandas import DataFrame

moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values
train_loss_history = []
valid_loss_history = []

trainer = TrainerResNet(maml, decay_interval=meta_decay_interval, meta_lr=meta_learning_rate, 
                        meta_betas=meta_betas, exp_name=exp_name, recovery_step=recovery_step)

In [None]:
t0 = time.time()

while trainer.total_steps <= max_epochs:
    lib.free_memory()
    metrics = trainer.parallel_train_on_batch(
        train_loader, valid_loader,
        first_valid_step=first_valid_step,
        valid_loss_interval=loss_interval, 
        loss_kwargs=loss_kwargs, 
    )
    train_loss = metrics['train_loss']
    train_loss_history.append(train_loss)
    
    valid_loss = metrics['valid_loss']
    valid_loss_history.append(valid_loss)
    
    if trainer.total_steps % 10 == 0:
        clear_output(True)
        print("Step: %d | Time: %f | Train Loss %.5f | Valid loss %.5f" 
              % (trainer.total_steps, time.time()-t0, train_loss, valid_loss))
        plt.figure(figsize=[16, 5])
        plt.subplot(1,2,1)
        plt.title('Train Loss over time')
        plt.plot(moving_average(train_loss_history, span=50))
        plt.scatter(range(len(train_loss_history)), train_loss_history, alpha=0.1)
        plt.subplot(1,2,2)
        plt.title('Valid Loss over time')
        plt.plot(moving_average(valid_loss_history, span=50))
        plt.scatter(range(len(valid_loss_history)), valid_loss_history, alpha=0.1)
        plt.show()
        local_train_loader = generate_train_batches(train_loader, maml_n_var_batches)
        trainer.evaluate_metrics(local_train_loader, test_loader, test_interval=20)
        trainer.plot_pdf()
        t0 = time.time()
        
    if trainer.total_steps % 100 == 0:
        trainer.save_model()
        
    trainer.total_steps += 1

## Plot probability distributions

In [None]:
trainer.plot_pdf()

# Evaluation

In [None]:
seed = random.randint(0, 2 ** 32 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
print(seed)

In [None]:
def gradient_quotient(loss, params, eps=1e-5): 
    grad = torch.autograd.grad(loss, params, retain_graph=True, create_graph=True)
    prod = torch.autograd.grad(sum([(g**2).sum() / 2 for g in grad]),
                               params, retain_graph=True, create_graph=True)
    out = sum([((g - p) / (g + eps * (2*(g >= 0).float() - 1).detach()) - 1).abs().sum() 
               for g, p in zip(grad, prod)])
    return out / sum([p.data.nelement() for p in params])


def metainit(model, criterion, x_size, y_size, lr=0.1, momentum=0.9, steps=200, eps=1e-5):
    model.eval()
    params = [p for p in model.parameters() 
              if p.requires_grad and len(p.size()) >= 2 and p.std().item() != 0]
    memory = [0] * len(params)
    for i in range(steps):
        input = torch.Tensor(*x_size).normal_(0, 1).cuda()
        target = torch.randint(0, y_size, (x_size[0],)).cuda()
        loss = criterion(model(input), target)
        gq = gradient_quotient(loss, list(model.parameters()), eps)
        
        grad = torch.autograd.grad(gq, params)
        for j, (p, g_all) in enumerate(zip(params, grad)):
            norm = p.data.norm().item()
            g = torch.sign((p.data * g_all).sum() / norm) 
            memory[j] = momentum * memory[j] - lr * g.item() 
            new_norm = norm + memory[j]
            p.data.mul_(new_norm / (norm + eps))
        print("%d/GQ = %.2f" % (i, gq.item()))

In [None]:
def genOrthgonal(dim):
    a = torch.zeros((dim, dim)).normal_(0, 1)
    q, r = torch.qr(a)
    d = torch.diag(r, 0).sign()
    diag_size = d.size(0)
    d_exp = d.view(1, diag_size).expand(diag_size, diag_size)
    q.mul_(d_exp)
    return q

def makeDeltaOrthogonal(weights, gain):
    rows = weights.size(0)
    cols = weights.size(1)
    if rows < cols:
        print("In_filters should not be greater than out_filters.")
    weights.data.fill_(0)
    dim = max(rows, cols)
    q = genOrthgonal(dim)
    mid1 = weights.size(2) // 2
    mid2 = weights.size(3) // 2
    with torch.no_grad():
        weights[:, :, mid1, mid2] = q[:weights.size(0), :weights.size(1)]
        weights.mul_(gain)

## Eval TinyImageNet

In [None]:
data_dir = 'data/tiny-imagenet-200/'
num_workers = {'train': 0, 'val': 0,'test': 0}
data_transforms = {
    'train': transforms.Compose([
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
    ])
}
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) 
                  for x in ['train', 'val','test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=128, shuffle=True, num_workers=num_workers[x])
                  for x in ['train', 'val', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}

In [None]:
ti_batches_in_epoch = len(dataloaders['train'])
assert ti_batches_in_epoch == 782
num_reruns = 10

reruns_base_test_loss_history = []
reruns_base_test_error_history = []

reruns_metainit_test_loss_history = []
reruns_metainit_test_error_history = []
    
reruns_maml_test_loss_history = []
reruns_maml_test_error_history = []

reruns_deltaorthogonal_test_loss_history = []
reruns_deltaorthogonal_test_error_history = []

for i in range(num_reruns):
    print(f"Rerun {i}")
    print("Ours")
    maml.resample_parameters(is_final=True)
    maml_model = deepcopy(maml.model)

    maml_model.fc = nn.Linear(in_features=512, out_features=200, bias=True).to(device)
    nn.init.constant_(maml_model.fc.weight, 0)
    nn.init.constant_(maml_model.fc.bias, 0)

    maml_train_loss_history, maml_test_loss_history, maml_test_error_history = \
         eval_model(maml_model, dataloaders['train'], dataloaders['test'],  
                    epochs=70, test_loss_interval=ti_batches_in_epoch, device=device)
    
    reruns_maml_test_loss_history.append(maml_test_loss_history)
    reruns_maml_test_error_history.append(maml_test_error_history)
    
    print("Baseline")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    base_model = deepcopy(maml.model)

    base_model.fc = nn.Linear(in_features=512, out_features=200, bias=True).to(device)
    nn.init.constant_(base_model.fc.weight, 0)
    nn.init.constant_(base_model.fc.bias, 0)

    base_train_loss_history, base_test_loss_history, base_test_error_history = \
        eval_model(base_model, dataloaders['train'], dataloaders['test'], 
                   epochs=70, test_loss_interval=ti_batches_in_epoch, device=device)

    reruns_base_test_loss_history.append(base_test_loss_history)
    reruns_base_test_error_history.append(base_test_error_history)
    
    print("MetaInit")
    batch_x, _ = next(iter(dataloaders['train']))
    batch_x = batch_x[:64]
    metainit_model = lib.models.metainit_resnet.MetaInitFixupResNet18(num_classes=200).to(device)
    metainit(metainit_model, loss_function, batch_x.shape, 200)
    
    metainit_train_loss_history, metainit_test_loss_history, metainit_test_error_history = \
        eval_model(metainit_model, dataloaders['train'], dataloaders['test'], 
                   epochs=70, test_loss_interval=ti_batches_in_epoch, device=device)
    
    reruns_metainit_test_loss_history.append(metainit_test_loss_history)
    reruns_metainit_test_error_history.append(metainit_test_error_history)
    
    print("DeltaOrthogonal")
    deltaorthogonal_model = lib.models.FixupResNet18(num_classes=200).to(device)
    for param in deltaorthogonal_model.parameters():
        if len(param.size()) >= 4:
            makeDeltaOrthogonal(param, nn.init.calculate_gain('leaky_relu'))
    
    deltaorthogonal_train_loss_history, deltaorthogonal_test_loss_history, deltaorthogonal_test_error_history = \
        eval_model(deltaorthogonal_model, dataloaders['train'], dataloaders['test'], 
                   epochs=70, test_loss_interval=ti_batches_in_epoch, device=device)
    
    reruns_deltaorthogonal_test_loss_history.append(deltaorthogonal_test_loss_history)
    reruns_deltaorthogonal_test_error_history.append(deltaorthogonal_test_error_history)

In [None]:
base_mean = np.array(reruns_base_test_error_history).mean(0)
base_std = np.array(reruns_base_test_error_history).std(0, ddof=1)

maml_mean = np.array(reruns_maml_test_error_history).mean(0)
maml_std = np.array(reruns_maml_test_error_history).std(0, ddof=1)

metainit_mean = np.array(reruns_metainit_test_error_history).mean(0)
metainit_std = np.array(reruns_metainit_test_error_history).std(0, ddof=1)

deltaorthogonal_mean = np.array(reruns_deltaorthogonal_test_error_history).mean(0)
deltaorthogonal_std = np.array(reruns_deltaorthogonal_test_error_history).std(0, ddof=1)

In [None]:
torch.save({'base_mean_std': (base_mean, base_std),
            'maml_mean_std': (maml_mean, maml_std),
            'metainit_mean_std': (metainit_mean, metainit_std),
            'deltaorthogonal_mean_std': (deltaorthogonal_mean, deltaorthogonal_std),
           }, "nips_cls_errors_tinyimagenet.pt")

In [None]:
plt.style.use('seaborn-darkgrid')

plt.figure(figsize=(8, 6))
x = np.arange(0, 71, 1)
plt.plot(x, moving_average(base_mean_relu, span=span), linewidth=1.1, label='Fixup')
plt.fill_between(x, moving_average(base_mean_relu, span=span) - moving_average(base_std_relu, span=span), 
                 moving_average(base_mean_relu, span=span) + moving_average(base_std_relu, span=span), alpha=0.12)

plt.plot(x, moving_average(delthaorthogonal_mean+0.001, span=span), linewidth=1.1, c='#7722dd', label='DeltaOrthogonal')

plt.plot(x, moving_average(metainit_mean_relu, span=span), linewidth=1.1, label='MetaInit')
plt.fill_between(x, moving_average(metainit_mean_relu, span=span) - moving_average(metainit_std_relu, span=span), 
                 moving_average(metainit_mean_relu, span=span) + moving_average(metainit_std_relu, span=span), alpha=0.12)

plt.plot(x, moving_average(fixup_metainit_mean, span=span), linewidth=1.1, c='r', label='Fixup $\\rightarrow$ MetaInit')

plt.plot(x, moving_average(maml_mean_relu, span=span), linewidth=1.1, c='g', label='Fixup $\\rightarrow$ DIMAML')
plt.fill_between(x, moving_average(maml_mean_relu, span=span) - moving_average(maml_std_relu, span=span), 
                 moving_average(maml_mean_relu, span=span) + moving_average(maml_std_relu, span=span), alpha=0.12, color='g')
plt.xlim([0, 20])
plt.ylim([0.4, 1.])
plt.yticks(np.arange(0.45, 1.01, 0.05), fontsize=11)
plt.xticks(x[:71][::5], fontsize=11)
plt.legend(fontsize=14)
plt.xlabel("Epoch", fontsize=14)
plt.ylabel("Classification Error", fontsize=14)

pp = PdfPages("nips_fixup_resnet_tiny_imagenet_darkgrid.pdf")
pp.savefig(bbox_inches='tight')
pp.close()
plt.show()

## ImageNet

In [None]:
import random
import shutil
import time

import torch.optim
import torch.utils.data
import torchvision.datasets as datasets
from torch.optim.lr_scheduler import CosineAnnealingLR

base_lr = 0.1
batch_size = 256
weight_decay = 1e-4
base_learning_rate = base_lr * batch_size / 256.

In [None]:
def train(train_loader, model, criterion, optimizer, epoch, val_loader=None,accuracies=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    for i, (inputs, targets) in enumerate(train_loader):
        # switch to train mode
        model.train()
        # measure data loading time
        data_time.update(time.time() - end)

        inputs = inputs.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        # compute output
        output = model(inputs)
        loss = criterion(output, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1[0], inputs.size(0))
        top5.update(acc5[0], inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        global_grad_norm = nn.utils.clip_grad_norm_(list(model.parameters()), 5)
        print(global_grad_norm)
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 500 == 0:
            accuracies.append(validate(val_loader, model, criterion).item())

def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1[0], input.size(0))
            top5.update(acc5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, epoch):
    # """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    # lr = args.base_lr * (0.1 ** (epoch // 30))
    #     for param_group in optimizer.param_groups:
    #     param_group['lr'] = lr

    for param_group in optimizer.param_groups:
        if param_group['initial_lr'] == base_learning_rate:
            print("adjust non-scalar lr.")
            lr = base_learning_rate * (0.1 ** (epoch // 30))
            param_group['lr'] = lr
        else:
            print("adjust scalar lr.")
            scalar_lr = param_group['initial_lr'] * (0.1 ** (epoch // 30))
            param_group['lr'] = scalar_lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [None]:
def train_model(model, epochs=100):
    accuracies = []
    
    best_acc1 = 0
    criterion = nn.CrossEntropyLoss().to(device)
    parameters_bias = [p[1] for p in model.named_parameters() if 'bias' in p[0]]
    parameters_scale = [p[1] for p in model.named_parameters() if 'scale' in p[0]]
    parameters_others = [p[1] for p in model.named_parameters() if not ('bias' in p[0] or 'scale' in p[0])]
    optimizer = torch.optim.SGD(
        [{'params': parameters_bias, 'lr': base_lr/10.},
        {'params': parameters_scale, 'lr': base_lr/10.},
        {'params': parameters_others}],
        lr=base_learning_rate, momentum=momentum, weight_decay=weight_decay)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join('YOUR_IMAGENET_PATH/imagenet', 'train')
    valdir = os.path.join('YOUR_IMAGENET_PATH/imagenet', 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))


    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 
                                               shuffle=True, num_workers=8, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)

    sgdr = CosineAnnealingLR(optimizer, 100, eta_min=0, last_epoch=-1)

    for epoch in range(epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, val_loader=val_loader, accuracies=accuracies)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer' : optimizer.state_dict(),
        }, is_best)
    return np.array(accuracies)

In [None]:
reruns = 10

reruns_base_accuracies = []
reruns_maml_accuracies = []
reruns_metainit_accuracies = []
reruns_metainit_fixup_accuracies = []
reruns_deltaorthogonal_accuracies = []

for rerun_id in range(reruns):
    print(f"Rerun #{rerun_id}")    
    print("Fixup")
    imagenet_base_model = lib.models.ImageNetFixupResNet18(num_classes=1000).to(device)
    base_accuracies = train_model(imagenet_base_model, epochs=3)
    reruns_base_accuracies.append(base_accuracies)
    lib.free_memory()
    
    print("MetaInit")
    imagenet_metainit_model = lib.models.MetaInitImageNetFixupResNet18(num_classes=1000).to(device)
    metainit(imagenet_metainit_model, loss_function, torch.Size([64, 3, 224, 224]), 1000, steps=100)
    metainit_accuracies = train_model(imagenet_metainit_model, epochs=3)
    reruns_metainit_accuracies.append(metainit_accuracies)
    lib.free_memory()
    
    print("Fixup --> MetaInit")
    imagenet_metainit_model = lib.models.ImageNetFixupResNet18(num_classes=1000).to(device)
    metainit(imagenet_metainit_model, loss_function, torch.Size([64, 3, 224, 224]), 1000, steps=100)
    metainit_fixup_accuracies = train_model(imagenet_metainit_model, epochs=3)
    reruns_metainit_fixup_accuracies.append(metainit_fixup_accuracies)
    lib.free_memory()
    
    print("Fixup --> DIMAML")
    imagenet_maml_model = lib.models.ImageNetFixupResNet18(num_classes=1000).to(device)
    for name, module in imagenet_maml_model.named_modules():
        if isinstance(module, nn.Conv2d) and 'conv2' not in name:
            key_name = '_'.join(name.split('.'))
            weight_initializer = maml.initializers[key_name][0]
            weights = weight_initializer(torch.rand_like(module.weight))
            module.weight = nn.Parameter(weights, requires_grad=True)

    maml_accuracies = train_model(imagenet_maml_model, epochs=3)
    reruns_maml_accuracies.append(maml_accuracies)
    lib.free_memory()
    
    print("DeltaOrthogonal")
    deltaorthogonal_model = lib.models.ImageNetFixupResNet18(num_classes=1000).to(device)
    for param in deltaorthogonal_model.parameters():
        if len(param.size()) >= 4:
            makeDeltaOrthogonal(param, nn.init.calculate_gain('leaky_relu'))
    
    deltaorthogonal_accuracies = train_model(deltaorthogonal_model, epochs=3)
    lib.free_memory()
    
    reruns_deltaorthogonal_accuracies.append(deltaorthogonal_accuracies)

In [None]:
base_mean_imagenet = ((100. - np.array(reruns_base_accuracies)) / 100).mean(0)
base_std_imagenet = ((100. - np.array(reruns_base_accuracies)) / 100).std(0, ddof=1)

metainit_mean_imagenet = ((100. - np.array(reruns_metainit_accuracies)) / 100).mean(0)
metainit_std_imagenet = ((100. - np.array(reruns_metainit_accuracies)) / 100).std(0, ddof=1)

maml_mean_imagenet = ((100. - np.array(reruns_maml_accuracies)) / 100).mean(0)
maml_std_imagenet = ((100. - np.array(reruns_maml_accuracies)) / 100).std(0, ddof=1)

metainit_fixup_mean_imagenet = ((100. - np.array(reruns_metainit_fixup_accuracies)) / 100).mean(0)
metainit_fixup_std_imagenet = ((100. - np.array(reruns_metainit_fixup_accuracies)) / 100).std(0, ddof=1)

deltaorthogonal_mean_imagenet = ((100. - np.array(reruns_deltaorthogonal_accuracies)) / 100).mean(0)
deltaorthogonal_std_imagenet = ((100. - np.array(reruns_deltaorthogonal_accuracies)) / 100).std(0, ddof=1)

In [None]:
plt.style.use('seaborn-darkgrid')

plt.figure(figsize=(8, 6))
span = 5
x = np.arange(0, 33, 1)

plt.plot(x, moving_average(base_mean_imagenet, span=span), linewidth=1.1, label='Fixup')
plt.fill_between(x, moving_average(base_mean_imagenet, span=span) - moving_average(base_std_imagenet, span=span), 
                 moving_average(base_mean_imagenet, span=span) + moving_average(base_std_imagenet, span=span), alpha=0.12)

plt.plot(x, moving_average(deltaorthogonal_mean_imagenet, span=span), linewidth=1.1, c='#7722dd', label='DeltaOrthogonal')
plt.fill_between(x, moving_average(deltaorthogonal_mean_imagenet, span=span) - moving_average(deltaorthogonal_std_imagenet, span=span), 
                 moving_average(deltaorthogonal_mean_imagenet, span=span) + moving_average(deltaorthogonal_std_imagenet, span=span), color='#7722dd', alpha=0.12)

plt.plot(x, moving_average(metainit_mean_imagenet, span=span), linewidth=1.1, label='MetaInit')
plt.fill_between(x, moving_average(metainit_mean_imagenet, span=span) - moving_average(metainit_std_imagenet, span=span), 
                 moving_average(metainit_mean_imagenet, span=span) + moving_average(metainit_std_imagenet, span=span), alpha=0.12)

plt.plot(x, moving_average(metainit_fixup_mean_imagenet, span=span), linewidth=1.1, c='r', label='Fixup $\\rightarrow$ MetaInit')
plt.fill_between(x, moving_average(metainit_fixup_mean_imagenet, span=span) - moving_average(metainit_fixup_std_imagenet, span=span), 
                 moving_average(metainit_fixup_mean_imagenet, span=span) + moving_average(metainit_fixup_std_imagenet, span=span), alpha=0.12, color='r')

plt.plot(x, moving_average(maml_mean_imagenet, span=span), linewidth=1.1, c='g', label='Fixup $\\rightarrow$ DIMAML')
plt.fill_between(x, moving_average(maml_mean_imagenet, span=span) - moving_average(maml_std_imagenet, span=span), 
                 moving_average(maml_mean_imagenet, span=span) + moving_average(maml_std_imagenet, span=span), alpha=0.12, color='g')


plt.xlim([0, 32])
plt.ylim([0.58, 1.])
plt.yticks(np.arange(0.6, 1.01, 0.05), fontsize=11)
plt.xticks(np.arange(0, 31, 3), np.arange(0, 15151, 1500), fontsize=11)

plt.legend(fontsize=14)
# plt.title("Test Classification Error", fontsize=14)
plt.xlabel("Iteration", fontsize=14)
plt.ylabel("Classification Error", fontsize=14)

# plt.gca().set_facecolor('xkcd:salmon')
# plt.gca().set_facecolor((1.0, 229/255., 204/255.))
# plt.gcf().set_size_inches(plt.gcf().get_size_inches()[1], plt.gcf().get_size_inches()[1])

pp = PdfPages("nips_fixup_resnet_imagenet_darkgrid.pdf")
pp.savefig(bbox_inches='tight')
pp.close()
plt.show()

In [None]:
torch.save({'base':reruns_base_accuracies,
            'maml':reruns_maml_accuracies, 
            'metainit':reruns_metainit_accuracies,
            'metainit_fixup':reruns_metainit_fixup_accuracies, 
            'deltaorthogonal':reruns_deltaorthogonal_accuracies}, 'nips_imagenet_accuracies.pt')