In [14]:
import pprint
import copy

import numpy as np
import matplotlib.pyplot as plt
import IPython.display as disp
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import RandomSampler, DataLoader

from data import TrainTestSplitter, CurveTasks

In [15]:
# Pretty print.
pp = pprint.PrettyPrinter(indent=4)

In [16]:
# Device.
device = torch.device("cuda:0")

# Randomness.
np.random.seed(5)
torch.manual_seed(5)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [17]:
class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        self.L1 = nn.Linear(1, 10)
        self.L2 = nn.Linear(10, 10)
        self.L3 = nn.Linear(10, 1)
    
    def forward(self, x):
        h1 = nn.Sigmoid()(self.L1(x))
        h2 = nn.Sigmoid()(self.L2(h1))
        out = self.L3(h2)
        return out

In [18]:
def compute_mse(y, y_pred):
    return (y_pred - y)**2

In [19]:
def random_initialisation(model, gain=1.0, torch_seed=42):
    torch.manual_seed(torch_seed)
    for name, param in model.named_parameters():
        if "weight" in name:
            torch.nn.init.xavier_uniform_(param, gain=gain)
        elif "bias" in name:
            torch.nn.init.ones_(param)
        else:
            raise ValueError("Unknown model parameter '{}' (not 'weight' or 'bias') found.".format(name))
    return model

In [20]:
def from_model_initialisation(model, from_model):
    parlist = sorted(list(model.named_parameters()))
    from_parlist = sorted(list(from_model.named_parameters()))
    for idx, (name, param) in enumerate(from_parlist):
        assert parlist[idx][0] == name, "Parameter mismatch."
        parlist[idx][1].data = param.data
    return model

In [21]:
# Data.
tts = TrainTestSplitter(test_frac=0.4)
meta_train = CurveTasks(train_test_splitter=tts, meta_train=True)
meta_test = CurveTasks(train_test_splitter=tts, meta_train=False)
dl_meta_train = DataLoader(meta_train, sampler=RandomSampler(meta_train, replacement=False))
dl_meta_test = DataLoader(meta_test, sampler=RandomSampler(meta_test, replacement=False))

### Evaluate performance after 1 gradient step 

In [22]:
# Training loop.

def training_loop(model, x_train, y_train, x_test, y_test, device, epochs=1, use_manual_sgd=False, display=False, retain_graph=False, create_graph=False):
    
    model = model.to(device)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    x_test = x_test.to(device)
    y_test = y_test.to(device)
    
    lr = 0.1

    x = x_train[0]
    y = y_train[0]

    x_test_ = x_test[0]
    y_test_ = y_test[0]

    if use_manual_sgd is False:
        opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0)

    epoch_losses_train = np.zeros((epochs,))
    epoch_losses_test = np.zeros((epochs,))

    for epoch in range(epochs):
        
        if display:
            disp.clear_output(wait=True)
            print("------------------------------\nepoch {}:\n------------------------------\n".format(epoch + 1))

        #######
        # Train
        #######

        model.train()  # Trian mode.

        y_preds = torch.zeros(len(x), requires_grad=False)
        mse_losses = torch.zeros(len(x), requires_grad=False)

        for it in range(len(x)):
            # Zero the gradients.
            model.zero_grad()

            # Predict.
            x_it = x[it].float().unsqueeze(0)
            y_pred = model(x_it)
            y_preds[it] = y_pred.detach()

            # Compute loss.
            mse_loss = compute_mse(y_pred.squeeze(), y[it])
            mse_losses[it] = mse_loss

            if use_manual_sgd is False:

                # Backprop.
                mse_loss.backward(torch.tensor(1.).to(device))

                # Gradient descent.
                opt.step()

            else:

                params_list = list(model.parameters())
                grads = torch.autograd.grad(outputs=mse_loss, inputs=params_list, grad_outputs=torch.tensor(1.).to(device), retain_graph=retain_graph, create_graph=create_graph)
                for idx, parameter in enumerate(params_list):
                    parameter.grad = grads[idx].data

                with torch.no_grad():  # Do not track.
                    for name, parameter in model.named_parameters():
                        parameter.data -= lr * parameter.grad.data

        ##########
        # Evaluate
        ##########

        model.eval()  # Evaluation mode.

        # Test loss:
        y_preds_test = torch.zeros(len(x_test_), requires_grad=False)
        mse_losses_test = torch.zeros(len(x_test_), requires_grad=False)
        for it in range(len(x_test_)):
            # Compute loss.
            x_it_test = x_test_[it].float().unsqueeze(0)
            y_pred_test = model(x_it_test)
            y_preds_test[it] = y_pred_test.detach()
            mse_loss_test = compute_mse(y_pred_test.squeeze(), y_test_[it])
            mse_losses_test[it] = mse_loss_test
        
        if display:
            f1 = show_plot(x, y, y_preds, x_test_, y_test_, y_preds_test, ylim=None)

        epoch_mse_train = mse_losses.sum()
        epoch_mse_test = mse_losses_test.sum()

        epoch_losses_train[epoch] = epoch_mse_train.item()
        epoch_losses_test[epoch] = epoch_mse_test.item()
        if display:
            f2 = plot_losses(epoch_losses_train[:epoch+1], epoch_losses_test[:epoch+1], xlim=(0, epochs-1))

        #########
        # Display
        #########
        
        if display:
            disp.display(f1)
            disp.display(f2)
            print()
            print("Training loss: {}".format(epoch_mse_train))
            print("Test loss: {}".format(epoch_mse_test))
        
    return epoch_mse_train, epoch_mse_test, model

In [23]:
def meta_evaluate(data_loader, model, param_init, param_init_kwargs, use_manual_sgd=False):
    
    df = pd.DataFrame(columns=("training_MSE", "test_MSE"), index=[0])
    
    torch.manual_seed(41)  # For data loader shuffling.
    for idx, ((x_train, y_train), (x_test, y_test)) in enumerate(data_loader):
        
        if "torch_seed" in param_init_kwargs:  # To vary random seed over meta test datasets.
            param_init_kwargs["torch_seed"] += 1
        
        model = param_init(model, **param_init_kwargs)
        
        tr, ts, _ = training_loop(model, x_train, y_train, x_test, y_test, device, epochs=1, use_manual_sgd=False, display=False, retain_graph=False, create_graph=False)
        df.loc[idx, ("training_MSE", "test_MSE")] = tr.item(), ts.item()
    
    return df

In [24]:
# Evaluate performance after 1 gradient step.
model = Model()

df = meta_evaluate(dl_meta_test, model, random_initialisation, {"gain": 1.0, "torch_seed": 41}, use_manual_sgd=True)

disp.display(df.mean())
# disp.display(df)

training_MSE    141.533891
test_MSE         59.671710
dtype: float64

### Meta-train 

In [25]:
meta_model = Model()
meta_model = meta_model.to(device)

torch.manual_seed(42)  # For parameter initialisation.
meta_model = random_initialisation(meta_model)  # Theta.

loss_track = []
model_track = []
for idx, ((x_train, y_train), (x_test, y_test)) in enumerate(dl_meta_train):
    fast_model = Model()
    fast_model = from_model_initialisation(model=fast_model, from_model=meta_model)
    tr, ts, fast_model = training_loop(fast_model, x_train, y_train, x_test, y_test, device, epochs=1, use_manual_sgd=True, display=False, retain_graph=True, create_graph=True)
#     print("FAST:")
#     pp.pprint(list(fast_model.named_parameters()))
#     print("\nMETA:")
#     pp.pprint(list(meta_model.named_parameters()))
    loss_track.append(tr)
    model_track.append(fast_model)

sum_losses = torch.stack(loss_track).sum().to(device)

### WORKING HERE:
# TODO: Now need to properly deal with the "One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior." situation.
params_list = list(meta_model.parameters())
grads = torch.autograd.grad(outputs=sum_losses, inputs=params_list, grad_outputs=torch.tensor(1.).to(device), retain_graph=False, create_graph=False)
for idx, parameter in enumerate(params_list):
    parameter.grad = grads[idx].data
with torch.no_grad():  # Do not track.
    for name, parameter in model.named_parameters():
        parameter.data -= lr * parameter.grad.data

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

In [None]:
# model = Model()

# torch.manual_seed(42)  # For parameter initialisation.
# model = random_initialisation(model)  # Theta.

# model2 = Model()
# print("BEFORE")
# pp.pprint(list(model2.named_parameters()))

# model2 = from_model_initialisation(model2, model)
# print("AFTER")
# pp.pprint(list(model2.named_parameters()))