In [25]:
import time

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from collections import OrderedDict
import matplotlib.pyplot as plt

from src.tasks import Sine_Task, Sine_Task_Distribution

In [26]:
class MAMLModel(nn.Module):
    def __init__(self):
        super(MAMLModel, self).__init__()
        self.model = nn.Sequential(OrderedDict([
            ('l1', nn.Linear(1,40)),
            ('relu1', nn.ReLU()),
            ('l2', nn.Linear(40,40)),
            ('relu2', nn.ReLU()),
            ('l3', nn.Linear(40,1))
        ]))
        
    def forward(self, x):
        return self.model(x)
    
    def parameterised(self, x, weights):
        # like forward, but uses ``weights`` instead of ``model.parameters()``
        # it'd be nice if this could be generated automatically for any nn.Module...
        x = nn.functional.linear(x, weights[0], weights[1])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[2], weights[3])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[4], weights[5])
        return x
                        

In [31]:
class VMAML():
    def __init__(self, model, tasks, inner_lr, meta_lr, K=10, inner_steps=1, tasks_per_meta_batch=1000, radius=0.05, num_ve_iterations=5,
                       device="cuda"):
                           
        # important objects
        self.tasks = tasks
        self.model = model
        self.weights = list(model.parameters()) # the maml weights we will be meta-optimising
        self.criterion = nn.MSELoss()
        self.meta_optimiser = torch.optim.Adam(self.weights, meta_lr)
        
        # hyperparameters
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.K = K
        self.inner_steps = inner_steps # with the current design of MAML, >1 is unlikely to work well 
        self.tasks_per_meta_batch = tasks_per_meta_batch 
        self.radius = radius
        self.num_ve_iterations = num_ve_iterations
        
        # metrics
        self.plot_every = 10
        self.print_every = 10
        self.save_every = 100
        self.meta_losses = []
        
        # GPU
        self.device = device
        self.model.to(device)
    
    def inner_loop(self, X, y, temp_weights, compute_loss=False):
        # perform training on data sampled from task
        # X, y = task.sample_data(self.K)
        train_X, test_X = X[:self.K], X[self.K:]
        train_y, test_y = y[:self.K], y[self.K:]
        
        for step in range(self.inner_steps):
            loss = self.criterion(self.model.parameterised(train_X, temp_weights), train_y) / self.K
            
            # compute grad and update inner loop weights
            grad = torch.autograd.grad(loss, temp_weights)
            temp_weights = [w - self.inner_lr * g for w, g in zip(temp_weights, grad)]
                    
        # sample new data for meta-update and compute loss 
        if compute_loss:
            # X, y = task.sample_data(self.K)
            loss = self.criterion(self.model.parameterised(test_X, temp_weights), test_y) / self.K
        else:
            loss = 0.
            
        return temp_weights, loss
    
    def main_loop(self, num_iterations):
        epoch_loss = 0
        
        for iteration in tqdm(range(1, num_iterations+1)):
            
            # compute meta loss
            meta_loss = 0
            tasks = self.tasks.sample_task(self.tasks_per_meta_batch)
            X, y = tasks.sample_data(self.K*2, device=self.device)
            
            task_weights_list = [[w.clone() for w in self.weights] for _ in range(self.tasks_per_meta_batch)]

            # tasks_list = []
            # for i in range(self.tasks_per_meta_batch):
            #     tasks_list.append(self.tasks.sample_task())

            for j in range(self.num_ve_iterations):
                dist_square = torch.tensor(0.)
                for i in range(self.tasks_per_meta_batch):
                    task_weights, _ = self.inner_loop(X[i], y[i], task_weights_list[i])
                    task_weights_list[i] = task_weights
                    dist_square += sum(list(map(lambda p: torch.sum(torch.square(p[1] - p[0])), zip(task_weights, self.weights))))
                    
                d = torch.sqrt(dist_square)
                r = self.radius
                #print("Before Projection", j, dist_square, torch.sqrt(dist_square))
                
                if d > r:
                    for i in range(self.tasks_per_meta_batch):
                        task_weights_list[i] = list(map(lambda p: (r*p[0] + (d-r)*p[1])/d, zip(task_weights_list[i], self.weights)))
                
            for i in range(self.tasks_per_meta_batch):
                _, loss = self.inner_loop(X[i], y[i], task_weights_list[i], True)
                meta_loss += loss
            
            # compute meta gradient of loss with respect to maml weights
            meta_grads = torch.autograd.grad(meta_loss, self.weights)
            
            # assign meta gradient to weights and take optimisation step
            for w, g in zip(self.weights, meta_grads):
                w.grad = g
            self.meta_optimiser.step()
            
            # log metrics
            epoch_loss += meta_loss.item() / self.tasks_per_meta_batch
            
            if iteration % self.print_every == 0:
                print("{}/{}. loss: {}".format(iteration, num_iterations, epoch_loss / self.plot_every))
            
            if iteration % self.plot_every == 0:
                self.meta_losses.append(epoch_loss / self.plot_every)
                epoch_loss = 0
   
            if iteration % self.save_every == 0:
                torch.save(self.model.state_dict(), './vmaml.pth')

In [32]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [33]:
tasks = Sine_Task_Distribution(0.1, 5, 0, np.pi, -5, 5)
maml = VMAML(MAMLModel(), tasks, inner_lr=0.01, meta_lr=0.001, device=device)

In [34]:
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)
maml.main_loop(num_iterations=10000)

  0%|          | 0/10000 [00:00<?, ?it/s]

tensor([[-1.5172e-03],
        [ 8.4290e-03],
        [ 1.1041e-02],
        [ 8.8467e-04],
        [-4.0512e-05],
        [ 0.0000e+00],
        [-1.1802e-02],
        [ 2.0560e-03],
        [-5.4742e-03],
        [-3.0210e-04],
        [-1.4834e-02],
        [-2.6851e-04],
        [ 3.5012e-04],
        [ 3.6112e-03],
        [ 1.6357e-02],
        [-8.1136e-03],
        [ 1.1172e-02],
        [-9.8972e-03],
        [-9.0433e-04],
        [ 1.2853e-04],
        [-1.1148e-03],
        [ 7.9425e-04],
        [-2.8441e-03],
        [ 8.3994e-03],
        [ 6.4778e-04],
        [ 8.4160e-03],
        [ 9.4961e-04],
        [-2.1924e-02],
        [ 7.1422e-03],
        [ 4.5636e-03],
        [-1.0154e-03],
        [-6.3502e-04],
        [ 3.9301e-03],
        [-1.9618e-02],
        [-1.4322e-02],
        [ 7.9646e-03],
        [ 2.4447e-03],
        [ 1.0705e-03],
        [ 5.7342e-04],
        [ 2.2088e-03]])
tensor([[-0.0014],
        [ 0.0131],
        [-0.0295],
        [ 0.0021],
   

  0%|          | 0/10000 [00:06<?, ?it/s]

tensor([[ 0.0026],
        [-0.0022],
        [-0.0238],
        [-0.0104],
        [ 0.0088],
        [ 0.0000],
        [ 0.0104],
        [-0.0016],
        [-0.0093],
        [ 0.0086],
        [ 0.0139],
        [ 0.0006],
        [-0.0034],
        [-0.0067],
        [-0.0142],
        [ 0.0123],
        [-0.0129],
        [ 0.0069],
        [ 0.0179],
        [ 0.0007],
        [-0.0037],
        [ 0.0045],
        [-0.0023],
        [-0.0084],
        [-0.0051],
        [-0.0105],
        [-0.0190],
        [ 0.0210],
        [-0.0058],
        [-0.0322],
        [-0.0011],
        [ 0.0022],
        [-0.0074],
        [ 0.0196],
        [ 0.0223],
        [-0.0074],
        [ 0.0017],
        [-0.0037],
        [ 0.0032],
        [-0.0242]])
tensor([[ 7.5810e-04],
        [ 7.4478e-03],
        [-7.2190e-03],
        [ 7.4055e-04],
        [-1.0781e-03],
        [ 0.0000e+00],
        [-1.0028e-03],
        [ 2.4983e-03],
        [ 4.2648e-03],
        [-1.9817e-03],
        [




KeyboardInterrupt: 