# Autoencoder model

Training is on Tiny ImageNet. Evaluation is on CelebA

In [None]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=1
import os, sys, time
import warnings
sys.path.insert(0, '..')
import lib

import math
import numpy as np
from copy import deepcopy
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline

# For reproducibility
import random
seed = random.randint(0, 2 ** 32 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(seed)

In [None]:
from matplotlib.backends.backend_pdf import PdfPages
plt.style.use('seaborn-darkgrid')
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

## Setting

In [None]:
model_type = 'AE'

# Dataset 
data_dir = './data'
train_batch_size = 128
valid_batch_size = 256
test_batch_size = 128
num_workers = 3
pin_memory = True

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# AE
latent_dim = 64 
loss_function = F.mse_loss
num_quantiles = 100

# MAML
max_epochs = 3000
batches_in_epoch = 200
maml_epochs = 3
maml_steps = batches_in_epoch * maml_epochs

max_meta_parameters_grad_norm = 10.

loss_kwargs={'reduction':'mean'}
loss_interval = 50
first_valid_step = 200

assert (maml_steps - first_valid_step) % loss_interval == 0
validation_steps = int((maml_steps - first_valid_step) / loss_interval + 1)

# Optimizer
optimizer_type='momentum'
nesterov = False

learning_rate = 0.01
momentum = 0.9
weight_decay = None

# Meta optimizer
meta_learning_rate = 0.001
meta_betas = (0.9, 0.997)
meta_decay_interval = max_epochs

checkpoint_steps = 5
recovery_step = None

kwargs = dict(
    first_valid_step=first_valid_step,
    valid_loss_interval=loss_interval, 
    loss_kwargs=loss_kwargs, 
)

In [None]:
exp_name = f"PLIF_{model_type}{latent_dim}_tiny_imagenet_{optimizer_type}_lr{learning_rate}" + \
           f"_meta_lr{meta_learning_rate}_steps{maml_steps}_interval{loss_interval}" + \
           f"_tr_bs{train_batch_size}_val_bs{valid_batch_size}_seed_{seed}"
        
print("Experiment name: ", exp_name)

logs_path = "./logs/{}".format(exp_name)
assert recovery_step is not None or not os.path.exists(logs_path)
# !rm -rf {logs_path}

In [None]:
class PixelNormalize(object):
    def __init__(self, mean_image, std_image):
        self.mean_image = mean_image
        self.std_image = std_image
        
    def __call__(self, image):
        normalized_image = (image - self.mean_image) / self.std_image
        return normalized_image

    
class Flip(object):
    def __call__(self, image):
        if random.random() > 0.5:
            return image.flip(-1)
        else:
            return image
        
        
class CustomTensorDataset(torch.utils.data.Dataset):
    """ TensorDataset with support of transforms """
    def __init__(self, *tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.tensors[0][index]
        
        if self.transform:
            x = self.transform(x)
        return x
    
    def __len__(self):
        return self.tensors[0].size(0)

In [None]:
# Load train and valid data
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

data_dir = 'data/tiny-imagenet-200/'

train_image_dataset =  datasets.ImageFolder(os.path.join(data_dir, 'train'), transforms.ToTensor())
train_images = torch.cat([train_image_dataset[i][0][None] for i in range(len(train_image_dataset))], dim=0)
mean_image = train_images.mean(0)
std_image = train_images.std(0)

train_transforms = transforms.Compose([
    Flip(),
    PixelNormalize(mean_image, std_image),
])

eval_transforms = transforms.Compose([
    PixelNormalize(mean_image, std_image),
])

train_dataset = CustomTensorDataset(train_images, transform=train_transforms)

if batches_in_epoch * train_batch_size >= len(train_dataset):
    warnings.warn("Your training process involves one entire epoch")
    
valid_image_dataset =  datasets.ImageFolder(os.path.join(data_dir, 'val'), transforms.ToTensor())
valid_images = torch.cat([valid_image_dataset[i][0][None] for i in range(len(valid_image_dataset))], dim=0)
valid_dataset = CustomTensorDataset(valid_images, transform=eval_transforms)

test_image_dataset =  datasets.ImageFolder(os.path.join(data_dir, 'test'), transforms.ToTensor())
test_images = torch.cat([test_image_dataset[i][0][None] for i in range(len(test_image_dataset))], dim=0)
test_dataset = CustomTensorDataset(test_images, transform=eval_transforms)


# Create data loaders
train_loader = DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True,
    num_workers=num_workers, pin_memory=pin_memory,
)

valid_loader = DataLoader(
    valid_dataset, batch_size=valid_batch_size, shuffle=True,
    num_workers=num_workers, pin_memory=pin_memory,
)

test_loader = DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False, 
    num_workers=num_workers, pin_memory=pin_memory
)

In [None]:
@torch.no_grad()
def compute_test_loss(model, loss_function, test_loader, device='cuda'):
    model.eval()   
    test_loss = 0.
    for batch_test in test_loader:
        if isinstance(batch_test, (list, tuple)):
            x_test = batch_test[0].to(device)
        elif isinstance(batch_test, torch.Tensor):
            x_test = batch_test.to(device)
        else:
            raise Exception("Wrong batch")
        preds = model(x_test)
        test_loss += loss_function(preds, x_test) * x_test.shape[0]
    test_loss /= len(test_loader.dataset)
    model.train()
    return test_loss.item()

In [None]:
if optimizer_type == 'sgd':
    optimizer = lib.IngraphGradientDescent(learning_rate=learning_rate)
elif optimizer_type == 'momentum':
    optimizer = lib.IngraphMomentum(learning_rate=learning_rate, momentum=momentum, 
                                    weight_decay=weight_decay, nesterov=nesterov)
elif optimizer_type == 'rmsprop':
    optimizer = lib.IngraphRMSProp(learning_rate=learning_rate, beta=beta, epsilon=epsilon)
elif optimizer_type == 'adam':
    optimizer = lib.IngraphAdam(learning_rate=learning_rate, beta2=beta2, beta1=beta1, epsilon=epsilon)
else: 
    raise NotImplemetedError("This optimizer is not implemeted")

model = lib.models.AE(latent_dim)
maml = lib.PLIF_MAML(model, model_type, optimizer=optimizer, 
    checkpoint_steps=checkpoint_steps,
    loss_function=loss_function,
    num_quantiles=num_quantiles
).to(device)

## Parallel Staff

In [None]:
import torch.nn.parallel
from concurrent.futures import Future, ThreadPoolExecutor, as_completed, TimeoutError
GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=16)

def run_in_background(func: callable, *args, **kwargs) -> Future:
    """ run func(*args, **kwargs) in background and return Future for its outputs """
    return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)


def samples_batches(dataloader, num_batches):
    x_batches = []
    for batch_i, x_batch in enumerate(dataloader):
        if batch_i >= num_batches: break
        x_batches.append(x_batch)
    return x_batches


def compute_maml_loss(maml, x_batches, y_batches, x_val_batches, y_val_batches, device):
    with lib.training_mode(maml, is_train=True):
        maml.resample_parameters()
        updated_model, train_loss_history, valid_loss_history, *etc = \
            maml.forward(x_batches, y_batches, x_val_batches, y_val_batches, 
                         device=device, **kwargs)  
        train_loss = torch.cat(train_loss_history[first_valid_step:]).mean()
        valid_loss = torch.cat(valid_loss_history).mean() if len(valid_loss_history) > 0 else torch.zeros(1)
    return train_loss, valid_loss

In [None]:
class TrainerAE(lib.Trainer):
    def train_on_batch(self, train_loader, valid_loader, prefix='train/', **kwargs):
        """ Performs a single gradient update and reports metrics """
        x_batches = []
        for _ in range(maml_epochs):
            x_batches.extend(samples_batches(train_loader, batches_in_epoch))
        x_val_batches = samples_batches(valid_loader, validation_steps)

        self.meta_optimizer.zero_grad()
        train_loss, valid_loss = compute_maml_loss(self.maml, x_batches, x_batches, 
                                                   x_val_batches, x_val_batches, self.device)
        valid_loss.backward()
        
        global_grad_norm = nn.utils.clip_grad_norm_(self.maml.initializers.parameters(), 
                                                    max_meta_parameters_grad_norm)
        self.writer.add_scalar(prefix + "global_grad_norm", global_grad_norm, self.total_steps)
        
        for i, param in enumerate(self.maml.initializers.parameters()):
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                print("Outer Nan or inf in grads")
            param.grad = torch.where(torch.isnan(param.grad), torch.zeros_like(param.grad), param.grad)
            param.grad = torch.where(torch.isinf(param.grad), torch.zeros_like(param.grad), param.grad)
            
        self.meta_optimizer.step()
        self.logs.append((self.total_steps, global_grad_norm))
        
        return self.record(train_loss=train_loss.item(),
                           valid_loss=valid_loss.item(), prefix=prefix)
    
    def parallel_train_on_batch(self, train_loader, valid_loader, prefix='train/', **kwargs):
        # generate training/validation batches for each device
        replica_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
        replicas = torch.nn.parallel.replicate(self.maml, devices=replica_devices, detach=False)

        replica_inputs = []
        for i, (replica, replica_device) in enumerate(zip(replicas, replica_devices)):
            replica.untrained_initializers = lib.switch_initializers_device(replica.untrained_initializers, 
                                                                            replica_device)            
            x_batches = []
            for _ in range(maml_epochs):
                x_batches.extend(samples_batches(train_loader, batches_in_epoch))
            x_val_batches = samples_batches(valid_loader, validation_steps)

            replica_inputs.append((replica, x_batches, x_batches,
                                   x_val_batches, x_val_batches, replica_device))
        
        self.meta_optimizer.zero_grad()
        replica_losses_futures = [run_in_background(compute_maml_loss, *replica_input)
                                  for replica_input in replica_inputs]
        
        replica_losses = [future.result() for future in replica_losses_futures]
        train_loss = sum(train_loss.item() for train_loss, _ in replica_losses) / len(replica_losses)
        valid_loss = sum(valid_loss.to(self.device) for _, valid_loss in replica_losses) / len(replica_losses)
        
        valid_loss.backward()
        
        global_grad_norm = nn.utils.clip_grad_norm_(list(self.maml.initializers.parameters()), 
                                                    max_meta_parameters_grad_norm)
        self.writer.add_scalar(prefix + "global_grad_norm", global_grad_norm, self.total_steps)
        
        for i, param in enumerate(self.maml.initializers.parameters()):
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                print("Outer Nan or inf in grads")
            param.grad = torch.where(torch.isnan(param.grad), torch.zeros_like(param.grad), param.grad)
            param.grad = torch.where(torch.isinf(param.grad), torch.zeros_like(param.grad), param.grad)
        
        self.meta_optimizer.step()
        self.logs.append((self.total_steps, global_grad_norm))
        
        return self.record(train_loss=train_loss,
                           valid_loss=valid_loss.item(), prefix=prefix)

    @torch.no_grad()
    def plot_quantile_functions(self):
        plt.figure(figsize=[20, 35])
        i = 0
        for name, (weight_quantile_function, bias_quantile_function) in self.maml.initializers.items():
            wq_init, bq_init = self.maml.untrained_initializers[name]
            i += 1
            plt.subplot(6, 4, i)
            xx = torch.linspace(0., 1., 1000).to(self.device)
            yy = wq_init(xx)
            plt.plot(lib.check_numpy(xx), lib.check_numpy(yy), '--')
            yy = weight_quantile_function(xx)
            plt.plot(lib.check_numpy(xx), lib.check_numpy(yy), c='g')
            plt.xlim([0, 1])
            plt.title(name + '_weight', fontsize=14)
            plt.yticks(fontsize=12)
            plt.xticks(fontsize=12)
        
            if i in [9, 10, 11, 12]:
                plt.xlabel("U(0,1) samples", fontsize=14)
                    
            if bias_quantile_function is not None:
                i += 1
                plt.subplot(6, 4, i)
                if i == 12:
                    plt.plot(lib.check_numpy(xx), lib.check_numpy(yy), '--', label='Kaiming')
                    yy = bias_quantile_function(xx)
                    plt.plot(lib.check_numpy(xx), lib.check_numpy(yy), c='g', label='DIMAML')
                    leg = plt.legend(loc=4, fontsize=15, frameon=False)
                    for line in leg.get_lines():
                        line.set_linewidth(1.6)
                else:
                    yy = bq_init(xx)
                    plt.plot(lib.check_numpy(xx), lib.check_numpy(yy), '--',)
                    yy = bias_quantile_function(xx)
                    plt.plot(lib.check_numpy(xx), lib.check_numpy(yy), c='g')
                plt.xlim([0, 1])
                plt.title(name + '_bias', fontsize=14)
                plt.yticks(fontsize=12)
                plt.xticks(fontsize=12)
            
                if i in [9, 10, 11, 12]:
                    plt.xlabel("U(0,1) samples", fontsize=14)
        plt.show()
        
    def evaluate_metrics(self, train_loader, test_loader, prefix='val/', **kwargs):
        """ Predicts and evaluates metrics over the entire dataset """
        torch.cuda.empty_cache()
        
        print('Baseline')
        self.maml.resample_parameters(initializers=self.maml.untrained_initializers, is_final=True)
        base_model = deepcopy(self.maml.model)    
        base_train_loss_history, base_test_loss_history = eval_model(base_model, train_loader, test_loader,
                                                                     device=self.device, **kwargs)
        print('Ours')
        self.maml.resample_parameters(is_final=True)
        maml_model = deepcopy(self.maml.model)
        maml_train_loss_history, maml_test_loss_history = eval_model(maml_model, train_loader, test_loader, 
                                                                     device=self.device,  **kwargs)
        draw_plots(base_train_loss_history, base_test_loss_history, 
                   maml_train_loss_history, maml_test_loss_history)
        
        self.writer.add_scalar(prefix + "test_AUC", sum(maml_test_loss_history), self.total_steps)
        self.writer.add_scalar(prefix + "test_loss", maml_test_loss_history[-1], self.total_steps)

In [None]:
########################
# Generate Train Batch #
########################
            
def generate_train_batches(train_loader, batches_in_epoch=150):
    x_batches = []
    for batch_i, x_batch in enumerate(train_loader):
        if batch_i >= batches_in_epoch: break
        x_batches.append(x_batch)

    assert len(x_batches) == batches_in_epoch
    local_x = torch.cat(x_batches, dim=0)
    return DataLoader(local_x, batch_size=train_batch_size, shuffle=True, 
                      num_workers=num_workers, pin_memory=pin_memory)

##################
# Eval functions #
##################


def eval_model(model, train_loader, test_loader, batches_in_epoch=150, 
               epochs=3, test_loss_interval=50, device='cuda', **kwargs):
    if optimizer_type == 'sgd':
        optimizer = torch.optim.SGD(maml.get_parameters(model), lr=learning_rate)
    elif optimizer_type == 'momentum':
        torch_weight_decay = 0.0 if weight_decay is None else weight_decay
        optimizer = torch.optim.SGD(maml.get_parameters(model), lr=learning_rate, 
                                    momentum=momentum, weight_decay=torch_weight_decay, nesterov=nesterov)
    elif optimizer_type == 'rmsprop':
        optimizer = torch.optim.RMSProp(maml.get_parameters(model), lr=learning_rate, beta=beta, eps=epsilon)
    elif optimizer_type == 'adam':
        optimizer = torch.optim.Adam(maml.get_parameters(model), lr=learning_rate, 
                                     betas=(beta1, beta2), epsilon=epsilon)
    else: 
        raise NotImplemetedError("{} optimizer is not implemeted".format(optimizer_type))
        
    # Train loop
    train_loss_history = []
    test_loss_history = []
    
    training_mode = model.training
    total_iters = 0
    for epoch in range(1, epochs + 1):
        model.train()
        for x_batch in train_loader:
            optimizer.zero_grad()
            x_batch = x_batch.to(device)
            preds = model(x_batch)
            loss = loss_function(preds, x_batch)
            loss.backward()
            optimizer.step()
            train_loss_history.append(loss.item())
            
            if (total_iters == 0) or (total_iters + 1) % test_loss_interval == 0: 
                model.eval()
                test_loss = compute_test_loss(model, loss_function, test_loader, device=device)
                print("Epoch {} | Total Iteration {} | Loss {}".format(epoch, total_iters+1, test_loss))
                test_loss_history.append(test_loss)
                model.train()
            
            total_iters += 1    
    model.train(training_mode)
    return train_loss_history, test_loss_history


def draw_plots(base_train_loss, base_test_loss, maml_train_loss, maml_test_loss):
    plt.figure(figsize=(16, 6))
    plt.subplot(1,2,1)
    plt.plot(moving_average(base_train_loss, span=10), label='Baseline')
    plt.plot(moving_average(maml_train_loss, span=10), c='g', label='Ours')
    plt.legend(fontsize=14)
    plt.title("Train loss", fontsize=14)
    plt.xlabel("Steps", fontsize=14)
    plt.ylabel("MSE", fontsize=14)
    
    plt.subplot(1,2,2)
    plt.plot(base_test_loss, label='Baseline')
    plt.plot(maml_test_loss, c='g', label='Ours')
    plt.legend(fontsize=14)
    plt.title("Test loss", fontsize=14)
    plt.xlabel("Steps", fontsize=14)
    plt.ylabel("MSE", fontsize=14)

In [None]:
from IPython.display import clear_output
from pandas import DataFrame

moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values
train_loss_history = []
valid_loss_history = []

trainer = TrainerAE(maml, decay_interval=meta_decay_interval, meta_lr=meta_learning_rate, 
                    meta_betas=meta_betas, exp_name=exp_name, recovery_step=recovery_step)

In [None]:
lib.free_memory()
t0 = time.time()

while trainer.total_steps <= max_epochs:
    local_train_loader = generate_train_batches(train_loader, batches_in_epoch)
    
    with lib.activate_context_batchnorm(maml.model):
        metrics = trainer.train_on_batch(
            local_train_loader, valid_loader,
            first_valid_step=first_valid_step,
            valid_loss_interval=loss_interval, 
            loss_kwargs=loss_kwargs, 
        )
    train_loss = metrics['train_loss']
    train_loss_history.append(train_loss)
    
    valid_loss = metrics['valid_loss']
    valid_loss_history.append(valid_loss)

    if trainer.total_steps % 10 == 0:
        clear_output(True)
        print("Step: %d | Time: %f | Train Loss %.5f | Valid loss %.5f" 
              % (trainer.total_steps, time.time()-t0, train_loss, valid_loss))
        plt.figure(figsize=[16, 5])
        plt.subplot(1,2,1)
        plt.title('Train Loss over time')
        plt.plot(moving_average(train_loss_history, span=50))
        plt.scatter(range(len(train_loss_history)), train_loss_history, alpha=0.1)
        plt.subplot(1,2,2)
        plt.title('Valid Loss over time')
        plt.plot(moving_average(valid_loss_history, span=50))
        plt.scatter(range(len(valid_loss_history)), valid_loss_history, alpha=0.1)
        plt.show()
        trainer.evaluate_metrics(local_train_loader, test_loader, epochs=maml_epochs,
                                 test_loss_interval=loss_interval)
        trainer.plot_quantile_functions()
        t0 = time.time()
        
    if trainer.total_steps % 100 == 0:
        trainer.save_model()
        
    trainer.total_steps += 1

## Quantile Functions 

In [None]:
trainer.plot_quantile_functions()

# Evaluation

In [None]:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

In [None]:
def gradient_quotient(loss, params, eps=1e-5): 
    grad = torch.autograd.grad(loss, params, retain_graph=True, create_graph=True)
    prod = torch.autograd.grad(sum([(g**2).sum() / 2 for g in grad]),
                               params, retain_graph=True, create_graph=True)
    out = sum([((g - p) / (g + eps * (2*(g >= 0).float() - 1).detach()) - 1).abs().sum() 
               for g, p in zip(grad, prod)])
    return out / sum([p.data.nelement() for p in params])


def metainit(model, criterion, x_size, lr=0.1, momentum=0.9, steps=200, eps=1e-5):
    model.eval()
    params = [p for p in model.parameters() 
              if p.requires_grad and len(p.size()) >= 2]
    memory = [0] * len(params)
    for i in range(steps):
        input = torch.Tensor(*x_size).normal_(0, 1).cuda()
        loss = criterion(model(input), input)
        gq = gradient_quotient(loss, list(model.parameters()), eps)
        
        grad = torch.autograd.grad(gq, params)
        for j, (p, g_all) in enumerate(zip(params, grad)):
            norm = p.data.norm().item()
            g = torch.sign((p.data * g_all).sum() / norm) 
            memory[j] = momentum * memory[j] - lr * g.item() 
            new_norm = norm + memory[j]
            p.data.mul_(new_norm / (norm + eps))
        print("%d/GQ = %.2f" % (i, gq.item()))

In [None]:
def genOrthgonal(dim):
    a = torch.zeros((dim, dim)).normal_(0, 1)
    q, r = torch.qr(a)
    d = torch.diag(r, 0).sign()
    diag_size = d.size(0)
    d_exp = d.view(1, diag_size).expand(diag_size, diag_size)
    q.mul_(d_exp)
    return q

def makeDeltaOrthogonal(weights, gain):
    rows = weights.size(0)
    cols = weights.size(1)
    if rows < cols:
        print("In_filters should not be greater than out_filters.")
    weights.data.fill_(0)
    dim = max(rows, cols)
    q = genOrthgonal(dim)
    mid1 = weights.size(2) // 2
    mid2 = weights.size(3) // 2
    with torch.no_grad():
        weights[:, :, mid1, mid2] = q[:weights.size(0), :weights.size(1)]
        weights.mul_(gain)

# Eval CelebA

In [None]:
# 1. Download this file into dataset_directory and unzip it:
#  https://drive.google.com/open?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM
# 2. Put the `img_align_celeba` directory into the `celeba` directory!
# 3. Dataset directory structure should look like this (required by ImageFolder from torchvision):
#  +- `dataset_directory`
#     +- celeba
#        +- train
#           +- images
#              +- 000001.jpg
#              +- 000002.jpg
#              +- ...
#        +- val
#           +- images
#              +- 000001.jpg
#              +- 000002.jpg
#              +- ...
#        +- test
#           +- images
#              +- 000001.jpg
#              +- 000002.jpg
#              +- ...

import pandas as pd
import shutil

celeba_data_dir = 'data/celeba/'
data = pd.read_csv(os.path.join(celeba_data_dir, 'list_eval_partition.csv'))

# Uncomment if you want to copy data again
# for partition in ['train', 'val', 'test']:
#     image_path = os.path.join(celeba_data_dir, partition)
#     !rm -rf {image_path}

try:
    for partition in ['train', 'val', 'test']:
        os.makedirs(os.path.join(celeba_data_dir, partition))
        os.makedirs(os.path.join(celeba_data_dir, partition, 'images'))

    for i in data.index:
        partition = data.loc[i].partition
        src_path = os.path.join(celeba_data_dir, 'img_align_celeba/img_align_celeba', data.loc[i].image_id)
        if partition == 0:
            shutil.copyfile(src_path, os.path.join(celeba_data_dir, 'train', 'images', data.loc[i].image_id))
        elif partition == 1:
            shutil.copyfile(src_path, os.path.join(celeba_data_dir, 'val', 'images', data.loc[i].image_id))
        elif partition == 2:
            shutil.copyfile(src_path, os.path.join(celeba_data_dir, 'test', 'images', data.loc[i].image_id))
            
except FileExistsError:
    print('\'train\', \'val\', \'test\' already exist. Probably, you do not want to copy data again')

In [None]:
celeba_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

celeba_train_dataset = datasets.ImageFolder(celeba_data_dir+'train', transform=celeba_transforms)
celeba_train_images = torch.cat([celeba_train_dataset[i][0][None] for i in range(len(celeba_train_dataset))])

celeba_mean_image = celeba_train_images.mean(0)
celeba_std_image = celeba_train_images.std(0)

celeba_train_images = (celeba_train_images - celeba_mean_image) / celeba_std_image

celeba_test_dataset = datasets.ImageFolder(celeba_data_dir+'test', celeba_transforms)
celeba_test_images = torch.cat([celeba_test_dataset[i][0][None] for i in range(len(celeba_test_dataset))])
celeba_test_images = (celeba_test_images - celeba_mean_image) / celeba_std_image

celeba_train_loader = torch.utils.data.DataLoader(celeba_train_images, batch_size=train_batch_size, shuffle=True,
                                                  pin_memory=pin_memory, num_workers=num_workers)
celeba_test_loader = torch.utils.data.DataLoader(celeba_test_images, batch_size=test_batch_size, 
                                                  pin_memory=pin_memory, num_workers=num_workers)

In [None]:
###################
# Evaluate models #
###################

num_reruns=10
celeba_batches_in_epoch = len(celeba_train_loader) # 1272 - full epoch

celeba_base_runs_10 = []
celeba_base_runs_50 = []
celeba_base_runs_100 = []

celeba_maml_runs_10 = []
celeba_maml_runs_50 = []
celeba_maml_runs_100 = []

celeba_deltaorthogonal_runs_10 = []
celeba_deltaorthogonal_runs_50 = []
celeba_deltaorthogonal_runs_100 = []

celeba_metainit_runs_10 = []
celeba_metainit_runs_50 = []
celeba_metainit_runs_100 = []

for _ in range(num_reruns):
    print("Baseline")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    base_model = deepcopy(maml.model)    
    base_train_loss_history, base_test_loss_history = \
        eval_model(base_model, celeba_train_loader, celeba_test_loader, 
                   batches_in_epoch=celeba_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*celeba_batches_in_epoch, device=device)

    print("Ours")
    maml.resample_parameters(is_final=True)
    maml_model = deepcopy(maml.model)
    maml_train_loss_history, maml_test_loss_history = \
        eval_model(maml_model, celeba_train_loader, celeba_test_loader, 
                   batches_in_epoch=celeba_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*celeba_batches_in_epoch, device=device)
    
    celeba_base_runs_10.append(base_test_loss_history[1])
    celeba_base_runs_50.append(base_test_loss_history[5])
    celeba_base_runs_100.append(base_test_loss_history[10])
    
    celeba_maml_runs_10.append(maml_test_loss_history[1])
    celeba_maml_runs_50.append(maml_test_loss_history[5])
    celeba_maml_runs_100.append(maml_test_loss_history[10])
    
    print("MetaInit")
    batch_x = next(iter(celeba_train_loader))
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    metainit_model = deepcopy(maml.model)
    metainit(metainit_model, loss_function, batch_x.shape, steps=200)

    metainit_train_loss_history, metainit_test_loss_history = \
        eval_model(metainit_model, celeba_train_loader, celeba_test_loader, 
                   batches_in_epoch=celeba_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*celeba_batches_in_epoch, device=device)
    
    print("Delta Orthogonal")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    deltaorthogonal_model = deepcopy(maml.model)
    for param in deltaorthogonal_model.parameters():
        if len(param.size()) >= 4:
            makeDeltaOrthogonal(param, nn.init.calculate_gain('relu'))
    
    deltaorthogonal_train_loss_history, deltaorthogonal_test_loss_history = \
        eval_model(deltaorthogonal_model, celeba_train_loader, celeba_test_loader, 
                   batches_in_epoch=celeba_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*celeba_batches_in_epoch, device=device)
    
    kaiming_train_loss_history, kaiming_test_loss_history = \
        eval_model(kaiming_model, celeba_train_loader, celeba_test_loader, 
                   batches_in_epoch=celeba_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*celeba_batches_in_epoch, device=device)
    
    celeba_deltaorthogonal_runs_10.append(deltaorthogonal_test_loss_history[1])
    celeba_deltaorthogonal_runs_50.append(deltaorthogonal_test_loss_history[5])
    celeba_deltaorthogonal_runs_100.append(deltaorthogonal_test_loss_history[10])
    
    celeba_metainit_runs_10.append(metainit_test_loss_history[1])
    celeba_metainit_runs_50.append(metainit_test_loss_history[5])
    celeba_metainit_runs_100.append(metainit_test_loss_history[10])

In [None]:
print("Baseline 10 epoch: ", np.mean(celeba_base_runs_10), np.std(celeba_base_runs_10, ddof=1))
print("Baseline 50 epoch: ", np.mean(celeba_base_runs_50), np.std(celeba_base_runs_50, ddof=1))
print("Baseline 100 epoch: ", np.mean(celeba_base_runs_100), np.std(celeba_base_runs_100, ddof=1))
print()
print("DIMAML 10 epoch: ", np.mean(celeba_maml_runs_10), np.std(celeba_maml_runs_10, ddof=1))
print("DIMAML 50 epoch: ", np.mean(celeba_maml_runs_50), np.std(celeba_maml_runs_50, ddof=1))
print("DIMAML 100 epoch: ", np.mean(celeba_maml_runs_100), np.std(celeba_maml_runs_100, ddof=1))

print("MetaInit 10 epoch: ", np.mean(celeba_metainit_runs_10), np.std(celeba_metainit_runs_10, ddof=1))
print("MetaInit 50 epoch: ", np.mean(celeba_metainit_runs_50), np.std(celeba_metainit_runs_50, ddof=1))
print("MetaInit 100 epoch: ", np.mean(celeba_metainit_runs_100), np.std(celeba_metainit_runs_100, ddof=1))
print()
print("DeltaOrthogonal 10 epoch: ", np.mean(celeba_deltaorthogonal_runs_10), np.std(celeba_deltaorthogonal_runs_10, ddof=1))
print("DeltaOrthogonal 50 epoch: ", np.mean(celeba_deltaorthogonal_runs_50), np.std(celeba_deltaorthogonal_runs_50, ddof=1))
print("DeltaOrthogonal 100 epoch: ", np.mean(celeba_deltaorthogonal_runs_100), np.std(celeba_deltaorthogonal_runs_100, ddof=1))

## AnimeFaces

In [None]:
animefaces_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

animefaces_dataset = datasets.ImageFolder('data/AnimeFaces/', transform=animefaces_transforms)
animefaces_images = torch.cat([animefaces_dataset[i][0][None] for i in range(len(animefaces_dataset))])
animefaces_train_images = animefaces_images[:19000]

animefaces_mean_image = animefaces_train_images.mean(0)
animefaces_std_image = animefaces_train_images.std(0)

animefaces_train_images = (animefaces_train_images - animefaces_mean_image) / animefaces_std_image
animefaces_train_dataset = CustomTensorDataset(animefaces_train_images, transform=Flip())

animefaces_test_images = animefaces_images[19000:]
animefaces_test_images = (animefaces_test_images - animefaces_mean_image) / animefaces_std_image

animefaces_train_loader = torch.utils.data.DataLoader(animefaces_train_dataset, batch_size=train_batch_size, shuffle=True,
                                                  pin_memory=pin_memory, num_workers=num_workers)
animefaces_test_loader = torch.utils.data.DataLoader(animefaces_test_images, batch_size=test_batch_size, 
                                                  pin_memory=pin_memory, num_workers=num_workers)

In [None]:
###################
# Evaluate models #
###################

num_reruns=10
animefaces_batches_in_epoch = len(animefaces_train_loader) # 1272 - full epoch

animefaces_base_runs_10 = []
animefaces_base_runs_50 = []
animefaces_base_runs_100 = []

animefaces_maml_runs_10 = []
animefaces_maml_runs_50 = []
animefaces_maml_runs_100 = []

animefaces_deltaorthogonal_runs_10 = []
animefaces_deltaorthogonal_runs_50 = []
animefaces_deltaorthogonal_runs_100 = []

animefaces_metainit_runs_10 = []
animefaces_metainit_runs_50 = []
animefaces_metainit_runs_100 = []

for _ in range(num_reruns):
    print("Baseline")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    base_model = deepcopy(maml.model)    
    base_train_loss_history, base_test_loss_history = \
        eval_model(base_model, animefaces_train_loader, animefaces_test_loader, 
                   batches_in_epoch=animefaces_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*animefaces_batches_in_epoch, device=device)

    print("Ours")
    maml.resample_parameters(is_final=True)
    maml_model = deepcopy(maml.model)
    maml_train_loss_history, maml_test_loss_history = \
        eval_model(maml_model, animefaces_train_loader, animefaces_test_loader, 
                   batches_in_epoch=animefaces_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*animefaces_batches_in_epoch, device=device)
    
    print("MetaInit")
    batch_x = next(iter(animefaces_train_loader))
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    metainit_model = deepcopy(maml.model)
    metainit(metainit_model, loss_function, batch_x.shape, steps=200)

    metainit_train_loss_history, metainit_test_loss_history = \
        eval_model(metainit_model, animefaces_train_loader, animefaces_test_loader, 
                   batches_in_epoch=animefaces_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*animefaces_batches_in_epoch, device=device)
    
    print("Delta Orthogonal")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    deltaorthogonal_model = deepcopy(maml.model)
    for param in deltaorthogonal_model.parameters():
        if len(param.size()) >= 4:
            makeDeltaOrthogonal(param, nn.init.calculate_gain('relu'))
    
    deltaorthogonal_train_loss_history, deltaorthogonal_test_loss_history = \
        eval_model(deltaorthogonal_model, animefaces_train_loader, animefaces_test_loader, 
                   batches_in_epoch=animefaces_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*animefaces_batches_in_epoch, device=device)
    
    animefaces_base_runs_10.append(base_test_loss_history[1])
    animefaces_base_runs_50.append(base_test_loss_history[5])
    animefaces_base_runs_100.append(base_test_loss_history[10])
    
    animefaces_maml_runs_10.append(maml_test_loss_history[1])
    animefaces_maml_runs_50.append(maml_test_loss_history[5])
    animefaces_maml_runs_100.append(maml_test_loss_history[10])
    
    animefaces_deltaorthogonal_runs_10.append(deltaorthogonal_test_loss_history[1])
    animefaces_deltaorthogonal_runs_50.append(deltaorthogonal_test_loss_history[5])
    animefaces_deltaorthogonal_runs_100.append(deltaorthogonal_test_loss_history[10])
    
    animefaces_metainit_runs_10.append(metainit_test_loss_history[1])
    animefaces_metainit_runs_50.append(metainit_test_loss_history[5])
    animefaces_metainit_runs_100.append(metainit_test_loss_history[10])

In [None]:
print("Baseline 10 epoch: ", np.mean(animefaces_base_runs_10), np.std(animefaces_base_runs_10, ddof=1))
print("Baseline 50 epoch: ", np.mean(animefaces_base_runs_50), np.std(animefaces_base_runs_50, ddof=1))
print("Baseline 100 epoch: ", np.mean(animefaces_base_runs_100), np.std(animefaces_base_runs_100, ddof=1))
print()
print("DIMAML 10 epoch: ", np.mean(animefaces_maml_runs_10), np.std(animefaces_maml_runs_10, ddof=1))
print("DIMAML 50 epoch: ", np.mean(animefaces_maml_runs_50), np.std(animefaces_maml_runs_50, ddof=1))
print("DIMAML 100 epoch: ", np.mean(animefaces_maml_runs_100), np.std(animefaces_maml_runs_100, ddof=1))
print()
print("MetaInit 10 epoch: ", np.mean(animefaces_metainit_runs_10), np.std(animefaces_metainit_runs_10, ddof=1))
print("MetaInit 50 epoch: ", np.mean(animefaces_metainit_runs_50), np.std(animefaces_metainit_runs_50, ddof=1))
print("MetaInit 100 epoch: ", np.mean(animefaces_metainit_runs_100), np.std(animefaces_metainit_runs_100, ddof=1))
print()
print("DeltaOrthogonal 10 epoch: ", np.mean(animefaces_deltaorthogonal_runs_10), np.std(animefaces_deltaorthogonal_runs_10, ddof=1))
print("DeltaOrthogonal 50 epoch: ", np.mean(animefaces_deltaorthogonal_runs_50), np.std(animefaces_deltaorthogonal_runs_50, ddof=1))
print("DeltaOrthogonal 100 epoch: ", np.mean(animefaces_deltaorthogonal_runs_100), np.std(animefaces_deltaorthogonal_runs_100, ddof=1))

## Permuted AnimeFaces

In [None]:
permutation = torch.load('nips_animefaces_permutation.pt')#torch.randperm(64*64)
permuted_animefaces_train_images = []
for image in animefaces_train_images:
    for i in range(3):
        image[i] = image[i].view(-1)[permutation].view(64, 64) # view !!!!!!!!!!
    permuted_animefaces_train_images.append(image[None])
permuted_animefaces_train_images = torch.cat(permuted_animefaces_train_images, dim=0)

permuted_animefaces_test_images = []
for image in animefaces_test_images:
    for i in range(3):
        image[i] = image[i].view(-1)[permutation].view(64, 64)  # view !!!!!!!!!!
    permuted_animefaces_test_images.append(image[None])
permuted_animefaces_test_images = torch.cat(permuted_animefaces_test_images, dim=0)


permuted_animefaces_train_dataset = CustomTensorDataset(permuted_animefaces_train_images, transform=Flip())

permuted_animefaces_train_loader = torch.utils.data.DataLoader(permuted_animefaces_train_dataset, 
                                                               batch_size=train_batch_size, shuffle=True,
                                                               pin_memory=pin_memory, num_workers=num_workers)
permuted_animefaces_test_loader = torch.utils.data.DataLoader(permuted_animefaces_test_images, 
                                                              batch_size=test_batch_size, 
                                                              pin_memory=pin_memory, num_workers=num_workers)

In [None]:
###################
# Evaluate models #
###################

num_reruns=10
animefaces_batches_in_epoch = len(animefaces_train_loader) # 1272 - full epoch

permuted_animefaces_base_runs_10 = []
permuted_animefaces_base_runs_50 = []
permuted_animefaces_base_runs_100 = []

permuted_animefaces_maml_runs_10 = []
permuted_animefaces_maml_runs_50 = []
permuted_animefaces_maml_runs_100 = []

permuted_animefaces_deltaorthogonal_runs_10 = []
permuted_animefaces_deltaorthogonal_runs_50 = []
permuted_animefaces_deltaorthogonal_runs_100 = []

permuted_animefaces_metainit_runs_10 = []
permuted_animefaces_metainit_runs_50 = []
permuted_animefaces_metainit_runs_100 = []

for _ in range(num_reruns):    
    print("MetaInit")
    batch_x = next(iter(animefaces_train_loader))
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    metainit_model = deepcopy(maml.model)
    metainit(metainit_model, loss_function, batch_x.shape, steps=200)

    metainit_train_loss_history, metainit_test_loss_history = \
        eval_model(metainit_model, permuted_animefaces_train_loader, permuted_animefaces_test_loader, 
                   batches_in_epoch=animefaces_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*animefaces_batches_in_epoch, device=device)
    
    print("Delta Orthogonal")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    deltaorthogonal_model = deepcopy(maml.model)
    for param in deltaorthogonal_model.parameters():
        if len(param.size()) >= 4:
            makeDeltaOrthogonal(param, nn.init.calculate_gain('relu'))
    
    deltaorthogonal_train_loss_history, deltaorthogonal_test_loss_history = \
        eval_model(deltaorthogonal_model, permuted_animefaces_train_loader, permuted_animefaces_test_loader, 
                   batches_in_epoch=animefaces_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*animefaces_batches_in_epoch, device=device)
    
    print("Baseline")
    maml.resample_parameters(initializers=maml.untrained_initializers, is_final=True)
    base_model = deepcopy(maml.model)    
    base_train_loss_history, base_test_loss_history = \
        eval_model(base_model, permuted_animefaces_train_loader, permuted_animefaces_test_loader, 
                   batches_in_epoch=animefaces_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*animefaces_batches_in_epoch, device=device)

    print("DIMAML")
    maml.resample_parameters(is_final=True)
    maml_model = deepcopy(maml.model)
    maml_train_loss_history, maml_test_loss_history = \
        eval_model(maml_model, permuted_animefaces_train_loader, permuted_animefaces_test_loader, 
                   batches_in_epoch=animefaces_batches_in_epoch, epochs=100, 
                   test_loss_interval=10*animefaces_batches_in_epoch, device=device)
    
    permuted_animefaces_base_runs_10.append(base_test_loss_history[1])
    permuted_animefaces_base_runs_50.append(base_test_loss_history[5])
    permuted_animefaces_base_runs_100.append(base_test_loss_history[10])
    
    permuted_animefaces_maml_runs_10.append(maml_test_loss_history[1])
    permuted_animefaces_maml_runs_50.append(maml_test_loss_history[5])
    permuted_animefaces_maml_runs_100.append(maml_test_loss_history[10])
    
    permuted_animefaces_deltaorthogonal_runs_10.append(deltaorthogonal_test_loss_history[1])
    permuted_animefaces_deltaorthogonal_runs_50.append(deltaorthogonal_test_loss_history[5])
    permuted_animefaces_deltaorthogonal_runs_100.append(deltaorthogonal_test_loss_history[10])
    
    permuted_animefaces_metainit_runs_10.append(metainit_test_loss_history[1])
    permuted_animefaces_metainit_runs_50.append(metainit_test_loss_history[5])
    permuted_animefaces_metainit_runs_100.append(metainit_test_loss_history[10])

In [None]:
print("Baseline 10 epoch: ", np.mean(permuted_animefaces_base_runs_10), np.std(permuted_animefaces_base_runs_10, ddof=1))
print("Baseline 50 epoch: ", np.mean(permuted_animefaces_base_runs_50), np.std(permuted_animefaces_base_runs_50, ddof=1))
print("Baseline 100 epoch: ", np.mean(permuted_animefaces_base_runs_100), np.std(permuted_animefaces_base_runs_100, ddof=1))
print()
print("DIMAML 10 epoch: ", np.mean(permuted_animefaces_maml_runs_10), np.std(permuted_animefaces_maml_runs_10, ddof=1))
print("DIMAML 50 epoch: ", np.mean(permuted_animefaces_maml_runs_50), np.std(permuted_animefaces_maml_runs_50, ddof=1))
print("DIMAML 100 epoch: ", np.mean(permuted_animefaces_maml_runs_100), np.std(permuted_animefaces_maml_runs_100, ddof=1))
print()
print("MetaInit 10 epoch: ", np.mean(permuted_animefaces_metainit_runs_10), np.std(permuted_animefaces_metainit_runs_10, ddof=1))
print("MetaInit 50 epoch: ", np.mean(permuted_animefaces_metainit_runs_50), np.std(permuted_animefaces_metainit_runs_50, ddof=1))
print("MetaInit 100 epoch: ", np.mean(permuted_animefaces_metainit_runs_100), np.std(permuted_animefaces_metainit_runs_100, ddof=1))
print()
print("DeltaOrthogonal 10 epoch: ", np.mean(permuted_animefaces_deltaorthogonal_runs_10), np.std(permuted_animefaces_deltaorthogonal_runs_10, ddof=1))
print("DeltaOrthogonal 50 epoch: ", np.mean(permuted_animefaces_deltaorthogonal_runs_50), np.std(permuted_animefaces_deltaorthogonal_runs_50, ddof=1))
print("DeltaOrthogonal 100 epoch: ", np.mean(permuted_animefaces_deltaorthogonal_runs_100), np.std(permuted_animefaces_deltaorthogonal_runs_100, ddof=1))