In [1]:
from copy import deepcopy

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset

In [2]:
class NN_Classifier(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        # have to set layer weights to be doubles as usually initialised as floats
        self.model = nn.Sequential(
            nn.Linear(2, 40).double(),
            nn.ReLU(),
            nn.Linear(40, 40).double(),
            nn.ReLU(),
            nn.Linear(40, 2).double()
        )
    
    def forward(self, x):
        return self.model(x)

In [3]:
def sgd(model, task, lr, iterations, criterion):
    # load weights
    # model.load_state_dict(state_dict)
    
    # set to training mode
    model.train()

    opt = torch.optim.SGD(model.parameters(), lr)
    for i in range(iterations):
        # get training points
        for data in task:
            y, x = data[:, 0], data[:, 1:]
            
            # zero the parameter gradients
            opt.zero_grad()

            # forward + backward + optimize
            y_hat = model(x)
            loss = criterion(y_hat, y)
            loss.backward()
            opt.step()
        
    return model.state_dict()

In [4]:
def load_banana(file, num_splits, column='x1'):
    data_ = pd.read_csv(file)
    # replaces -1s with 0s
    data_.y.replace(-1.0, 0, inplace=True)
    # converts labels to int64 which is required by cross entropy loss
    data_['y'] = data_['y'].astype('int64')
    data_size = len(data_)

    # sort by x1
    data_ = data_.sort_values(column)

    # set seed
    np.random.seed(32)

    # split into train, test splits
    split_size = np.floor(len(data_) / num_splits).astype(int)
    ratio = 0.75
    idx = 0
    split_len = []
    
    data = {}
    for i in range(num_splits):
        if i < 2:
            # draw indices
            test_idx = np.random.choice(split_size, np.floor(ratio * split_size).astype(int), replace=False)
            # create boolean mask
            mask = np.array([True if j in test_idx else False for j in np.arange(split_size)])
            rev_mask = np.invert(mask)
            split_len.append(sum(rev_mask))
            data[str(i) + '_train'] = torch.from_numpy(data_[idx:split_size+idx].values[rev_mask, :])
            data[str(i) + '_test'] = torch.from_numpy(data_[idx:split_size+idx].values[mask, :])
            idx += split_size
        else:
             # draw indices
            test_idx = np.random.choice(len(data_[idx:]), np.floor(ratio * len(data_[idx:])).astype(int), replace=False)
            # create boolean mask
            mask = np.array([True if j in test_idx else False for j in np.arange(len(data_[idx:]))])
            rev_mask = np.invert(mask)
            split_len.append(sum(rev_mask))
            data[str(i) + '_train'] = torch.from_numpy(data_[idx:].values[rev_mask, :])
            data[str(i) + '_test'] = torch.from_numpy(data_[idx:].values[mask, :])
            
    return data, split_len

def gen_task_data(data, split_len, num_splits, batch_size):
    # select task
    task_num = int(torch.randint(num_splits, (1, )).item())
    
    # create dataloader to sample from task
    loader = DataLoader(
        data[str(task_num) + '_train'],
        batch_size=batch_size,
        shuffle=True,
        pin_memory=False
        )
    
    return loader, split_len[task_num] / batch_size

def gen_eval_data(data, num_splits):
    for split in np.arange(num_splits):
        if split == 0:
            test_data = data[str(split) + '_test']
        else:
            test_data = torch.cat([test_data, data[str(split) + '_test']])
    return test_data

In [5]:
def reptile_1_step(model, task, o_lr, i_lr, i_iters, criterion):
    # create a copy of weights to be used in outer loop
    weights_main = deepcopy(model.state_dict())
   
    # run inner loop
    for _ in torch.arange(i_iters):
        weights = sgd(model, task, i_lr, i_iters, criterion)
    
    # update model's state_dict()
    for key in weights_main.keys():
        weights_main[key] += o_lr * (weights[key] - weights_main[key])
        
    return weights_main

In [6]:
# Load data
# Set num_splits
num_splits = 3
banana, split_len = load_banana('banana_data.csv', num_splits)
test_data = gen_eval_data(banana, num_splits)
y_test, x_test = test_data[:, 0], test_data[:, 1:]

# Initialise model
model = NN_Classifier()
weights = model.state_dict()

# Set hyperparameters
o_iters = 1000
o_lr = 0.001
i_iters = 20
i_lr = 0.01
batch_size = 50

criterion = torch.nn.CrossEntropyLoss()

for _ in torch.arange(o_iters):
    # sample task
    task, scale = gen_task_data(banana, split_len, num_splits, batch_size)
    
    # run reptile_1_step
    weights = reptile_1_step(model, task, o_lr, scale * i_lr, i_iters, criterion)
    
    # re-assign weights to model
    model.load_state_dict(weights)
    
    if _ % 20 == 0:
        y_hat = model(x_test)
        test_loss = criterion(y_hat, y_test)
        print('Iteration ' + str(_.item()) + ': ', test_loss.item())
        
y_hat = model(x_test)
test_loss = criterion(y_hat, y_test)
print('Final Test Loss ' + str(_.item()) + ': ', test_loss.item())

Iteration tensor(0.):  tensor(0.6966, dtype=torch.float64)
Iteration tensor(20.):  tensor(0.6938, dtype=torch.float64)
Iteration tensor(40.):  tensor(0.6910, dtype=torch.float64)
Iteration tensor(60.):  tensor(0.6878, dtype=torch.float64)
Iteration tensor(80.):  tensor(0.6844, dtype=torch.float64)
Iteration tensor(100.):  tensor(0.6806, dtype=torch.float64)
Iteration tensor(120.):  tensor(0.6766, dtype=torch.float64)
Iteration tensor(140.):  tensor(0.6724, dtype=torch.float64)
Iteration tensor(160.):  tensor(0.6681, dtype=torch.float64)
Iteration tensor(180.):  tensor(0.6635, dtype=torch.float64)
Iteration tensor(200.):  tensor(0.6588, dtype=torch.float64)
Iteration tensor(220.):  tensor(0.6544, dtype=torch.float64)
Iteration tensor(240.):  tensor(0.6498, dtype=torch.float64)
Iteration tensor(260.):  tensor(0.6450, dtype=torch.float64)
Iteration tensor(280.):  tensor(0.6414, dtype=torch.float64)
Iteration tensor(300.):  tensor(0.6377, dtype=torch.float64)
Iteration tensor(320.):  tenso