In [204]:
import numpy as np
import matplotlib.pyplot as plt

In [205]:
import torch
import torch.nn as nn
import torch.nn.functional as F
#print(torch.cuda.is_available())
#print(torch.cuda.get_device_name())

In [206]:
#Initialise the Hyperparams

In [207]:
#dimension of beta_i - there are as many beta-i's as there are functions but this is the dimension of each beta_i
beta_dim = 100

In [208]:
#our inputs are scalars of dimension 1
input_dim = 1

In [209]:
#check what this correspons to - number of phi_rbfs'
num_phi_rbf = 100

In [210]:
#this is the variance of the radial basis function kernels
phi_rbf_sigma = 5

In [211]:
#number of neuron on each hidden layer of the function Phi
phi_hidden_layer_size = 10

In [212]:
#what does z_dim represent?
z_dim = 16

In [213]:
# Gives the numbers of betas to learn
num_training_funcs = 1000

In [214]:
# Number of points each function is evaluated at
#this is the K in the paper
num_eval_points = 20 

In [215]:
# The observation standard deviation
obs_sigma = 0.01 

In [216]:
#torch.nn.linear applies a linear transformation to the incoming data/inputs. 
#this defines the dimensions of the hidden layers inside the encoder?
encoder_h_dim_1 = 512
encoder_h_dim_2 = 512
encoder_h_dim_3 = 128

In [217]:
decoder_h_dim_1 = 128
decoder_h_dim_2 = 128
decoder_h_dim_3 = 128

In [218]:
#interval of definition for the x's (function evaluations)
function_xlims = [-5, 5]

In [219]:
#function that produces a cubic dataset: x^3+noise

def generate_cubic_dataset():
    #create 10 points uniformly sampled between -4 and -2
    x_points = np.random.uniform(low=-4, high=-2, size=(10,))
    #creaet 10 new points uniformly sampled between 2 and 4 and append them to the prvious 10 points
    x_points = np.append(x_points, np.random.uniform(low=2, high=4, size=(10,)))
    #compute y=x^3+Noise(normal(0,1)^3)
    y_points = x_points**3 + np.random.normal(size=(20,)) * 3
    #we return the 20 x's and the 20 y's
    return (x_points, y_points)

In [220]:
# From krasserm github io

#isotropic squared exp kernel
#np.reshape gives a new shape to an array without changing its data. 
def kernel(X1, X2, l=1.0, sigma_f=1.0):
    """
    Isotropic squared exponential kernel.
        
    Args:
        X1: Array of m points (m x d).
        X2: Array of n points (n x d).

    Returns:
        (m x n) matrix.
    """
    sqdist = np.sum(X1**2, 1).reshape(-1, 1) + np.sum(X2**2, 1) - 2 * np.dot(X1, X2.T)
    return sigma_f**2 * np.exp(-0.5 / l**2 * sqdist)

In [221]:
#generate the set of training functions - here prior is a GP
#f_i ~ GP(0, cov(X,X))
def generate_gp_1d_dataset():
    # X = np.arange(function_xlims[0], function_xlims[1], 0.1).reshape(-1, 1)
    output_X = []
    output_samples = []
    for n in range(num_training_funcs):
        X = np.random.uniform(function_xlims[0], function_xlims[1],
            size=(num_eval_points,1))
        mu = np.zeros(X.shape)
        cov = kernel(X, X)
        sample = np.random.multivariate_normal(mu.ravel(), cov, 1)
        output_X.append(X)
        output_samples.append(sample)
    return np.array(output_X), np.array(output_samples)

In [222]:
#generate a set of training functions - here the prior is based on a sine wave

#f_i=a. sqrt(|X|).sin(5X)+X^2

def generate_quadratic_sine_dataset():
    output_X = []
    output_samples = []
    for n in range(num_training_funcs):
        X = np.random.uniform(function_xlims[0], function_xlims[1],
            size=(num_eval_points, 1))
        a = np.random.uniform(0.0, 3.0)
        y = a * np.sqrt(np.abs(X)) * np.sin(X*5) + X**2
        output_X.append(X)
        output_samples.append(y)
    return np.array(output_X), np.array(output_samples)

In [223]:
# y=A*exp(B*x), where A ~ U(1,2) and B ~ U(-2,2)

def generate_exp_dataset():
    output_X = []
    output_samples = []
    for n in range(num_training_funcs):
        X = np.random.uniform(function_xlims[0], function_xlims[1],
            size=(num_eval_points, 1))
        A = np.random.uniform(1, 2)
        B = np.random.uniform(-2, 2)
        y = A * np.exp(B * X)
        output_X.append(X)
        output_samples.append(y)
    return np.array(output_X), np.array(output_samples)
        


In [224]:
#y= a sin(x + phase) where a ~ U(0,1) and phase~U(0, pi)

def generate_maml_sine_dataset():
    output_X = []
    output_samples = []
    for n in range(num_training_funcs):
        X = np.random.uniform(function_xlims[0], function_xlims[1],
            size=(num_eval_points, 1))
        amplitude = np.random.uniform(0, 1.0)
        phase = np.random.uniform(0, np.pi)
        y = amplitude * np.sin(X + phase)
        output_X.append(X)
        output_samples.append(y)
    return np.array(output_X), np.array(output_samples)
        

In [225]:
#we define a Model class which: 
    #initialises the arguments and parameters (what exactly needs initialisation)
    #contains various function definitions: 
        #the high dimensional function Phis, 
        #the encoder, 
        #the decoder, 
        #the get_loss function
        #the eval_at_z function: which gives predicted x values at the location points s, when the z value is given
        #the draw_sample function: which draws samples from the pi vae (s should be (num_eval_points, dim))
        # the get_unnormalized_log_posterior function which gets something proportional to p(z|x, s) where x and s are new test points 
            #s (batch x dim)
            # x (batch)
            # z (z_dim)
    

In [226]:
#copy of Model class without comments:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.phi_rbf_centers = nn.Parameter(torch.tensor(
            np.random.uniform(function_xlims[0], function_xlims[1],
            size=(num_phi_rbf, input_dim))))
        self.phi_nn_1 = nn.Linear(num_phi_rbf, phi_hidden_layer_size)
        self.phi_nn_2 = nn.Linear(phi_hidden_layer_size, beta_dim)

        self.encoder_nn_1 = nn.Linear(beta_dim, encoder_h_dim_1)
        self.encoder_nn_2 = nn.Linear(encoder_h_dim_1, encoder_h_dim_2)
        self.encoder_nn_3 = nn.Linear(encoder_h_dim_2, encoder_h_dim_3)
        self.encoder_nn_4 = nn.Linear(encoder_h_dim_3, z_dim * 2)

        self.decoder_nn_1 = nn.Linear(z_dim, decoder_h_dim_1)
        self.decoder_nn_2 = nn.Linear(decoder_h_dim_1, decoder_h_dim_2)
        self.decoder_nn_3 = nn.Linear(decoder_h_dim_2, decoder_h_dim_3)
        self.decoder_nn_4 = nn.Linear(decoder_h_dim_3, beta_dim)

        # self.betas = nn.Parameter(torch.ones(num_training_funcs, beta_dim))
        self.betas = nn.Parameter(torch.tensor(
            np.random.uniform(-1, 1, size=(num_training_funcs, beta_dim))
        ))

        self.normal_sampler = torch.distributions.normal.Normal(0.0, 1.0)

    def Phi(self, input):
        # Takes input (batch x dim_in) and gives Phi(input) (batch x dim_out)
        #torch.unsqueeze returns a new tensor (here called input_expand)
        #w ith a dimension of size 1 inserted at the specified position 1
        #so here input_expand is transformed into a one-column vector
        input_expand = torch.unsqueeze(input, 1)
        #same here except that the phi_rbf_centers are transformed into a row vector
        phi_expand = torch.unsqueeze(self.phi_rbf_centers, 0)
        M1 = input_expand - phi_expand
        M2 = torch.sum(M1 ** 2, 2)
        M3 = torch.exp(-M2/phi_rbf_sigma)
        M4 = F.sigmoid(self.phi_nn_1(M3))
        M5 = self.phi_nn_2(M4)
        return M5

    def encoder(self, input):
        # input (batch x beta_dim) output ((batch x z_dim), (batch x z_dim))
        M1 = F.relu(self.encoder_nn_1(input))
        M2 = F.relu(self.encoder_nn_2(M1))
        M3 = F.relu(self.encoder_nn_3(M2))
        M4 = self.encoder_nn_4(M3)
        z_mean = M4[:, 0:z_dim]
        z_std = torch.exp(M4[:, z_dim:]) # needs to be positive
        return z_mean, z_std

    def decoder(self, input):
        # input (batch x z_dim) output (batch x beta_dim)
        M1 = F.relu(self.decoder_nn_1(input))
        M2 = F.relu(self.decoder_nn_2(M1))
        M3 = F.relu(self.decoder_nn_3(M2))
        M4 = self.decoder_nn_4(M3)
        return M4

    def get_loss(self, function_id, s, x, kl_factor, print_breakdown=False, 
        return_breakdown=False):
        # function_id is just to know which beta to use
        # s are the inputs (batch x dim)
        # x are the observed outputs (batch)
        batch_size = s.shape[0]

        phi_s = self.Phi(s)
        beta = self.betas[function_id, :]
        x_enc = torch.matmul(phi_s, beta)

        loss_term_1 = (x - x_enc)**2

        z_mean, z_std = self.encoder(beta.unsqueeze(0))
        # Do we draw one z_sample for all x values of this function or one z_sample for each of them?
        # The pi-vae paper in Alg1 does one z-sample for all x-values of the function
        z_sample = z_mean + z_std * self.normal_sampler.rsample((1, z_dim))
        beta_hat = self.decoder(z_sample)
        x_dec = torch.matmul(phi_s, beta_hat.squeeze()) # double check this is actually doing what we want it to
        loss_term_2 = (x - x_dec)**2

        # z_samples = z_mean + z_std * self.normal_sampler.rsample((batch_size, z_dim))
        # z_samples = z_mean.repeat(batch_size, 1)
        # beta_hats = self.decoder(z_samples)
        # x_dec = torch.sum(beta_hats * phi_s, dim=1)
        # loss_term_2 = (x - x_dec)**2
        # beta_hat = self.decoder(z_mean)
        # beta_hat = beta_hat.reshape(beta.shape)
        # loss_term_2 = torch.mean((beta_hat - beta)**2)

        

        # You only get one value not batch_num values since there's only
        # one beta for the whole batch since they're all from the same function
        # But when you add all the losses together, it will get broadcasted
        # so that it is repeated for each item in the batch so the mean will
        # be ok
        loss_term_3 = 0.5 * torch.sum(z_std**2 + z_mean**2 - 1 - torch.log(z_std**2),
            dim=1)
        loss_term_3 = kl_factor * (loss_term_3/z_dim)

        if print_breakdown:
            # print("z_mean, std", z_mean, z_std)
            # print("z_samples", z_samples)
            print("1", torch.mean(loss_term_1))
            print("2", torch.mean(loss_term_2))
            print("3", loss_term_3)

        if return_breakdown == False:
            return torch.mean(loss_term_1 + loss_term_2) + loss_term_3
        else:
            return torch.mean(loss_term_1 + loss_term_2) + loss_term_3, \
                torch.mean(loss_term_1), torch.mean(loss_term_2), loss_term_3

    def eval_at_z(self, z, s, return_beta_hat=False):
        # Gives predicted x values at s points when the z value is given
        phi_s = self.Phi(s)
        beta_hat = self.decoder(z)
        x_dec = torch.matmul(phi_s, beta_hat)
        if not return_beta_hat:
            return x_dec
        else:
            return x_dec, beta_hat

    def draw_samples(self, s, num_samples):
        # draw samples from the pi vae
        # s should be (num_eval_points, dim)
        z_samples = self.normal_sampler.rsample((num_samples, z_dim)).double()
        beta_hats = self.decoder(z_samples)
        phi_s = self.Phi(s)
        x_dec = torch.matmul(beta_hats.unsqueeze(1).unsqueeze(1),
            phi_s.unsqueeze(2).unsqueeze(0))
        x_dec = x_dec.squeeze()

        return x_dec

    def get_unnormalized_log_posterior(self, s, x, z):
        # Gets something proportional to p(z|x, s) where x and s are new test points
        # s (batch x dim)
        # x (batch)
        # z (z_dim)

        log_prior = -0.5 * torch.sum(z**2)

        phi_s = self.Phi(s)
        beta_hat = self.decoder(z)
        x_dec = torch.matmul(phi_s, beta_hat)
        log_likelihoods = (-1 / (2 * obs_sigma**2)) * (x_dec - x)**2

        return log_prior + torch.sum(log_likelihoods)


In [227]:
#CLASS MODEL COMMENTED OUT   
class Model(nn.Module):
    #_init_ is a reserved method in python, the method _init_ is called when an object is created from a class
    #_init_ allows the class to initialise its attributes
    
    #the word self is used to represent an instance of a class
    #By using the keyword self, we can access the attributes and the methods of the class 

    def __init__(self):
        super().__init__()
       
        #initialise the parameters of the Phi radial basis kernel functions: 
            #the centers are uniformly sampled between xlims and put into a tensor of size num_phi_rbf and input_dim
            #and the 2 linear layers composing Phi: Phi takes as input the number of rbfs and outputs betas
        
        #PHI:    
        self.phi_rbf_centers = nn.Parameter(torch.tensor(
            np.random.uniform(function_xlims[0], function_xlims[1],
            size=(num_phi_rbf, input_dim))))
        self.phi_nn_1 = nn.Linear(num_phi_rbf, phi_hidden_layer_size)
        self.phi_nn_2 = nn.Linear(phi_hidden_layer_size, beta_dim)

        
        #initialise the parameters of the encoder which takes betas as inputs
        #it outputs something of dimensions z_dim*2 (i.e.mu_z and sigma_z, each of dim_z)
        #ENCODER:
        self.encoder_nn_1 = nn.Linear(beta_dim, encoder_h_dim_1)
        self.encoder_nn_2 = nn.Linear(encoder_h_dim_1, encoder_h_dim_2)
        self.encoder_nn_3 = nn.Linear(encoder_h_dim_2, encoder_h_dim_3)
        self.encoder_nn_4 = nn.Linear(encoder_h_dim_3, z_dim * 2)

        #the encoder takes as inputs z and outputs a reconstructed beta
        #DECODER:
        self.decoder_nn_1 = nn.Linear(z_dim, decoder_h_dim_1)
        self.decoder_nn_2 = nn.Linear(decoder_h_dim_1, decoder_h_dim_2)
        self.decoder_nn_3 = nn.Linear(decoder_h_dim_2, decoder_h_dim_3)
        self.decoder_nn_4 = nn.Linear(decoder_h_dim_3, beta_dim)

        #initialise the betas: they are sampled uniformly between -1 and 1. 
        #number of betas= number of functions trained=num_training_funcs
        
        #BETAS:
        self.betas = nn.Parameter(torch.tensor(np.random.uniform(-1, 1, size=(num_training_funcs, beta_dim))))

        #normal N(0,1) sampler
        self.normal_sampler = torch.distributions.normal.Normal(0.0, 1.0)
        
        
        #define the high dimensional feature function Phi as an rbf kernel
        #it takes the locations s_i^k as inputs
        # Takes input (batch x dim_in) and gives Phi(input) (batch x dim_out)
        
     
    def Phi(self, input):
        #unsqueeze adds a dimension of size 1 at the specified position
        #for the inputs we need to add a dimension along axis=1
        #torch.unsqueeze returns a new tensor (here called input_expand)
        #with a dimension of size 1 inserted at the specified position 1
        #so here input_expand is transformed into a one-column vector
        input_expand = torch.unsqueeze(input, 1)
        
        #same here except that the phi_rbf_centers are transformed into a row vector
        #for the rbf centers, an additional dimension needs to be added along x=0
        phi_expand = torch.unsqueeze(self.phi_rbf_centers, 0)
        
        #we now can compute x-x_rbf
        M1 = input_expand - phi_expand
        
        #we now compute sum(x-x')^2 but along axis 2 and I am not sure about what axis 2 represents. 
        M2 = torch.sum(M1 ** 2, 2)
        
        #this is exp(-sum(x-x_rbf)^2/sigma_rbf)
        M3 = torch.exp(-M2/phi_rbf_sigma)
        
        #apply a sigmoid function to the first linear layer of Phi
        #recall that F stands for torch.nn.functional
        M4 = F.sigmoid(self.phi_nn_1(M3))
        #apply a second linear layer to sigmoid
        M5 = self.phi_nn_2(M4)
        return M5

    
        
    def encoder(self, input):
    #this function defines the encoder network
    # input (batch x beta_dim) output ((batch x z_dim), (batch x z_dim))
    #it takes as inputs the betas so the inputs have dimensions (nber of x's, dim of beta)
    #the encoder outputs the mean and standard dev of the Normal from which z will be sampled

    #apply a relu activation fct to the first linear layer of the encoder
        M1 = F.relu(self.encoder_nn_1(input))
    #apply another relu to the second linear layer of the encoder. 
        M2 = F.relu(self.encoder_nn_2(M1))
    #apply another relu to the second linear layer of the encoder. 
        M3 = F.relu(self.encoder_nn_3(M2))
    #the fourth layer is simply linear (no activation fct)
        M4 = self.encoder_nn_4(M3)

    #the first output is the  mean of z: we take the output of the encoder M4
    #we retrieve all its rows but only its first z_dim columns
        z_mean = M4[:, 0:z_dim]
    # The std dev needs to be positive
    #for sigma_z, we take all the rows as well bu the subsequent z_dim columns
        z_std = torch.exp(M4[:, z_dim:]) 
        return z_mean, z_std

    def decoder(self, input):
    # input (batch x z_dim) output (batch x beta_dim)
    #the input of the encoder is a sampled z: we have batch x number of them and their dim is z_dim
    #we first apply a relu activation to the first linear layer
        M1 = F.relu(self.decoder_nn_1(input))
    #second and third relu activations on layer 2 and 3
        M2 = F.relu(self.decoder_nn_2(M1))
        M3 = F.relu(self.decoder_nn_3(M2))
    #the output of the decoder 
    #the last layer is simply linear
        M4 = self.decoder_nn_4(M3)
        return M4

    #REVIEW GET_LOSS FUNCTION

    #this function conmputes the loss that we wish to minimise
    #its arguments are

    def get_loss(self, function_id, s, x, kl_factor, print_breakdown=False, return_breakdown=False):
        # function_id is just to know which beta to use
        # s are the inputs (batch x dim)
        # x are the observed outputs (batch)
        batch_size = s.shape[0]
        #compute Phi(s)
        phi_s = self.Phi(s)
        beta = self.betas[function_id, :]
        x_enc = torch.matmul(phi_s, beta)

        loss_term_1 = (x - x_enc)**2

        z_mean, z_std = self.encoder(beta.unsqueeze(0))
        # Do we draw one z_sample for all x values of this function or one z_sample for each of them?
        # The pi-vae paper in Alg1 does one z-sample for all x-values of the function
        z_sample = z_mean + z_std * self.normal_sampler.rsample((1, z_dim))
        
        beta_hat = self.decoder(z_sample)
        
        x_dec = torch.matmul(phi_s, beta_hat.squeeze()) # double check this is actually doing what we want it to
        loss_term_2 = (x - x_dec)**2

    
        # You only get one value for beta_hat, not batch_num values, since there's only
        # one beta for the whole batch, as the are all from the same function.
        # But when you add all the losses together, it will get broadcasted
        # so that it is repeated for each item in the batch and the mean will
        # be ok
        
        loss_term_3 = 0.5 * torch.sum(z_std**2 + z_mean**2 - 1 - torch.log(z_std**2),
            dim=1)
        loss_term_3 = kl_factor * (loss_term_3/z_dim)

        if print_breakdown:
            # print("z_mean, std", z_mean, z_std)
            # print("z_samples", z_samples)
            print("1", torch.mean(loss_term_1))
            print("2", torch.mean(loss_term_2))
            print("3", loss_term_3)

        if return_breakdown == False:
            return torch.mean(loss_term_1 + loss_term_2) + loss_term_3
        else:
            return torch.mean(loss_term_1 + loss_term_2) + loss_term_3, \
                torch.mean(loss_term_1), torch.mean(loss_term_2), loss_term_3


    # Gives the predicted decoded x values at point locartions s, when the value of z is given
    #eval_at_z returns the decoded x's and the decoded betas
    def eval_at_z(self, z, s, return_beta_hat=False):
        #compute Phi(s_i^k)
        phi_s = self.Phi(s)
        beta_hat = self.decoder(z)
        x_dec = torch.matmul(phi_s, beta_hat)
        if not return_beta_hat:
            return x_dec
        else:
            return x_dec, beta_hat


    # draw samples from the pi vae: 
    #we sample z's from N(0,1), pass these z values through the decoder to get beta_hats decoded from z~N(0,1)
    #then reconstruct x_hat=beta_hat * Phi(s)
    
    def draw_samples(self, s, num_samples):

        # s should be (num_eval_points, dim)

        #dimensions of z_samples: we draw num_samples (nbr of rows) of dim z_dim
        #draw z ~ N(0,1)
        z_samples = self.normal_sampler.rsample((num_samples, z_dim)).double()
        
        #decode betas using the z sampled from N(0,1)
        beta_hats = self.decoder(z_samples)
        #compute Phi(s_i^k)
        phi_s = self.Phi(s)
        
        #compute the decoded x's=beta_hat *phi(s)
        #why we do unsqueeze beta_hats and phis_s and then resqueeze x_dec
        x_dec = torch.matmul(beta_hats.unsqueeze(1).unsqueeze(1), phi_s.unsqueeze(2).unsqueeze(0))
        x_dec = x_dec.squeeze()

        return x_dec

    # Gets something proportional to p(z|x, s) where x and s are new test points
    #p(z|x,s) \propto p(z)*p(x|z)
    # s (batch x dim), x (batch), z (z_dim)
    def get_unnormalized_log_posterior(self, s, x, z):

        log_prior = -0.5 * torch.sum(z**2)

        phi_s = self.Phi(s)
        beta_hat = self.decoder(z)
        x_dec = torch.matmul(phi_s, beta_hat)
        log_likelihoods = (-1 / (2 * obs_sigma**2)) * (x_dec - x)**2

        return log_prior + torch.sum(log_likelihoods)
    
    
#Below would be the code if we were drawing one z sample for each of the x's:
        # z_samples = z_mean + z_std * self.normal_sampler.rsample((batch_size, z_dim))
        # z_samples = z_mean.repeat(batch_size, 1)
        # beta_hats = self.decoder(z_samples)
        # x_dec = torch.sum(beta_hats * phi_s, dim=1)
        # loss_term_2 = (x - x_dec)**2
        

In [228]:
#uses Metropolis Hasting algo
#An object of the MCMC class returns a sample
class MCMC():
    def __init__(self, in_model):
        self.model = in_model
    
    def mcmc_draw_samples(self, num_samples, starting_point, proposal_sigma, s_star, x_star):
        
        z = starting_point
        #initialise the samples with a tensor of zeros of dimensions num_samples and z_dim
        samples = torch.zeros((num_samples, z_dim)).double()
        
        #initialise the acceptance probability to 0
        acc_prob_sum = 0
        
        #loop through the number of samples
        for t in range(num_samples):
            #torch.randn_like(input) returns a tensor of the same size  as the input (i.e dim_z here) 
            #and the tensor is filled with N(0,1) RVs
            #z_p=z + N(0,1)* proposal_sigma^2
            #generate a candidate z_p centered at z
            z_p = z + torch.randn_like(z) * proposal_sigma**2
            
            
            log_p_z = self.model.get_unnormalized_log_posterior(s_star, x_star, z)
            
            log_p_z_p = self.model.get_unnormalized_log_posterior(s_star, x_star, z_p)
            
            #the ratio compares the posterior at the current valuae of z
            #and at the new proposed value z_p: it compares p(z|z, s) and p(z_p|x, s)
            ratio = torch.exp(log_p_z_p - log_p_z)
            
            #We take the min(1,p(z_p|x)/p(z|x)):
            acc_prob = torch.min(torch.tensor(1.0).double(), ratio)
            
            #generate u ~ U(0.1)
            u = torch.rand(1)
            
            #tha algo compares the min between 1 and the ratio of the posteriors with a uniform RV u
            if u < acc_prob:
                z = z_p
                
            #fill in the tensor with the sampled z    
            samples[t, :] = z
            
            acc_prob_sum += acc_prob.detach().data
        print("mean acc prob", acc_prob_sum/num_samples)
        return samples


In [229]:
def check_beta(model, id):
    #define some new test points between -5 and 5 at K locations
    test_points = torch.arange(-5, 5, 0.1).reshape(100, 1)
    
    #compute Phi(s^k) at these test points
    phi_s = model.Phi(test_points)
    
    #use the betas method defined in the class model:
    beta = model.betas[id, :]
    
    #compute the encoded x's: x_i,e=beta^T * Phi(s^k)
    x_encs = torch.matmul(phi_s, beta)
    
    #first unsqueeze the betas by adding a dimension to them along the axis 0 and then pass them through the encoder
    #the encoder outputs the statistics of N(z; mu_z, std_z)
    z_mean, z_std = model.encoder(beta.unsqueeze(0))
    print(z_mean, z_std)

    #beta_hat is the output of the decoder applied to z_mean
    
    #why not applied to z_sdv??
    
    beta_hat = model.decoder(z_mean)
    
    #the decoded x's are equal to beta_hat^T * Phi(s_k)
    x_decs = torch.matmul(beta_hat, torch.transpose(phi_s, 0, 1))
    
    #plot test points against encoded x's
    plt.plot(test_points.detach().numpy(), x_encs.detach().numpy())
    
    #plot test points against decoded x's
    plt.plot(test_points.detach().numpy().reshape(100), x_decs.detach().numpy().reshape(100))
    
    #plot the original scattered points 
    plt.scatter(dataset_X[id].reshape(num_eval_points), dataset_f[id].reshape(num_eval_points))
    plt.show()

In [230]:
def plot_posterior_samples(model, samples, s_star, x_star):
    #define some test points between xlims
    test_points = torch.arange(function_xlims[0], function_xlims[1], 0.02).double()
    
    #loop through the number of samples
    #one plot for each sample i - so we get samples.shape[0] number of lines
    for i in range(samples.shape[0]):
        
        #method eval_at_z from object of class Model
        #the samples[i,:] are the sampled z's
        func = model.eval_at_z(samples[i,:], test_points.unsqueeze(1))
        
        plt.plot(test_points.detach().numpy(), func.detach().numpy(), alpha=0.1, color='black')
    
    #now we plot the scattered points (x*, s*)
    
    #EXPLAIN DIFF BETWEEN TESTS POINTS AND STAR POINTS
    plt.scatter(s_star.detach().numpy(), x_star.detach().numpy(), s=1000, marker="+")
    
    plt.show()

In [231]:
#initialise a dataset of x's and f(x) using the exponential generator
# here f(x)=A exp(Bx)

dataset_X, dataset_f = generate_maml_sine_dataset()

#dataset_X, dataset_f = generate_exp_dataset()
#dataset_X, dataset_f = generate_maml_sine_dataset()

In [232]:
dataset_X.shape

(1000, 20, 1)

In [233]:
dataset_f.shape

(1000, 20, 1)

In [234]:
#here we build the model from the model class
#it includes Phi made of 2 linear layers, an encoder composed of 4 layers, a decoder composed of 4 layers
model = Model().double()

In [235]:
model

Model(
  (phi_nn_1): Linear(in_features=100, out_features=10, bias=True)
  (phi_nn_2): Linear(in_features=10, out_features=100, bias=True)
  (encoder_nn_1): Linear(in_features=100, out_features=512, bias=True)
  (encoder_nn_2): Linear(in_features=512, out_features=512, bias=True)
  (encoder_nn_3): Linear(in_features=512, out_features=128, bias=True)
  (encoder_nn_4): Linear(in_features=128, out_features=32, bias=True)
  (decoder_nn_1): Linear(in_features=16, out_features=128, bias=True)
  (decoder_nn_2): Linear(in_features=128, out_features=128, bias=True)
  (decoder_nn_3): Linear(in_features=128, out_features=128, bias=True)
  (decoder_nn_4): Linear(in_features=128, out_features=100, bias=True)
)

In [236]:
#apply the class MCMC to your object model
mcmc = MCMC(model)

In [237]:
mcmc

<__main__.MCMC at 0x12378dcd0>

In [238]:
#we define the optimiser - using Adam algo and a learning rate of 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Training

In [239]:
#this is the number of input functions f_i
num_funcs_to_consider = 1
current_max = 32
interval = 5
nbr_epochs=300

In [240]:
#we now train the model which means we minimise the total of the encoder and decoder using get_loss
#define the inputs points as the dataset_X (points between 0 and 1) and the fcts evaluations at x as dataset_f where
#y=a*sin(x+phase)

In [241]:

#train for nbr_epochs
for epoch_id in range(nbr_epochs):
    # print(epoch_id)
    
    #create empty arrays for the losses l1, l2 and l3:
    l1s = []
    l2s = []
    l3s = []
    
    for function_id in range(num_funcs_to_consider):
        #because pytorch accunulates the gradients on successive backward passes, 
        #we need to re-initialise them to 0 at every loop
        optimizer.zero_grad()
        
        #the inputs are made of the dataset
        #what is dataset_X?
        input_points = torch.tensor(dataset_X[function_id])
        
        #the x's are the functions evaluations
        #what is dataset_f?
        x_vals = torch.tensor(dataset_f[function_id])
        
        #when breakdown is true, the get_loss method returns: 
        #torch.mean(loss_term_1 + loss_term_2) + loss_term_3, torch.mean(loss_term_1), torch.mean(loss_term_2), loss_term_3
        #so here: 
            #loss=mean(loss_term_1 + loss_term_2) + loss_term_3
            #l1=mean(loss_term_1)
            #l2=torch.mean(loss_term_2)
            #l3=torch.mean(loss_term_3)
        
        loss, l1, l2, l3 = model.get_loss(function_id, input_points, x_vals, 1.0, return_breakdown=True)
        
        #loss_backward() computes the gradients for every parameter and accumulates them 
        loss.backward()
        #optimizer.step() updates the value of x using the gradient of x computed by loss.backward()
        optimizer.step()
        
        
        l1s.append(l1.detach().numpy())
        l2s.append(l2.detach().numpy())
        l3s.append(l3.detach().numpy())

    if epoch_id % interval == 0:
        num_funcs_to_consider = min(num_funcs_to_consider+1, current_max)

    print("l1", np.mean(np.array(l1s)),
        "l2", np.mean(np.array(l2s)),
        "l3", np.mean(np.array(l3s)),
        "num funcs", num_funcs_to_consider)

    


l1 1.3820695022009997 l2 0.052775714592218546 l3 0.004431170350392485 num funcs 2
l1 0.8925321979329899 l2 0.2034786721123095 l3 0.0027938489521845576 num funcs 2
l1 0.33389845610400826 l2 0.13281665106220525 l3 0.0036218327554741426 num funcs 2
l1 0.20176670274472674 l2 0.136639346260383 l3 0.0030185765250162807 num funcs 2
l1 0.2714020084439108 l2 0.3267842603334868 l3 0.0022774983157294957 num funcs 2
l1 0.3617258994889493 l2 0.23468260428602428 l3 0.002376431706665003 num funcs 3
l1 2.2173972186136948 l2 0.2965026571054507 l3 0.002381565928200906 num funcs 3
l1 1.5484684603421923 l2 0.4408252790309571 l3 0.003071909056255345 num funcs 3
l1 0.9040581436470569 l2 0.2529603195230614 l3 0.005478291363399887 num funcs 3
l1 0.46956264213391136 l2 0.2587697666461247 l3 0.006997846081514207 num funcs 3
l1 0.27253251687084584 l2 0.28283969513164736 l3 0.005895407955848863 num funcs 4
l1 0.3316610603212666 l2 0.22526823880437194 l3 0.0034628195328627754 num funcs 4
l1 0.3335611902798643 l2 0

l1 0.17588688087131205 l2 0.1882596028289704 l3 6.008690834976606e-05 num funcs 21
l1 0.17535313640676572 l2 0.1863650780301302 l3 2.261247457823038e-05 num funcs 22
l1 0.27788531082511864 l2 0.1930180685462929 l3 1.0358353018787832e-05 num funcs 22
l1 0.2001308882518509 l2 0.19382929141563485 l3 7.084932241674933e-06 num funcs 22
l1 0.19099862278105692 l2 0.19426643009679437 l3 1.6768020124839745e-05 num funcs 22
l1 0.18577189452312634 l2 0.19423412856732708 l3 1.688241375970684e-05 num funcs 22
l1 0.1828499108196705 l2 0.1930955322121595 l3 1.8884367065848204e-05 num funcs 23
l1 0.18048251000779883 l2 0.18849834226303158 l3 1.2056768254236965e-05 num funcs 23
l1 0.18211680018924445 l2 0.18854734529019335 l3 2.7165607812194074e-05 num funcs 23
l1 0.17832163493784736 l2 0.18747635829333142 l3 3.970710682017221e-05 num funcs 23
l1 0.17628403786080768 l2 0.18770481125488966 l3 1.5943760845069788e-05 num funcs 23
l1 0.17549347616192265 l2 0.1886361509660579 l3 1.2040586275892078e-05 num f

l1 0.15147195366189042 l2 0.16484422116220337 l3 3.333914480196134e-07 num funcs 32
l1 0.15155649282898243 l2 0.16539427879898339 l3 2.796513051229967e-07 num funcs 32
l1 0.15221328846329896 l2 0.16483445936483224 l3 2.3819825409686202e-07 num funcs 32
l1 0.1534540621902379 l2 0.1653640265277412 l3 4.2320607430179534e-07 num funcs 32
l1 0.15636230197682416 l2 0.16496312029719956 l3 5.488500525305935e-07 num funcs 32
l1 0.16152291375301897 l2 0.16507487068273277 l3 4.7136123357144234e-07 num funcs 32
l1 0.16989295847000824 l2 0.16537533290490003 l3 6.638152331162704e-07 num funcs 32
l1 0.18047973895150002 l2 0.1647217804802022 l3 1.1935917626142794e-06 num funcs 32
l1 0.1861865697871236 l2 0.16572994674174496 l3 1.2040790126514996e-06 num funcs 32
l1 0.18623017263198854 l2 0.16463247741633058 l3 1.2164967909306784e-06 num funcs 32
l1 0.17929158408216073 l2 0.1655927205052228 l3 2.423652951427952e-06 num funcs 32
l1 0.169036097217095 l2 0.16463726005709967 l3 2.6340550602650044e-06 num f

l1 0.15014674410239015 l2 0.16454309761825844 l3 4.643400461677749e-09 num funcs 32
l1 0.1497277390480467 l2 0.16480275560865365 l3 6.062671191330679e-09 num funcs 32
l1 0.14949603930957078 l2 0.16454777457411607 l3 4.19918147021905e-09 num funcs 32
l1 0.14940394714883792 l2 0.16479038482171243 l3 4.077667339482123e-09 num funcs 32
l1 0.14945121683472778 l2 0.16454910269321726 l3 3.9688938554674545e-09 num funcs 32


In [242]:
#Draw some samples from the pivae:
#First define locations between function_xlims
locations = torch.arange(function_xlims[0], function_xlims[1], 0.2).unsqueeze(1).double()

In [243]:
np.savetxt('locations_output_sine_201120.txt', locations)

In [244]:
#draw samples using the method draw_samples() from the class model:
#here we draw z from z ~ N(0,1), pass the z values through the decoder to get beta_hat 
#and reconstruct x_hat=beta_hat*Phi(s_i^k)

#from the locations s_i^k, we draw samples, and obtain samples of decoded x's
samples = model.draw_samples(locations, 20)

In [245]:
samples = samples.detach().numpy()

In [246]:
locations = locations.detach().numpy()

In [247]:
np.savetxt('samples_output_sine_201120.txt', samples)

In [248]:
# ---- MCMC ----
#define the additional pairs of points (s,x) at which we will compute the log posterior log p(x^*,z) 
#to get the Metropolis Hasting algo to sample the z_p that satisfy the criterium of acceptance

s_star = torch.tensor([0,2,-2]).unsqueeze(1).double()
x_star = torch.tensor([0.5*np.sin(-1.5),0.5*np.sin(2-1.5),0.5*np.sin(-2-1.5)]).double()


#draw z from N(0,1)
z = torch.randn((z_dim,)).double()

#draw 10,000 samples from the mcmc algo
mcmc_samples = mcmc.mcmc_draw_samples(10000, z, 0.1, s_star, x_star)

mean acc prob tensor(0.9887, dtype=torch.float64)


In [249]:
all_samples = mcmc_samples[1000::100,:]
all_samples.shape

torch.Size([90, 16])

In [250]:
#apply MCMC to different starting points z ~ N(0,1)
for i in range(9):
    z = torch.randn((z_dim,)).double()
    mcmc_samples = mcmc.mcmc_draw_samples(1000, z, 0.1, s_star, x_star)
    
    all_samples[10*i:10*(i+1), :] = mcmc_samples[500::50,:]

mean acc prob tensor(0.9841, dtype=torch.float64)
mean acc prob tensor(0.9854, dtype=torch.float64)
mean acc prob tensor(0.9832, dtype=torch.float64)
mean acc prob tensor(0.9844, dtype=torch.float64)
mean acc prob tensor(0.9832, dtype=torch.float64)
mean acc prob tensor(0.9833, dtype=torch.float64)
mean acc prob tensor(0.9837, dtype=torch.float64)
mean acc prob tensor(0.9822, dtype=torch.float64)
mean acc prob tensor(0.9807, dtype=torch.float64)


In [251]:
all_samples.shape

torch.Size([90, 16])

In [252]:
np.savetxt('all_samples_output_sine_3pts.txt', all_samples)
np.savetxt('s_star_output_sine_3pts.txt', s_star)
np.savetxt('x_star_output_sine_3pts.txt', x_star)
np.savetxt('mcmc_samples_output_sine_3pts.txt', mcmc_samples)

In [253]:
#define test points between -5 and 5. 
test_points = torch.arange(-5, 5, 0.1).unsqueeze(1)
np.save('test_points', test_points)
np.savetxt('test_points_out.txt', test_points)

In [254]:
#compute the functions evaluations:

#first create an np array of zeros to store the fct evaluations: 
    #nbr of rows = nbr of sample fcts
    #nbr of columns = nbr of test points
all_funcs=np.zeros((all_samples.shape[0], test_points.shape[0]))

#eval_at_z computes: beta_hat^T * Phi(s) 
#where beta_hat=decoder(z) and here z=sample from mcmc
for i in range(all_samples.shape[0]):
        func = model.eval_at_z(all_samples[i,:], test_points)
        all_funcs[i,:]=func.detach().numpy()

In [255]:
np.savetxt('all_funcs_out_3pts.txt', all_funcs)

In [139]:
def get_loss(self, function_id, s, x, kl_factor, print_breakdown=False, return_breakdown=False):
         # function_id is just to know which beta to use
        # s are the inputs (batch x dim)
        # x are the observed outputs (batch)
        
            #the first dim of s is the number N of training functions (i=1, ..., N)
            #here: N=x batch I think?
            batch_size = s.shape[0]
        
            #we directly apply the feature function Phi to our location points s, this is Phi(s_i^k)
            phi_s = self.Phi(s)
        
            #beta is builds as an array. Not sure where betas method is defined
            beta = self.betas[function_id, :]
        
            #this corresponds to the equation: x_hat_i^k,e=beta_i^T * Phi(s_i^k)
            x_enc = torch.matmul(phi_s, beta)

            #the first loss term is the squared difference between the inputs x and their encoded form hat_x's
            loss_term_1 = (x - x_enc)**2

            #to get the output mu_z and sigma_z of the encoder, we apply the encoder to the betas. 
            #this corresponds to the  equation: [z_mu, z_sd]^T=e(eta_e, beta_i)
            # we add a second dimention to the betas - we unsqueeze them along axis 0 
        
            z_mean, z_std = self.encoder(beta.unsqueeze(0))
        
           # Do we draw one z_sample for all x values of this function or one z_sample for each of them?
        
            # The pi-vae paper in Alg1 does one z-sample for all x-values of the function
            #z ~ N(mu_z, sigma_z) -> z=mu_z+sigma_z * N(0,1)
        
            #here we draw ONLY ONE z_sample for all x_values of this particular function
            z_sample = z_mean + z_std * self.normal_sampler.rsample((1, z_dim))
        
            #then we apply the decoder to the sampled z in order to retrieve the reconstructed betas
            beta_hat = self.decoder(z_sample)
            
            #the reconstructed x's: x_hat_i^k,d=beta_hat_i^T * Phi(s_i^k)
            # double check this is actually doing what we want it to
            x_dec = torch.matmul(phi_s, beta_hat.squeeze()) # double check this is actually doing what we want it to
        
            #the second loss term compares the intial x's with the reconstructed x's
            loss_term_2 = (x - x_dec)**2

            # z_samples = z_mean + z_std * self.normal_sampler.rsample((batch_size, z_dim))
            # z_samples = z_mean.repeat(batch_size, 1)
            # beta_hats = self.decoder(z_samples)
            # x_dec = torch.sum(beta_hats * phi_s, dim=1)
            # loss_term_2 = (x - x_dec)**2
            # beta_hat = self.decoder(z_mean)
            # beta_hat = beta_hat.reshape(beta.shape)
            # loss_term_2 = torch.mean((beta_hat - beta)**2)



            # You only get one value not batch_num values since there's only
            # one beta for the whole batch since they're all from the same function
            # But when you add all the losses together, it will get broadcasted
            # so that it is repeated for each item in the batch so the mean will
            # be ok
            loss_term_3 = 0.5 * torch.sum(z_std**2 + z_mean**2 - 1 - torch.log(z_std**2),
                dim=1)
            loss_term_3 = kl_factor * (loss_term_3/z_dim)

            if print_breakdown:
                # print("z_mean, std", z_mean, z_std)
                # print("z_samples", z_samples)
                print("1", torch.mean(loss_term_1))
                print("2", torch.mean(loss_term_2))
                print("3", loss_term_3)

            if return_breakdown == False:
                return torch.mean(loss_term_1 + loss_term_2) + loss_term_3
            else:
                return torch.mean(loss_term_1 + loss_term_2) + loss_term_3, \
                    torch.mean(loss_term_1), torch.mean(loss_term_2), loss_term_3

