# LOCA Interploation Test

In [None]:
import pickle
import math
from scipy import spatial

import numpy as np
import matplotlib.pyplot as plt
import random

import torch.nn.functional as F
import torch 

import torchvision
from torch.autograd import grad
from torch import nn, optim

## Data generation and transformation

In [None]:
def deformation(x):
    #This function transforms the data which the encoder of the LOCA has to invert and the decoder has to reapply.
    y = np.empty((2))
    y[0] = x[0] + x[1]**3
    y[1] = x[1] - x[0]**3
    return y

In [None]:
def samples_test(size= 200, samples=20000, sig_sq=0.0001):
    #This function generates the data, in the form of clouds around a given number of anchor points in a square. 
    #This data will be used as the test data.
    center = np.empty(shape = (samples, 2))
    points = np.empty(shape = (samples, size, 2))
    
    for i in range(samples):        
        center[i, ::] = np.random.uniform(low=-0.025, high=1.025, size=(2))
        pts = np.random.multivariate_normal(center[i, ::], [[sig_sq,0],[0,sig_sq]], size=size)
       
        points[i, ::, 0] = pts[::, 0] 
        points[i, ::, 1] = pts[::, 1] 
       
    return points, center

In [None]:
def samples_train(size= 200, samples=2000, sig_sq=0.0001):
    #This function generates the data, in the form of clouds around a given number of anchor points in a hollow square. 
    #This data will be used as the training data.
    center = np.empty(shape = (samples, 2))
    points = np.empty(shape = (samples, size, 2))
    
    for i in range(samples):
        side_list = ['rigth','top','left','bottom']
        side = random.choice(side_list)
        if side == 'rigth':
            center[i, ::] = np.random.uniform(low=[0, 0.1], high=[0.1, 1], size=(2))
        if side == 'top':
            center[i, ::] = np.random.uniform(low=[0.1, 0.9], high=[1, 1], size=(2))
        if side == 'left':
            center[i, ::] = np.random.uniform(low=[0.9, 0.9], high=[1, 0], size=(2))
        if side == 'bottom':
            center[i, ::] = np.random.uniform(low=[0.9, 0.1], high=[0, 0], size=(2))
            
        pts = np.random.multivariate_normal(center[i, ::], [[sig_sq,0],[0,sig_sq]], size=size)
       
        points[i, ::, 0] = pts[::, 0] 
        points[i, ::, 1] = pts[::, 1] 
       
    return points, center

In [None]:
#Calling the data generation function and visualizing the results.
points, center = samples_train(size= 200, samples=2000, sig_sq=0.0001)
points_plt = np.reshape(points, (points.shape[0]*points.shape[1], 2))

fig = plt.figure()
ax = fig.add_subplot(111)
ax.set(xlabel='X', ylabel='Y',
      title='Untransformed Data')

ax.scatter(center[::, 0], center[::, 1], c='g', alpha=1, s = 4)
ax.scatter(points_plt[::, 0], points_plt[::, 1], c='b', alpha=0.006)

In [None]:
#Applying the transformation and rescaling the data to create the final training data.
sig_sq=0.0001
size = 200 

tran_sq = np.empty_like(points)
tran_cent = np.empty_like(center)

for i in range(points.shape[0]):
    tran_cent[i, ::] =  deformation(center[i, ::])
    tran_cent[i, 0 ] = tran_cent[i, 0] - 1
    for j in range(points.shape[1]):
        tran_sq[i, j, ::] = deformation(points[i, j, ::])
        tran_sq[i, j, 0 ] = tran_sq[i, j, 0] - 1 

tran_sq_plt = np.reshape(tran_sq, (tran_sq.shape[0]*tran_sq.shape[1], 2))

fig = plt.figure()
ax = fig.add_subplot(111)

ax.set(xlabel='X', ylabel='Y',
      title='Model Input')
ax.scatter(tran_cent[::, 0], tran_cent[::, 1], c='g', alpha=1, s = 4)
ax.scatter(tran_sq[::, ::, 0], tran_sq[::, ::, 1], c='b', alpha=0.006)

### Defining the whitening loss

In [None]:
def cov(m, rowvar=False):
    #This function calculates the covariance matrix which is needed in the loss function of the encoder.
    if not rowvar and m.size(0) != 1:
        m = m.t()
    fact = 1.0 / (m.size(1) - 1)
    m_sub_mean = m - torch.mean(m, dim=1, keepdim=True)
    mt = m_sub_mean.t()  # if complex: mt = m.t().conj()
    m_mul = m_sub_mean.matmul(mt)
    f_m_mul = fact * m_mul
    cov_mat = f_m_mul.squeeze()
    return cov_mat.cuda()

In [None]:
def loss_whiten(outputs_en):
    #This function calculates the whitening loss for the encoder.
    L_un = 0
    clouds = torch.reshape(outputs_en, (batch_cloud_num, -1, 2))
    #interate through all the other points.
    for d in range(batch_cloud_num):
        #calculate the covariance
        C = cov(clouds[d, ::, ::])
        #calculate the summed loss
        L = ((torch.norm(1/sig_sq*C-torch.eye(2).cuda()))**2)
        L_un = L_un + L

    #norm the loss by the batch size
    L_w = L_un/batch_cloud_num
    return L_w

In [None]:
def loss_recon(outputs_de, batch_features):
    #This function calculates the reconstruction loss.
    clouds_recon = torch.reshape(outputs_de, (batch_cloud_num, -1, 2))
    clouds_real = torch.reshape(batch_features, (batch_cloud_num, -1, 2))

    #calculate the summed loss
    L = torch.norm(clouds_recon[::, ::, ::]-clouds_real[::, ::, ::])**2

    #norm the loss by the batch size
    L_r = L/(batch_cloud_num*size)
    return L_r

### Model

In [None]:
#The Data is reshaped and the data loaders are created. 
batch_cloud_num = 200
batch_size = batch_cloud_num*tran_sq.shape[1]
valdation = tran_sq[:200, ::, ::]
tran_sq = tran_sq[200:, ::, ::]

val_dataset = np.reshape(valdation, (valdation.shape[0]*valdation.shape[1], valdation.shape[2]))

train_dataset = np.reshape(tran_sq, (tran_sq.shape[0]*tran_sq.shape[1], tran_sq.shape[2]))

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

input_shape = train_dataset.shape[1]

In [None]:
#The LOCA Network is created
class AE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_input_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features= 50)
        
        self.encoder_hidden_layer1 = nn.Linear(
            in_features= 50, out_features= 50)
        
        self.encoder_hidden_layer2 = nn.Linear(
            in_features= 50, out_features= 2)
        
        self.encoder_hidden_layer3 = nn.Linear(
            in_features= 2, out_features= 2)
        
        self.encoder_output_layer = nn.Linear(
            in_features= 2, out_features= 2)
        
        self.decoder_input_layer = nn.Linear(
            in_features= 2, out_features= 2)
        
        self.decoder_hidden_layer1 = nn.Linear(
            in_features= 2, out_features= 50)
        
        self.decoder_hidden_layer2 = nn.Linear(
            in_features= 50, out_features= 50)
        
        self.decoder_hidden_layer3 = nn.Linear(
            in_features= 50, out_features= input_shape)    
        
        self.decoder_output_layer = nn.Linear(
            in_features= input_shape, out_features=kwargs["input_shape"])
        
        self.tanh = nn.Tanh()
        
    def forward_encoder(self, features):
        n_e1 = self.encoder_input_layer(features)
        a_e1 = self.tanh(n_e1)
        
        n_e2 = self.encoder_hidden_layer1(a_e1)
        a_e2 = self.tanh(n_e2)
        
        n_e3 = self.encoder_hidden_layer2(a_e2)
        a_e3 = self.tanh(n_e3)
        
        n_e4 = self.encoder_hidden_layer3(a_e3)
        
        outputs_en = self.encoder_output_layer(n_e4)
        
        return outputs_en
    
    def forward_decoder(self, outputs_en):
        n_d1 = self.decoder_input_layer(outputs_en)
        a_d1 = self.tanh(n_d1)
        
        n_d2 = self.decoder_hidden_layer1(a_d1)
        a_d2 = self.tanh(n_d2)
        
        n_d3 = self.decoder_hidden_layer2(a_d2)
        a_d3 = self.tanh(n_d3)
        
        n_d4 = self.decoder_hidden_layer3(a_d3)
        
        reconstructed = self.decoder_output_layer(n_d4)
        
        return reconstructed

In [None]:
#The weigths are initialized based on the xavier scheme, 
#which is preferable to the He scheme that is the default in Pytorch that is optimized for the RELU function.
def init_weights(net):
    if type(net) == torch.nn.Linear:
        torch.nn.init.xavier_uniform_(net.weight)

init_weights(AE)

In [None]:
# Create the necessary objects for the neural network.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# create a model from `AE` autoencoder class
# load it to the specified device, either gpu or cpu
model_ae = AE(input_shape=input_shape).to(device)

# create an optimizer object
# Adam optimizer with learning rate 1e-3
optimizer_ae = optim.Adam(model_ae.parameters(), lr=1e-3)

In [None]:
#This is the training loop of the network.
epochs = 100

#Create lists of training metrics
epoch_list = []
val_whiten = []
val_recon = []
val_total = []

for epoch in range(epochs):
    train_loss_en = 0
    train_loss_de = 0
    loss = 0
    if epoch % 2 == 0:
        for batch_features in train_loader:
            # Encoder training
            # load it to the active device
            batch_features = batch_features.to(device)
            optimizer_ae.zero_grad()
            # compute reconstructions
            outputs_en = model_ae.forward_encoder(batch_features.float())
            # compute training reconstruction loss
            train_loss_en = loss_whiten(outputs_en)
            
            # compute accumulated gradients
            train_loss_en.backward()

            # perform parameter update based on current gradients
            optimizer_ae.step()

            # add the mini-batch training loss to epoch loss
            loss += train_loss_en.item()
    else:
        for batch_features in train_loader:
            # Encoder training
            # load it to the active device
            batch_features = batch_features.to(device)
            optimizer_ae.zero_grad()
            
            for param in model_ae.encoder_input_layer.parameters():
                param.requires_grad = False
            for param in model_ae.encoder_hidden_layer1.parameters():
                param.requires_grad = False
            for param in model_ae.encoder_hidden_layer2.parameters():
                param.requires_grad = False
            for param in model_ae.encoder_hidden_layer3.parameters():
                param.requires_grad = False
            for param in model_ae.encoder_output_layer.parameters():
                param.requires_grad = False
            # compute reconstructions
            outputs_en = model_ae.forward_encoder(batch_features.float())
            outputs_de = model_ae.forward_decoder(outputs_en.float())
            # compute training reconstruction loss
            train_loss_de = loss_recon(outputs_de.double(), batch_features.double())
            # compute accumulated gradients
            train_loss_de.backward()

            # perform parameter update based on current gradients
            optimizer_ae.step()
            
            for param in model_ae.encoder_input_layer.parameters():
                param.requires_grad = True
            for param in model_ae.encoder_hidden_layer1.parameters():
                param.requires_grad = True
            for param in model_ae.encoder_hidden_layer2.parameters():
                param.requires_grad = True
            for param in model_ae.encoder_hidden_layer3.parameters():
                param.requires_grad = True
            for param in model_ae.encoder_output_layer.parameters():
                param.requires_grad = True
            # add the mini-batch training loss to epoch loss
            loss += train_loss_de.item()
    if (epoch % 50) == 0:
        #Validate the models performance
        val_loss_en = 0 
        val_loss_de = 0

        for batch_val in test_loader:
            batch_val = batch_val.to(device)
            outputs_en_val = model_ae.forward_encoder(batch_val.float())
            outputs_de_val = model_ae.forward_decoder(outputs_en_val.float())
            val_loss_en = loss_whiten(outputs_en_val)
            val_loss_de = loss_recon(outputs_de_val.double(), batch_val.double())

        val_loss_en = val_loss_en.cpu().detach().numpy()/len(test_loader)
        val_loss_de = val_loss_de.cpu().detach().numpy()/len(test_loader)

        val_whiten.append(val_loss_en)
        val_recon.append(val_loss_de)
        val_total.append(val_loss_en + val_loss_de)
        epoch_list.append(epoch)

        fig, ax = plt.subplots()
        ax.plot(epoch_list, val_whiten, 'g', label = "Whitening Loss")
        ax.plot(epoch_list, val_recon, 'r', label = "Reconstruction Loss")
        ax.set(xlabel='Epochs', ylabel='Loss',
              title='Loss over time')
        ax.legend()
        ax.grid()

        fig.savefig("test.png")
        plt.show()

        print("Validation: epoch : {}/{}, loss_whiten = {:.4f}, loss_recon = {:.4f}".format(epoch, epochs, val_loss_en, val_loss_de))

        torch.save({
            'epoch': epoch,
            'model_state_dict': model_ae.state_dict(),
            'optimizer_state_dict': optimizer_ae.state_dict(),
            'loss': loss
            }, "/content/drive/My Drive/Colab Notebooks/62_Epoch_"+str(epoch))

        if epoch >= 50:
              if all(i <= val_total[-1] for i in val_total[-10:-1]):

                  torch.save({
                      'epoch': epoch,
                      'model_state_dict': model_ae.state_dict(),
                      'optimizer_state_dict': optimizer_ae.state_dict(),
                      'loss': loss
                      }, "/content/drive/My Drive/Colab Notebooks/62_best")
                  break 

    # compute the epoch training loss
    loss = loss / len(train_loader)
    
    # display the epoch training loss
    print("epoch : {}/{}, loss = {:.4f}, loss_whiten = {:.4f}, loss_recon = {:.4f}".format(epoch + 1, epochs, loss, train_loss_en, train_loss_de))

In [None]:
#Saving of the model if needed.
torch.save({
            'epoch': epoch,
            'model_state_dict': model_ae.state_dict(),
            'optimizer_state_dict': optimizer_ae.state_dict(),
            'loss': loss
            }, "/content/drive/My Drive/Colab Notebooks/6_2")

In [None]:
#Loading of the model if needed.
model_ae = AE(input_shape=input_shape).to(device)
optimizer_ae = optim.Adam(model_ae.parameters(), lr=1e-4)

checkpoint = torch.load("/content/drive/My Drive/Colab Notebooks/Plane_Epoch_1300")
model_ae.load_state_dict(checkpoint['model_state_dict'])
optimizer_ae.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model_ae.eval()
# - or -
#model_ae.train()

In [None]:
#Pass the training data set through the model decoder and encoder
pred = np.empty_like(train_dataset)
for i in range(train_dataset.shape[0]):
    Input_1 = torch.from_numpy(train_dataset[i, ::].flatten()).float()
    outputs_en = model_ae.forward_encoder(Input_1.cuda())
    outputs_de = model_ae.forward_decoder(outputs_en.float())
    pred_burst = np.reshape(outputs_de.cpu().detach().numpy(), (int(outputs_de.cpu().detach().numpy().shape[0]/2), 2))
    pred[i, ::] = pred_burst
pred = np.reshape(pred, (tran_sq.shape[0], tran_sq.shape[1], tran_sq.shape[2]))

In [None]:
#Plot the reconstructed training data set.
data_sterio_flatt = np.reshape(pred, (pred.shape[0]*pred.shape[1], 2))
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set(xlabel='X', ylabel='Y',
      title='Decoder Output')

ax.scatter(data_sterio_flatt[::, 0], data_sterio_flatt[::, 1], c='g', alpha=0.1)

In [None]:
#Pass the training data set through the model encoder
pred_encode = np.empty_like(train_dataset)
for i in range(pred_encode.shape[0]):
    Input_1 = torch.from_numpy(train_dataset[i, ::].flatten()).float()
    outputs_en = model_ae.forward_encoder(Input_1.cuda())
    pred_burst = np.reshape(outputs_en.cpu().detach().numpy(), (int(outputs_en.cpu().detach().numpy().shape[0]/2), 2))
    pred_encode[i, ::] = pred_burst

In [None]:
#Plot the standardised untransformed data set.
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set(xlabel='X', ylabel='Y',
      title='Encoder Output')
ax.scatter(pred_encode[::, 0], pred_encode[::, 1], c='b', alpha=0.1)

In [None]:
def distance_plot(samples = 300):
    #This function computes the pairwise distance between points in the undelaying manifold and the encoder output, 
    #to schow how well the model is able to find the underlaying manifold a´by checking the isometry of the transformation.
    pred_encode = np.empty_like(train_dataset)
    for i in range(pred_encode.shape[0]):
        Input_1 = torch.from_numpy(train_dataset[i, ::].flatten()).float()
        outputs_en = model_ae.forward_encoder(Input_1.cuda())
        pred_burst = np.reshape(outputs_en.cpu().detach().numpy(), (int(outputs_en.cpu().detach().numpy().shape[0]/2), 2))
        pred_encode[i, ::] = pred_burst
    pred_data = np.reshape(pred_encode, (-1, 200, 2))
    real_data = np.reshape(train_dataset, (-1, 200, 2))

    pred_data_av = np.average(pred_data, axis = 1)
    real_data_av = np.average(real_data, axis = 1)

    real_distance = []
    pred_distance = []

    for d in range(samples):
        point_1 = np.random.randint(0, pred_data_av.shape[0])
        point_2 = np.random.randint(0, pred_data_av.shape[0])

        dist_pred = spatial.distance.euclidean(pred_data_av[point_1], pred_data_av[point_2])
        dist_real = spatial.distance.euclidean(real_data_av[point_1], real_data_av[point_2])

        real_distance.append(dist_real)
        pred_distance.append(dist_pred)

    return real_distance, pred_distance

In [None]:
#Plot the results
real_distance, pred_distance = distance_plot(samples = 2000)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_ylabel("Dist. in X")
ax.set_xlabel("Dist. in ρ")
ax.set(xlabel='Dist. in X', ylabel='Dist. in ρ',
      title='Distance Conservation')
ax.scatter(real_distance, pred_distance, c='g', alpha=0.5)

### Interploation Demonstration

In [None]:
def samples_grid(size= 200, samples=2000, sig_sq=0.01, gaussian = False):
    #This function creates a grid passing throught covering areas contained in the training set as well as areas not covered.
    center = np.empty(shape = (samples, 2))
    step_len = 0.8/(samples/6)
    for d in range(int(samples/6)):
        center[d, 0] = 0.1
        center[d, 1] = 0.1+step_len*d

        center[d+int(samples/6), 0] = 0.1+step_len*d
        center[d+int(samples/6), 1] = 0.9

        center[d+2*int(samples/6), 0] = 0.9
        center[d+2*int(samples/6), 1] = 0.1+step_len*d

        center[d+3*int(samples/6), 0] = 0.1+step_len*d
        center[d+3*int(samples/6), 1] = 0.1

        center[d+4*int(samples/6), 0] = 0.6
        center[d+4*int(samples/6), 1] = 0.1+step_len*d

        center[d+5*int(samples/6), 0] = 0.1+step_len*d
        center[d+5*int(samples/6), 1] = 0.6
        
    points = np.empty(shape = (samples, size, 2))
    
    for i in range(samples):
        pts = np.random.uniform([center[i, 0]-sig_sq, center[i, 1]-sig_sq], [center[i, 0]+sig_sq, center[i, 1]+sig_sq], (size, 2))
            
        points[i, ::, 0] = pts[::, 0] 
        points[i, ::, 1] = pts[::, 1] 
       
    return points, center

In [None]:
#Plots the grid thagt the model has to reconstruct through its encoder
points, center = samples_grid(size= 200, samples=2000, sig_sq=0.01, gaussian = False)
points_plt = np.reshape(points, (points.shape[0]*points.shape[1], 2))

fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_xlim([0,1])
ax.set_ylim([0,1])
ax.set(xlabel='X', ylabel='Y',
      title='Interpolation Test Grid')
ax.scatter(center[::, 0], center[::, 1], c='g', alpha=1, s = 4)
ax.scatter(points_plt[::, 0], points_plt[::, 1], c='b', alpha=0.006)

In [None]:
#Transforms the grid by the same transformationas the train set
tran_sq = np.empty_like(points)
tran_cent = np.empty_like(center)

for i in range(points.shape[0]):
    tran_cent[i, ::] =  deformation(center[i, ::])
    tran_cent[i, 0 ] = tran_cent[i, 0] - 1
    for j in range(points.shape[1]):
        tran_sq[i, j, ::] = deformation(points[i, j, ::])
        tran_sq[i, j, 0 ] = tran_sq[i, j, 0] - 1 

tran_sq_plt = np.reshape(tran_sq, (tran_sq.shape[0]*tran_sq.shape[1], 2))

fig = plt.figure()
ax = fig.add_subplot(111)

ax.set(xlabel='X', ylabel='Y',
      title='Model Input')

ax.scatter(tran_sq_plt[::, 0], tran_sq_plt[::, 1], c='b', alpha=0.05)
#ax.scatter(tran_cent[::, 0], tran_cent[::, 1], c='r', alpha=0.2)
train_dataset = np.reshape(tran_sq, (tran_sq.shape[0]*tran_sq.shape[1], tran_sq.shape[2]))

In [None]:
#Feeds the transfomed grid through the encoder
pred_encode = np.empty_like(train_dataset)
for i in range(pred_encode.shape[0]):
    Input_1 = torch.from_numpy(train_dataset[i, ::].flatten()).float()
    outputs_en = model_ae.forward_encoder(Input_1.cuda())
    pred_burst = np.reshape(outputs_en.cpu().detach().numpy(), (int(outputs_en.cpu().detach().numpy().shape[0]/2), 2))
    pred_encode[i, ::] = pred_burst

In [None]:
#Plots the encoder output
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set(xlabel='X', ylabel='Y',
      title='Encoder Output')

ax.scatter(pred_encode[::, 0], pred_encode[::, 1], c='b', alpha=0.1)

In [None]:
#Feeds the Grid though the encoder and decoder
pred = np.empty_like(train_dataset)
for i in range(train_dataset.shape[0]):
    Input_1 = torch.from_numpy(train_dataset[i, ::].flatten()).float()
    outputs_en = model_ae.forward_encoder(Input_1.cuda())
    outputs_de = model_ae.forward_decoder(outputs_en.float())
    pred_burst = np.reshape(outputs_de.cpu().detach().numpy(), (int(outputs_de.cpu().detach().numpy().shape[0]/2), 2))
    pred[i, ::] = pred_burst
pred = np.reshape(pred, (tran_sq.shape[0], tran_sq.shape[1], tran_sq.shape[2]))

In [None]:
#Plots the decoder output.
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set(xlabel='X', ylabel='Y',
      title='Decoder Output')
ax.scatter(pred[::, ::, 0], pred[::, ::, 1], c='g', alpha=0.1)