In [4]:
import numpy as np
import random
import torch
import torch.nn.functional as F
from data_generator import DataGenerator

torch.set_default_tensor_type('torch.cuda.FloatTensor')

# (fixme) change the comments to own words. change variable names dim_input -> input_dim.


device = torch.device("cuda:0")


def flatten(parameters):
    l = [torch.flatten(p) for p in parameters]
    return torch.cat(l).view(-1,1)

def unflatten(flat, size_ids):
    D = [np.prod(E) for E in size_ids]
    D.insert(0,0)
    D = np.cumsum(D)
    E = [(D[i],D[i+1]) for i in range(D.size-1)]
    
    l = [flat[s:e] for (s, e) in E]
    for i, p in enumerate(size_ids):
        l[i] = l[i].view(*p)
    return l

def hessian_matrix(grads, params):
    components = sum([1 for param in grads])
    size_id = [comp.shape for comp in grads]
    m = sum([np.prod(p) for p in size_id])
    hessian = torch.zeros(m,m)
    
    k = 0 
    for i in range(components):
        B = np.ndindex(grads[i].shape)
        for j in B:
            # calculates gradient of the gradient w.r.t all the parameters.
            hessian[:,k]= flatten(torch.autograd.grad(grads[i][j], params, retain_graph=True)).reshape(m)
            k+=1 
    return hessian

def hessian_vector_products(grad, hessians, lr):
    m = hessians[0].shape[0]
    I = torch.eye(m)
    size_id = [comp.shape for comp in grad]
    grad_flat = flatten(grad).reshape(m)
    
    for hes in reversed(hessians):
        grad_flat = torch.mv(I - lr*hes, grad_flat)
    return unflatten(grad_flat, size_id)

def empty_gradient(params):
    grad = []
    for par in params:
        grad.append(torch.zeros(par.shape))
    return grad
        

class Net(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Net, self).__init__()
        self.layer1 = torch.nn.Linear(input_dim, 40)
        self.layer2 = torch.nn.Linear(40,40)
        self.layer3 = torch.nn.Linear(40, output_dim)
        
    def forward(self, x, weight=None):
        if weight:
            x = F.linear(x, weight[0], weight[1])
            x = F.relu(x)
            x = F.linear(x, weight[2], weight[3])
            x = F.relu(x)
            x = F.linear(x, weight[4], weight[5])
        else:
            x = F.relu(self.layer1(x))
            x = F.relu(self.layer2(x))
            x = self.layer3(x)
        return x


class MAML():
    def __init__(self, model, dim_input=1, dim_output=1, inner_lr=1e-2, outer_lr=1e-3, shots = 10, epochs = 100, task_updates=2, verbose=True):
        self.input_dim = dim_input
        self.output_dim = dim_output
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.shots = shots
        self.epochs = epochs
        self.task_updates = task_updates
        self.meta_loss_history = []
        self.verbose = verbose
        
        self.model = model
        self.measure = torch.nn.MSELoss()
        self.weights = list(model.parameters())
        self.meta_optimizer = torch.optim.Adam(self.weights, self.outer_lr)
        
    def inner(self, task, approx=False):
        
        if not approx:
            hessians = []
        
        temp_weights = [w.clone() for w in self.weights]
        batch = task[np.random.randint(task.shape[0], size = self.shots)]
        
        for step in range(self.task_updates):
            #print(temp_weights[0])
            loss = self.measure(self.model.forward(batch[:,:-1].float(), weight=temp_weights), batch[:,-1])
            
            # compute grad and hessian
            grad = torch.autograd.grad(loss, temp_weights, retain_graph=True, create_graph=True)
            if not approx:
                hessians.append(hessian_matrix(grad, temp_weights))
            # update the parameters
            temp_weights = [w - self.inner_lr * g for w, g in zip(temp_weights, grad)]
        
        
        batch = task[np.random.randint(task.shape[0], size = self.shots)]
        loss = self.measure(self.model.forward(batch[:,:-1], weight=temp_weights), batch[:,-1])
        hold_grad = torch.autograd.grad(loss, temp_weights)
        if not approx:
            task_grad = hessian_vector_products(hold_grad, hessians, self.inner_lr)
        else:
            task_grad = hold_grad
        return task_grad, loss
        

    def outer(self, data_generator, epochs):
        if epochs:
            self.epochs = epochs
        epoch_loss = 0
        
        for epoch in range(1, self.epochs+1):
            sup_x, sup_y, _, _ = data_generator.generate()
            tasks = torch.tensor(np.concatenate((sup_x, sup_y), axis=2).astype(np.float32))
            meta_grad = empty_gradient(self.weights)
            meta_loss = 0
            
            # perform inner loop
            # should be able to paralellize this
            for t in range(data_generator.batch_size):
                task_grad, task_loss = self.inner(tasks[t], approx=True)
                meta_grad = [e1 + e2 for (e1, e2) in zip(meta_grad, task_grad)]
                meta_loss += task_loss
            
            # assign meta gradiet to weights
            #print(meta_grad)
            
            for weight, grad in zip(self.weights, meta_grad):
                weight.grad = grad
            self.meta_optimizer.step()
            
            # 
            epoch_loss += meta_loss / data_generator.batch_size
            
            #self.meta_optimizer.zero_grad()
            if self.verbose:
                if epoch % 50 == 0:
                    print("{}/{}. loss: {}".format(epoch, epochs, epoch_loss / 50 ))
            
                if epoch % 50 == 0:
                    self.meta_loss_history.append(epoch_loss / 50)
                    epoch_loss = 0
        
         
def main():
    
    datasource = 'sinusoid'
    update_batch_size = 10       # 'number of examples used for inner gradient update (K for K-shot learning).'
    meta_batch_size = 500       # 'number of tasks sampled per meta-update'
    
    data_generator = DataGenerator('sinusoid', meta_batch_size, update_batch_size)
    
    dim_input = data_generator.dim_input
    dim_output = data_generator.dim_output
    
    net = Net(dim_input, dim_output)
    model = MAML(net, verbose=True)
    support_x, support_y, _, _ = data_generator.generate()
    support = torch.tensor(np.concatenate((support_x, support_y), axis=2).astype(np.float32))
    model.outer(data_generator, 50)

    


In [None]:
if __name__ == "__main__":
    main()

In [231]:
torch.cuda.empty_cache() 

In [None]:




def construct_fc_weights(self):
    weights = {}
    weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01))
    weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]]))
    for i in range(1,len(self.dim_hidden)):
        weights['w'+str(i+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[i-1], self.dim_hidden[i]], stddev=0.01))
        weights['b'+str(i+1)] = tf.Variable(tf.zeros([self.dim_hidden[i]]))
    weights['w'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01))
    weights['b'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.zeros([self.dim_output]))
    return weights



def task_metalearn(inp, reuse=True):
    """ Perform gradient descent for one task in the meta-batch. """
    inputa, inputb, labela, labelb = inp
    task_outputbs, task_lossesb = [], []

    task_outputa = self.forward(inputa, weights, reuse=reuse)  # only reuse on the first iter
    task_lossa = self.loss_func(task_outputa, labela)
    grads = tf.gradients(task_lossa, list(weights.values()))
                
    if FLAGS.stop_grad:
        grads = [tf.stop_gradient(grad) for grad in grads]
                
    gradients = dict(zip(weights.keys(), grads))
    fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr*gradients[key] for key in weights.keys()]))
    output = self.forward(inputb, fast_weights, reuse=True)
    task_outputbs.append(output)
    task_lossesb.append(self.loss_func(output, labelb))
                
    for j in range(num_updates - 1):
        loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela)
        grads = tf.gradients(loss, list(fast_weights.values()))
        if FLAGS.stop_grad:
            grads = [tf.stop_gradient(grad) for grad in grads]
        gradients = dict(zip(fast_weights.keys(), grads))
        fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.update_lr*gradients[key] for key in fast_weights.keys()]))
        output = self.forward(inputb, fast_weights, reuse=True)
        task_outputbs.append(output)
        task_lossesb.append(self.loss_func(output, labelb))

                task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb]

In [154]:
torch.cuda.is_available()

True