In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable, grad
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from collections import OrderedDict
import numpy as np 

In [3]:
train_dataset = datasets.MNIST('mnist', download=False, train=True, transform=ToTensor())
test_dataset = datasets.MNIST('mnist', download=False, train=False, transform=ToTensor())

In [4]:
train_dataset.train_data.size()

torch.Size([60000, 28, 28])

In [5]:
train_dataset.train_labels.size()

torch.Size([60000])

In [6]:
class SimpleMLP(nn.Module):
    
    def __init__(self, input_size, output_size, layer_sizes=[128,128]):
        super(SimpleMLP, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.layer1_size = layer_sizes[0]
        self.layer2_size = layer_sizes[1]
        self.output_size = output_size
        
        self.layer1 = nn.Linear(input_size, layer_sizes[0])
        self.layer2 = nn.Linear(layer_sizes[0], layer_sizes[1])
        self.output_layer = nn.Linear(layer_sizes[1], output_size)
        
    def forward(self, input_batch, weights=None):        
        x = F.relu(F.linear(input_batch,weights['layer1.weight'],weights['layer1.bias']))
        x = F.relu(F.linear(x,weights['layer2.weight'],weights['layer2.bias']))
        output = F.linear(x,weights['output_layer.weight'],weights['output_layer.bias'])
        return F.softmax(output)

In [7]:
mlp = SimpleMLP(784, 10)

In [8]:
base_weights = OrderedDict((name, param) for (name, param) in mlp.named_parameters())

In [9]:
# mlp(Variable(torch.randn(64,784)), base_weights)
# testing

In [10]:
# train on 0, 2, 4, 6, 8

In [11]:
# meta train on 1, 3, 5, 7, 9

In [12]:
train_dataset_np = train_dataset.train_data.numpy()

In [13]:
train_dataset_np = train_dataset_np.reshape((train_dataset_np.shape[0],-1))

In [14]:
train_labels  = np.array(train_dataset.train_labels.tolist())

In [21]:
def gen_minibatches(data, labels, task, train=True):
    task_indices = np.where(labels == task)[0]
    task_data = data[task_indices]
    n = len(task_data)
    if train:
        index_ = int(0.8 * n)
        train_data = task_data[0:index_]
        train_labels = labels[0:index_]
        rand_indices = np.random.randint(0,index_-1,64)
        train_data = train_data[rand_indices]
        train_labels = [task] * len(train_data)
        return Variable(torch.FloatTensor(train_data.tolist())), Variable(torch.LongTensor(train_labels), requires_grad = False)
    else:
        index_ = int(0.8 * n)
        rand_indices = np.random.randint(index_,len(task_data),64)
        test_data = task_data[rand_indices]
        test_labels = [task] * len(test_data)
        return Variable(torch.FloatTensor(test_data.tolist())), Variable(torch.LongTensor(test_labels), requires_grad = False)

In [16]:
# gen_minibatches(train_dataset_np, train_labels, 1)

In [17]:
TRAIN_TASKS = [0,2,4,6,8]
ALPHA = 0.01
BETA = 0.01

In [18]:
loss_function = nn.CrossEntropyLoss(size_average=True)

In [26]:
base_weights = OrderedDict((name, param) for (name, param) in mlp.named_parameters())
task_weights_list = []
val_loss_list = []
for task in TRAIN_TASKS:
    train_img_batch, train_labels_batch = gen_minibatches(train_dataset_np, train_labels, task)
    output = mlp(train_img_batch, base_weights)
    loss = loss_function(output, train_labels_batch)
#     print loss
    grad_params = grad(loss, mlp.parameters(), create_graph=True)    
    task_weights= OrderedDict((name, param - ALPHA*grad) for ((name, param), grad) in zip(base_weights.items(), grad_params))
    task_weights_list.append(task_weights)
    # forward pass on validation
    val_img_batch, val_labels_batch = gen_minibatches(train_dataset_np, train_labels, task, train=False)
    output = mlp(val_img_batch, task_weights)
    loss = loss_function(output, val_labels_batch)
    val_loss_list.append(loss)
    print loss
# meta_update
meta_grads = OrderedDict({k: sum(d[k] for d in task_weights_list) for k in task_weights_list[0].keys()})
meta_loss = sum(val_loss_list)
meta_grad_params = grad(meta_loss, mlp.parameters(), create_graph=True)    

meta_weights = OrderedDict((name, param - BETA*grad) for ((name, param), grad) in zip(base_weights.items(), meta_grad_params))

    

Variable containing:
 2.4386
[torch.FloatTensor of size 1]

Variable containing:
 1.4612
[torch.FloatTensor of size 1]

Variable containing:
 1.6472
[torch.FloatTensor of size 1]

Variable containing:
 2.4477
[torch.FloatTensor of size 1]

Variable containing:
 2.4481
[torch.FloatTensor of size 1]



In [28]:
# outer loss = all_task_loss
# outer gradient = summation_on_inner_gradients
# outer update

In [27]:
meta_weights

OrderedDict([('layer1.weight', Variable containing:
              -9.2867e-04  2.3334e-02  1.0286e-02  ...  -2.1942e-02 -4.7197e-04 -3.7485e-04
               2.0157e-02 -1.3466e-02 -3.0656e-02  ...  -2.3907e-02 -7.0711e-03  3.0730e-02
               5.2343e-03  6.5051e-03 -2.0429e-02  ...   6.8701e-03 -1.8204e-02  1.6440e-02
                              ...                   ⋱                   ...                
              -8.8931e-03 -1.0550e-02  2.3829e-04  ...  -3.2756e-02  9.2230e-03  1.7846e-02
               3.5419e-02 -2.5387e-02  7.2238e-03  ...   6.6081e-04 -1.0161e-03 -1.1414e-03
              -2.3213e-02  1.4009e-02  1.8974e-02  ...  -1.1989e-02 -8.0940e-03 -3.1297e-02
              [torch.FloatTensor of size 128x784]),
             ('layer1.bias', Variable containing:
              1.00000e-02 *
               -0.1601
               -2.5524
                0.9776
               -0.5021
               -2.9744
                1.1096
               -1.4821
             