## Demo of 2D regression with an Attentive Neural Process (ANP) model

This notebook will provide a simple and straightforward demonstration on how to utilize an Attentive Neural Process (ANP) to perform regression on images in the MNIST dataset.

**Note:**: the training time for this model is very lengthy; a GPU is recommended to reduce the training time. Comment back in `.cuda` in the code below if you decide to use one.

First, we need to import all necessary packages and modules for our task:

In [None]:
import os
import sys
import torch
from matplotlib import pyplot as plt
import mnist
import random
import numpy as np

# Provide access to modules in repo.
sys.path.insert(0, os.path.abspath('neural_process_models'))

from neural_process_models.anp import ANP_Model

Each data point (image) in the MNIST dataset is represented by a 2D (28 x 28) array, with each cell value existing in the range 0-255.

Let us retrieve the prepare this dataset for training:

In [None]:
test_images = mnist.test_images()  # (10000 x 28 x 28)
test_images = (test_images / 255.0)  # normalize pixel values

data_size = len(test_images)
test_images = np.resize(test_images, (10000, 28, 28, 1))

Notice that we normalized the pixel values above. The normalized pixel values will be the y-values of the context points in the training dataset, while x-values represent a pair of pixel indices indicating where the pixel with the corresponding y-value is found.

Let us initialize our model. The ANP model is implemented under the `NeuralProcessModel` class under the file `neural_process_models/attentive_neural_process.py`.

We will use the following parameters for our example model:
* 2 for x-dimension and 1 for y-dimension (as explained above)
* 4 hidden layers of dimension 256 for encoders and decoder
* 256 as the latent dimension for encoders and decoder
* We will utilize dot attention for the self-attention process.
* We will utilize multihead attention for the cross-attention process.
* We will utilize a deterministic path for the encoder.

In [None]:
model = ANP_Model(x_dim=2,  # x_dim: pixel index (0-27 x 0-27)
                  y_dim=1,  # y_dim: normalized pixel value (0-1)
                  mlp_hidden_size_list=[256, 256, 256, 256],
                  latent_dim=256,
                  use_rnn=False,
                  use_self_attention=False,
                  use_deter_path=True)#.cuda()

And let's set some hyperparameters for our tuning:

In [None]:
optim = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 10000
batch_size = 16
num_context = 400

Now, let us train our model. For each epoch, we will print the loss at that epoch.

In [None]:
print("Training...")

for epoch in range(1, num_epochs + 1):
    print("step = " + str(epoch))

    model.train()

    plt.clf()
    optim.zero_grad()

    ctt_x, ctt_y, tgt_x, tgt_y = list(), list(), list(), list()

    sample_context_indices = random.sample(range(data_size), batch_size)

    for context_idx in sample_context_indices:
        pixel_indices = random.sample(range(784), num_context)

        c_x, c_y = list(), list()
        for pixel_idx in pixel_indices:
            pixel_x = (pixel_idx // 28) / 27.0
            pixel_y = (pixel_idx % 28) / 27.0

    		c_x.append([pixel_x, pixel_y])
    		c_y.append(test_images[context_idx][pixel_x][pixel_y])

    	ctt_x.append(c_x)
    	ctt_y.append(c_y)

    sample_target_indices = random.sample(range(data_size), batch_size)

    for target_idx in sample_target_indices:
        t_x, t_y = list(), list()
        for pixel_x in range(28):
        	for pixel_y in range(28):
	            t_x.append([pixel_x, pixel_y])
	            t_y.append(test_images[target_idx][pixel_x][pixel_y])

        tgt_x.append(t_x)
        tgt_y.append(t_y)

    ctt_x = torch.FloatTensor(ctt_x)#.cuda()
    ctt_y = torch.FloatTensor(ctt_y)#.cuda()
    tgt_x = torch.FloatTensor(tgt_x)#.cuda()
    tgt_y = torch.FloatTensor(tgt_y)#.cuda()


    # ctt_x: (batch_size x num_context x 2), ctt_y: (batch_size x 784 x 1)
    # tgt_x: (batch_size x num_context x 2), tgt_y: (batch_size x 784 x 1)
    mu, sigma, log_p, kl, loss = model(ctt_x, ctt_y, tgt_x, tgt_y)

    # print('kl =', kl)
    print('loss = ', loss)
    # print('mu.size() =', mu.size())
    # print('sigma.size() =', sigma.size())

    # tgt_x_np = tgt_x[0, :, :].squeeze(-1).numpy()
    # print('tgt_x_np.shape =', tgt_x_np.shape)

    loss.backward()
    optim.step()

    model.eval()
    plt.ion()
    # fig = plt.figure()
    
    # Visualize first target image.
    pred_y = mu[0].view(28, 28).detach().numpy()

    plt.axis('off')
    #plt.imshow(torch.sigmoid(tgt_y).squeeze(0).view(-1, 28).detach().numpy())
    plt.imshow(pred_y)

    title_str = 'Training at epoch ' + str(epoch)
    plt.title(title_str)
    plt.savefig(title_str + ".png")
    plt.pause(0.1)

plt.ioff()
plt.show()