In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable, grad
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from collections import OrderedDict
import numpy as np 

In [2]:
train_dataset = datasets.MNIST('mnist', download=False, train=True, transform=ToTensor())
test_dataset = datasets.MNIST('mnist', download=False, train=False, transform=ToTensor())

In [3]:
train_dataset.train_data.size()

torch.Size([60000, 28, 28])

In [4]:
train_dataset.train_labels.size()

torch.Size([60000])

In [5]:
class SimpleMLP(nn.Module):
    
    def __init__(self, input_size, output_size, layer_sizes=[128,128]):
        super(SimpleMLP, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.layer1_size = layer_sizes[0]
        self.layer2_size = layer_sizes[1]
        self.output_size = output_size
        
        self.layer1 = nn.Linear(input_size, layer_sizes[0])
        self.layer2 = nn.Linear(layer_sizes[0], layer_sizes[1])
        self.output_layer = nn.Linear(layer_sizes[1], output_size)
        
    def forward(self, input_batch, weights=None):        
        x = F.relu(F.linear(input_batch,weights['layer1.weight'],weights['layer1.bias']))
        x = F.relu(F.linear(x,weights['layer2.weight'],weights['layer2.bias']))
        output = F.linear(x,weights['output_layer.weight'],weights['output_layer.bias'])
        return F.softmax(output)

In [6]:
mlp = SimpleMLP(784, 10)

In [7]:
base_weights = OrderedDict((name, param) for (name, param) in mlp.named_parameters())

In [8]:
# mlp(Variable(torch.randn(64,784)), base_weights)
# testing

In [9]:
# train on 0, 2, 4, 6, 8

In [10]:
# meta train on 1, 3, 5, 7, 9

In [11]:
train_dataset_np = train_dataset.train_data.numpy()

In [12]:
train_dataset_np = train_dataset_np.reshape((train_dataset_np.shape[0],-1))

In [13]:
train_labels  = np.array(train_dataset.train_labels.tolist())

In [14]:
def gen_minibatches(data, labels, task, train=True):
    task_indices = np.where(labels == task)[0]
    task_data = data[task_indices]
    n = len(task_data)
    if train:
        index_ = int(0.8 * n)
        train_data = task_data[0:index_]
        train_labels = labels[0:index_]
        rand_indices = np.random.randint(0,index_-1,64)
        train_data = train_data[rand_indices]
        train_labels = [task] * len(train_data)
        return Variable(torch.FloatTensor(train_data.tolist())), Variable(torch.LongTensor(train_labels), requires_grad = False)
    else:
        index_ = int(0.8 * n)
        rand_indices = np.random.randint(index_,len(task_data),64)
        test_data = task_data[rand_indices]
        test_labels = [task] * len(test_data)
        return test_data, test_labels

In [15]:
# gen_minibatches(train_dataset_np, train_labels, 1)

In [16]:
TRAIN_TASKS = [0,2,4,6,8]
ALPHA = 0.01
BETA = 0.01

In [17]:
loss_function = nn.CrossEntropyLoss(size_average=True)

In [18]:
base_weights = OrderedDict((name, param) for (name, param) in mlp.named_parameters())
task_weights = []
for task in TRAIN_TASKS:
    train_img_batch, train_labels_batch = gen_minibatches(train_dataset_np, train_labels, task)
    output = mlp(train_img_batch, base_weights)
    loss = loss_function(output, train_labels_batch)
    print loss
    grad_params = grad(loss, mlp.parameters(), create_graph=True)
    task_weights.append(OrderedDict((name, param - ALPHA*grad) for ((name, param), grad) in zip(base_weights.items(), grad_params)))

Variable containing:
 2.4457
[torch.FloatTensor of size 1]

Variable containing:
 2.4242
[torch.FloatTensor of size 1]

Variable containing:
 2.4395
[torch.FloatTensor of size 1]

Variable containing:
 2.1962
[torch.FloatTensor of size 1]

Variable containing:
 2.2197
[torch.FloatTensor of size 1]



In [28]:
# outer loss = all_task_loss
# outer gradient = summation_on_inner_gradients
# outer update

In [46]:
dummy = OrderedDict({k: sum(d[k] for d in task_weights) for k in task_weights[0].keys()})