In [None]:
import torch 
from omegaconf import OmegaConf

In [None]:

from neural_processes.model import  Encoder, LatentDecoder, Aggregator, LatentNP
from neural_processes.model.attention import Attention, MLP
from dataset import GPCurvesReader
from neural_processes.train import train

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
batch_size, num_samples, x_size, y_size = 100, 10, 1, 1
num_targets = 5
context_X = torch.randn(batch_size, num_samples, x_size)
context_y = torch.randn(batch_size, num_samples, y_size)
target_X = torch.randn(batch_size, num_targets, x_size)

In [None]:
attn = Attention()
r= attn(query=context_X, key=context_X, value=context_y)
print(r.shape)
output = attn(query=target_X, key=context_X, value=r)
print(output.shape)


# Neural Processes 
Latent distribution to model functional uncertainty: $z = \mathcal{N}(\mu(r), I\sigma(r))$

In [None]:
with open("config/np_config.yaml") as file:
        cfg = OmegaConf.load(file)

encoder_num_layers = cfg.encoder.num_layers
encoder_num_units = cfg.encoder.num_units
encoder_activation_cls = cfg.encoder.activation_cls
decoder_num_layers = cfg.decoder.num_layers
decoder_num_units = cfg.decoder.num_units
decoder_activation_cls = cfg.decoder.activation_cls
aggregator_num_layers = cfg.aggregator.num_layers
aggregator_num_units = cfg.aggregator.num_units
aggregator_activation_cls = cfg.aggregator.activation_cls
r_dim = cfg.r_dim
y_size = cfg.dataset.y_size
x_size = cfg.dataset.x_size
max_num_context = cfg.dataset.max_num_context
batch_size = cfg.dataset.batch_size

In [None]:
# 1d regression dataset, sampled from a GP
data_train = GPCurvesReader(batch_size=batch_size, max_num_context=max_num_context)
data_test = GPCurvesReader(batch_size=batch_size, max_num_context=max_num_context, testing=True)

In [None]:
latent_encoder = Encoder(x_size=x_size, r_dim=r_dim, y_size=y_size, num_layers=encoder_num_layers, num_units=encoder_num_units, activation_cls=encoder_activation_cls)
latent_decoder = LatentDecoder(x_size=x_size, r_dim=r_dim, y_size=y_size, num_layers=decoder_num_layers, num_units=decoder_num_units, activation_cls=decoder_activation_cls)
aggregator = Aggregator(r_dim=r_dim, num_layers=aggregator_num_layers, num_units=aggregator_num_units, activation_cls=aggregator_activation_cls)

In [None]:
latent_np = LatentNP(encoder_num_layers=encoder_num_layers,
                 encoder_num_units=encoder_num_units,
                 encoder_activation_cls=encoder_activation_cls,
                 decoder_num_layers=decoder_num_layers, 
                 decoder_num_units=decoder_num_units,
                 decoder_activation_cls=decoder_activation_cls,
                 agggreagtor_num_layers=aggregator_num_layers,
                 agggreagtor_num_units=aggregator_num_units,
                 agggreagtor_activation_cls=aggregator_activation_cls,
                 r_dim=r_dim,
                 x_size=x_size,
                 y_size=y_size)
latent_np

In [None]:
latent_np = train(config=cfg, model=latent_np, data_train=data_train, data_test=data_test)

In [None]:
torch.save(latent_np, 'neural_processes/trained_model/np_elbo.pt')