In [24]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

In [68]:
#Hyperparameters initialisation:

#dimension of the vector beta_i: it has p features we wanrt to learn. Here p=beta_dim
beta_dim = 100

#The input x is a scalar and here its dimension is 1
input_dim = 1

#we need to set up a few parameters for the high dimensional feature function phi
#here phi is assumed to be a radial basis kernel function
#number of rbf centers
num_phi_rbf = 100
#sigma of the radial basis function
phi_rbf_sigma = 5
#we need to learn phi using a NN. We assume it has 2 hidden layers, each with 10 neurons
phi_hidden_layer_size = 10


#dimension of the latent variable z that embeds the inputs:
z_dim = 16

# we train our VAE with 1000 training functions:
#this also gives the numbers of betas to learn:
num_training_funcs = 1000


#each function f_i is evaluated at K locations. Here K=num_eval_points
num_eval_points = 20

#The observation standard deviation
obs_sigma = 0.01 

#the encoder and the decoder parts each have 3 hidden layers
encoder_h_dim_1 = 512
encoder_h_dim_2 = 512
encoder_h_dim_3 = 128

decoder_h_dim_1 = 128
decoder_h_dim_2 = 128
decoder_h_dim_3 = 128

#f_i=x_i is the function evaluation at the K locations: lower and upper bounds of x
function_xlims = [-5, 5]
function_slims=[-10,10]

# Algo 1: Prior training for piVAE
## 1. Draw N functions evaluated at K points: 
### For $i=1, \ldots, N,$ the N training functions are called $f_i$, the K locations are $s_i^k,$ and the function evaluations are $x_i^k=f_i(s_i^k).$

In [69]:
def generate_sine_fcts_dataset():
    s_locations=[]
    x_fcts_evals=[]
    
    #loop through the number of training functions N=num_training_funcs
    #each sin wave is differentiated by a uniformly sampled amplitude and phase
    
    for i in range(num_training_funcs):
        s=np.random.uniform(function_slims[0], function_slims[1], size=(num_eval_points, 1))
        amplitude=np.random.uniform(1.0,5.0)
        phase = np.random.uniform(0, np.pi)
        x=amplitude*np.sin(s+phase)
        s_locations.append(s)
        x_fcts_evals.append(x)
    return np.array(s_locations), np.array(x_fcts_evals)

In [70]:
dataset_s, dataset_x=generate_sine_fcts_dataset()

## 2. For each training function $i=1,\ldots, N$ and for each location $k=1, \ldots, K$, we compute  $\hat{x}_{e,i}^k=\beta_i^T\Phi(s_i^k).$ 

But before we can do so, we need to create a class Model  that will encapsulate the feature function $\Phi$ and also the encoder, the decoder and the loss function. 

In [71]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.phi_rbf_centers = nn.Parameter(torch.tensor(
            np.random.uniform(function_xlims[0], function_xlims[1],
            size=(num_phi_rbf, input_dim))))
        self.phi_nn_1 = nn.Linear(num_phi_rbf, phi_hidden_layer_size)
        self.phi_nn_2 = nn.Linear(phi_hidden_layer_size, beta_dim)
          
        self.encoder_nn_1=nn.Linear(beta_dim, encoder_h_dim_1)
        self.encoder_nn_2=nn.Linear(encoder_h_dim_1, encoder_h_dim_2)
        self.encoder_nn_3=nn.Linear(encoder_h_dim_2, encoder_h_dim_3)
        self.encoder_nn_4=nn.Linear(encoder_h_dim_3, z_dim*2) #the encoder output is z_mean and z_sigma 
         
        #we instantiate the hidden layers of the decoder neural network
        #the input of the decoder is z and once z is passed through the decoder, it outputs a reconstructed beta
        self.decoder_nn_1=nn.Linear(z_dim, decoder_h_dim_1)
        self.decoder_nn_2=nn.Linear(decoder_h_dim_1, decoder_h_dim_2)
        self.decoder_nn_3=nn.Linear(decoder_h_dim_2, decoder_h_dim_3)
        self.decoder_nn_4=nn.Linear(decoder_h_dim_3, beta_dim)
        

        # self.betas = nn.Parameter(torch.ones(num_training_funcs, beta_dim))
        self.betas = nn.Parameter(torch.tensor(
            np.random.uniform(-1, 1, size=(num_training_funcs, beta_dim))
        ))

        #instantiate the N(0,1) distribution
        self.normal_sampler = torch.distributions.normal.Normal(0.0, 1.0)
    
    
      
    #the model contains the feature function Phi
    #Phi takes as input the locations s_i^k, which need to be tensors
    def Phi(self, input):
        # Takes input (batch x dim_in) and gives Phi(input) (batch x dim_out)
        input_expand = torch.unsqueeze(input, 1)
        phi_expand = torch.unsqueeze(self.phi_rbf_centers, 0)
        M1 = input_expand - phi_expand
        M2 = torch.sum(M1 ** 2, 2)
        M3 = torch.exp(-M2/phi_rbf_sigma)
        M4 = F.sigmoid(self.phi_nn_1(M3))
        M5 = self.phi_nn_2(M4)
        return M5
    
    
    #encoder inputs beta and outputs z_mean and z_std:
    def encoder(self, input):
        #The input of the encoder is beta. Each beta has dimension beta_dim 
        #but we need one beta for each function evaluation, i.e. batch_size
        #thus the inputs dimensions are input (batch x beta_dim)
        #The outputs of the decoder are the mean and standard dev of the db of z. 
        #Dimensions of output: ((batch x z_dim), (batch x z_dim))
        
        #we apply relu to the linear hidden layers:
        M1=F.relu(self.encoder_nn_1(input))
        M2=F.relu(self.encoder_nn_2(M1))
        M3 = F.relu(self.encoder_nn_3(M2))
                  
        #the last layer of the encoder is simply linear          
        M4 = self.encoder_nn_4(M3) 
                  
        z_mean=M4[:, 0:z_dim]
        #we take the exponential of the weights to ensure the std is positive
        z_std=torch.exp(M4[:,z_dim:])
                  
        return z_mean, z_std  
    
    
    #decoder inputs the latent variable z and outputs a recontructed beta
    def decoder(self, input):
        #input dimensions: (batch x z_dim) 
        #output dimensions: (batch x beta_dim)
        M1=F.relu(self.decoder_nn_1(input))
        M2=F.relu(self.decoder_nn_2(M1))
        M3=F.relu(self.decoder_nn_3(M2))
        M4=self.decoder_nn_4(M3)
        return M4
    
    
    
    #the loss is made of 3 terms: 
    def get_loss(self, function_id, s, x, kl_factor, print_breakdown=False, return_breakdown=False):
        #the arguments of the loss functions are the function_id (to differentiate which fct eval is becing considered)
        #we also need the locations s_i^k and the function evaluations x_i^k
        #the batch_size is equal to the first dimension of the tensor s
        batch_size=s.shape[0] 
        
        #FIRST LOSS TERM:
        phi_s=self.Phi(s)
        beta=self.betas[function_id, :]
        x_enc=torch.matmul(phi_s,beta)
        
        loss_term_1=(x_enc-x)**2
        
        
        #SECOND LOSS TERM:
        #pass the beta through the encoder to get the mean and std of the latent db of z
        z_mean, z_std=self.encoder(beta.unsqueeze(0))
        
        #sample z: we draw one z for each function_id
        z_sample = z_mean + z_std * self.normal_sampler.rsample((1, z_dim))        
        
        #we then decode the z_sample to obtain the reconstructed beta
        hat_beta=self.decoder(z_sample)
        
        #we need to compute the reconstructed x from the hat_beta
        x_dec=torch.matmul(phi_s,hat_beta.squeeze())
        
        loss_term_2=(x_dec-x)**2
        
        #regularisation term between N(z_mean, z_std) and N(0,1):
        
        loss_term_3=0.5*torch.sum(z_mean**2+z_std**2-1-torch.log(z_std**2), dim=1)
    
        #DOES THE KL-FACTOR COME FROM THE FACT THAT WE NEED TO LEARN THE BETAS?
        #WHY DO WE DIVIDE BY Z_DIM?
        loss_term_3 = kl_factor * (loss_term_3/z_dim)
        
        
        if print_breakdown:
            #if print_breakdown is true, we print the mean of each of the first 2 loss terms: 
            #(mean over the training funcs)
            #loss_term_1 and loss_term_2 both depend on beta and hat_beta and thus on the training function considered
            #loss_term_3 is only one term (does not depend on beta_i)
            print("1", torch.mean(loss_term_1))
            print("2", torch.mean(loss_term_2))
            print("3", loss_term_3)
            
            
        if return_breakdown == False:
            return torch.mean(loss_term_1 + loss_term_2) + loss_term_3
        else:
            return torch.mean(loss_term_1 + loss_term_2) + loss_term_3, \
                torch.mean(loss_term_1), torch.mean(loss_term_2), loss_term_3
        
    
    #this function computes the predicted x at the location s when the value of z is given
    def predicted_x(self, s, z, return_hat_beta=False):
        phi_s=self.Phi(s)
        hat_beta=self.decoder(z)
        x_dec=torch.matmul(phi_s, hat_beta)
        if not return_hat_beta:
            return x_dec
        else:
            return x_dec, hat_beta
        
    #this function draws a standard normal latent z from N(0,1), passes z through the decoder and reconstructs 
    #the functions valuations x's
    
    def draw_sample_from_piVAE(self, s, num_samples):
        phi_s=self.Phi(s)
        #draw z from the standard normal N(0,1)
        z_samples=self.normal_sampler.rsample((num_samples, z_dim)).double()
        hat_betas=self.decoder(z_samples)
        x_dec=torch.matmul(hat_betas.unsqueeze(1).unsqueeze(1), phi_s.unsqueeze(2).unsqueeze(0))
        x_dec=x_dec.squeeze()
        
        return x_dec


### Loss 1 is the MSE between the observed function valuations $x_i$'s and the $x_{d,i}^k=\beta_i^T\Phi(s_i^k):$

### loss1 $=\frac{1}{2}\sum_i \sum_k (\beta_i^T\Phi(s_i^k)-x_i)^2.$

### Loss 2 is the MSE between the observed function valuations $x_i$'s and the  decoded  $x_{d,i}^k=\hat\beta_i^T\Phi(s_i^k):$

### loss2 $=\frac{1}{2}\sum_i \sum_k (\hat\beta_i^T\Phi(s_i^k)-x_i)^2.$

### Loss 3 is the regularisation term. It is the KL divergence between N(z; z_mean, z_std) and N(z; 0,1). See appendic B of Kingma and Welling: 
### Regularisation term between a Gaussian and N(0,1): $-D_{KL}=\frac{1}{2}\sum_j (1+\log \sigma_j^2-\mu_j^2-\sigma_j^2).$

In [72]:
#Instantiate a model

model=Model().double()

In [73]:
model

Model(
  (phi_nn_1): Linear(in_features=100, out_features=10, bias=True)
  (phi_nn_2): Linear(in_features=10, out_features=100, bias=True)
  (encoder_nn_1): Linear(in_features=100, out_features=512, bias=True)
  (encoder_nn_2): Linear(in_features=512, out_features=512, bias=True)
  (encoder_nn_3): Linear(in_features=512, out_features=128, bias=True)
  (encoder_nn_4): Linear(in_features=128, out_features=32, bias=True)
  (decoder_nn_1): Linear(in_features=16, out_features=128, bias=True)
  (decoder_nn_2): Linear(in_features=128, out_features=128, bias=True)
  (decoder_nn_3): Linear(in_features=128, out_features=128, bias=True)
  (decoder_nn_4): Linear(in_features=128, out_features=100, bias=True)
)

## Define the optimiser

In [74]:
#learning rate
lr=1e-3
optimizer=torch.optim.Adam(model.parameters(), lr=lr)

# Training the model: we optimise the loss using back propagation

In [75]:
num_epochs=10
#what is the difference between the number_funcs_to_consider and number of training functions (N in the paper)
#this gives us the number of functions we consider at each epoh
num_funcs_to_consider=500
#maybe the functions to consider is simply a subset of the training functions

interval=2
current_max=300
kl_factor=0.01

In [76]:
#here we have 5 training functions f_i which we evaluate at K=20 locations
dataset_s.shape

(1000, 20, 1)

In [77]:
#these are the 20 locations for the first training function
#dataset_s[0]

In [78]:
#dataset_s[1]

In [79]:
for epoch_id in range(num_epochs):
    print("epoch number:", epoch_id)

    
    #create empty storages for the 3 loss terms
    loss_1=[]
    loss_2=[]
    loss_3=[]
    
    
    for fct_id in range(num_funcs_to_consider):
        #ensure the gradients are back to 0 after each loop for a new function
        optimizer.zero_grad()
        
        #define the inputs s and x
        
        #extract the locations s for the function that is being considered 
        s=dataset_s[fct_id]
        #transform into a tensor
        s=torch.tensor(s)
        
        #extract the function valuations for the function that is being considered and transform into a tensor
        x=dataset_x[fct_id]
        x=torch.tensor(x)
        
        loss, l1, l2, l3=model.get_loss(fct_id, s, x, kl_factor, return_breakdown=True)
        
        #compute the backward gradients of loss which is equal to 
        #loss=mean(l1+l2)+l3
        loss.backward()
        optimizer.step()
        loss_1.append(l1.detach())
        loss_2.append(l2.detach())
        loss_3.append(l3.detach())
        
        
        
    if epoch_id % interval == 0:
        num_funcs_to_consider = min(num_funcs_to_consider+1, current_max)
        
        
    print('loss_term_1:', np.mean(np.array(loss_1)),
             'loss_term_2:',  np.mean(np.array(loss_2)),
             'loss_term_3:', np.mean(np.array(loss_3)), 
             'number fcts:', num_funcs_to_consider)
             

       

epoch number: 0
loss_term_1: 6.441614452799898 loss_term_2: 5.0318177541430895 loss_term_3: 0.0013364252973675711 number fcts: 300
epoch number: 1
loss_term_1: 5.248699688167946 loss_term_2: 5.045365981293769 loss_term_3: 0.00046946131670622126 number fcts: 300
epoch number: 2
loss_term_1: 5.065456590735872 loss_term_2: 5.044956092495338 loss_term_3: 0.00030560516593664824 number fcts: 300
epoch number: 3
loss_term_1: 4.981373301192847 loss_term_2: 5.036549702634773 loss_term_3: 0.0010370670381505053 number fcts: 300
epoch number: 4
loss_term_1: 4.922282770278792 loss_term_2: 5.034013153239072 loss_term_3: 0.0029106047145529518 number fcts: 300
epoch number: 5
loss_term_1: 4.872865830564001 loss_term_2: 5.005597086579388 loss_term_3: 0.004775487188987786 number fcts: 300
epoch number: 6
loss_term_1: 4.828866316538846 loss_term_2: 4.957449632951987 loss_term_3: 0.008394692355714918 number fcts: 300
epoch number: 7
loss_term_1: 4.795759406223443 loss_term_2: 4.9102018491432435 loss_term_

In [80]:
# ---- Draw some samples from the pivae
num_samples=5
locations = torch.arange(-10, 10, 0.2).unsqueeze(1).double()
locations = locations.detach()

In [81]:
#save the locations in a txt file
np.savetxt('locations.txt', locations)

In [82]:
piVAE_samples=model.draw_sample_from_piVAE(locations, num_samples)

piVAE_samples = piVAE_samples.detach()


In [83]:
#save the piVAE samples in a text file
np.savetxt('piVAE_samples.txt', piVAE_samples)

In [None]:
#QUESTION: HOW DO WE KNOW WHEN TO UNSQUEEZE. DO WE NEED TENSORS OF SAME DIMENSIONS?
#get Phi to work outside the class as a stand-alone function
# Takes input (batch x dim_in) and gives Phi(input) (batch x dim_out)

phi_rbf_centers=nn.Parameter(torch.tensor(
            np.random.uniform(function_xlims[0], function_xlims[1],
            size=(num_phi_rbf,input_dim) )))
phi_nn_1=nn.Linear(num_phi_rbf, phi_hidden_layer_size)
phi_nn_2=nn.Linear(phi_hidden_layer_size, beta_dim)


def Phi(input):
        
        input_expand = torch.unsqueeze(input, 1)
        phi_expand = torch.unsqueeze(phi_rbf_centers, 0)
        M1 = input_expand - phi_expand
        M2 = torch.sum(M1 ** 2, 2)
        M3 = torch.exp(-M2/phi_rbf_sigma)
        M4 = F.sigmoid(phi_nn_1(M3))
        M5 = phi_nn_2(M4)
        return M5
        