In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.distributions import Normal, kl_divergence
import numpy as np
import matplotlib.pyplot as plt
import collections

## Data Generation

In [3]:
class GPCurvesReader(object):
    """Generates curves using a Gaussian Process (GP).

    Supports vector inputs (x) and vector outputs (y). Kernel is
    mean-squared exponential, using the x-value l2 coordinate distance scaled by
    some factor chosen randomly in a range. Outputs are independent gaussian
    processes.
    """

    def __init__(self,
               batch_size,
               max_num_context,
               x_size=1,
               y_size=1,
               l1_scale=0.4,
               sigma_scale=1.0,
               testing=False):
        """Creates a regression dataset of functions sampled from a GP.

        Args:
          batch_size: An integer.
          max_num_context: The max number of observations in the context.
          x_size: Integer >= 1 for length of "x values" vector.
          y_size: Integer >= 1 for length of "y values" vector.
          l1_scale: Float; typical scale for kernel distance function.
          sigma_scale: Float; typical scale for variance.
          testing: Boolean that indicates whether we are testing. If so there are
              more targets for visualization.
        """
        self._batch_size = batch_size
        self._max_num_context = max_num_context
        self._x_size = x_size
        self._y_size = y_size
        self._l1_scale = l1_scale
        self._sigma_scale = sigma_scale
        self._testing = testing

    def _gaussian_kernel(self, xdata, l1, sigma_f, sigma_noise=2e-2):
        """Applies the Gaussian kernel to generate curve data.

        Args:
          xdata: Tensor with shape `[batch_size, num_total_points, x_size]` with
              the values of the x-axis data.
          l1: Tensor with shape `[batch_size, y_size, x_size]`, the scale
              parameter of the Gaussian kernel.
          sigma_f: Float tensor with shape `[batch_size, y_size]`; the magnitude
              of the std.
          sigma_noise: Float, std of the noise that we add for stability.

        Returns:
          The kernel, a float tensor with shape
          `[batch_size, y_size, num_total_points, num_total_points]`.
        """
        num_total_points = xdata.shape[1]

        # Expand and take the difference
        xdata1 = torch.unsqueeze(xdata, 1)  # [B, 1, num_total_points, x_size]
        xdata2 = torch.unsqueeze(xdata, 2)  # [B, num_total_points, 1, x_size]
        diff = xdata1 - xdata2  # [B, num_total_points, num_total_points, x_size]

        # [B, y_size, num_total_points, num_total_points, x_size]
        norm = torch.square(diff[:, None, :, :, :] / l1[:, :, None, None, :])

        norm = torch.sum(norm, -1)  # [B, data_size, num_total_points, num_total_points]

        # [B, y_size, num_total_points, num_total_points]
        kernel = torch.square(sigma_f)[:, :, None, None] * torch.exp(-0.5 * norm)

        # Add some noise to the diagonal to make the cholesky work.
        kernel += (sigma_noise**2) * torch.eye(num_total_points)

        return kernel

    def generate_curves(self):
        """Builds the op delivering the data.

        Generated functions are `float32` with x values between -2 and 2.

        Returns:
          A `CNPRegressionDescription` namedtuple.
        """
        num_context = torch.randint(size=[], low=3, high=self._max_num_context)

        # If we are testing we want to have more targets and have them evenly
        # distributed in order to plot the function.
        if self._testing:
            num_target = 400
            num_total_points = num_target
            x_values = torch.unsqueeze(torch.arange(-2.,2.,1./100), 0).repeat([self._batch_size, 1])
            x_values = torch.unsqueeze(x_values, axis=-1)
        # During training the number of target points and their x-positions are
        # selected at random
        else:
            num_target = torch.randint(size=[], low=2, high=self._max_num_context)
            num_total_points = num_context + num_target
            x_values = torch.Tensor(self._batch_size, num_total_points, self._x_size).uniform_(-2, 2)

        # Set kernel parameters
        l1 = torch.ones([self._batch_size, self._y_size, self._x_size])*self._l1_scale
        sigma_f = torch.ones([self._batch_size, self._y_size]) * self._sigma_scale

        # Pass the x_values through the Gaussian kernel
        # [batch_size, y_size, num_total_points, num_total_points]
        kernel = self._gaussian_kernel(x_values, l1, sigma_f)

        # Calculate Cholesky, using double precision for better stability:
        cholesky = torch.cholesky(kernel.double()).float()

        # Sample a curve
        # [batch_size, y_size, num_total_points, 1]
        y_values = torch.matmul(cholesky,torch.normal(0.0, 1.0, [self._batch_size, self._y_size, num_total_points, 1]))

        # [batch_size, num_total_points, y_size]
        y_values = torch.squeeze(y_values, 3).permute(0, 2, 1)

        if self._testing:
          # Select the targets
            target_x = x_values
            target_y = y_values

            # Select the observations
            #### MAYBE WRONG
            idx = torch.randperm(num_target)
            context_x = x_values[:, idx[:num_context], :]
            context_y = y_values[:, idx[:num_context], :]

        else:
            # Select the targets which will consist of the context points as well as
            # some new target points
            target_x = x_values[:, :num_target + num_context, :]
            target_y = y_values[:, :num_target + num_context, :]

            # Select the observations
            context_x = x_values[:, :num_context, :]
            context_y = y_values[:, :num_context, :]


        return context_x, context_y, target_x, target_y

## Utility Methods

In [4]:
def forward_pass(x, linears):
    batch_size, num_context_points, filter_size = x.shape
    x = x.reshape(batch_size * num_context_points, -1)
    # Pass through MLP
    for lay in self.Linears[:-1]:
        x = F.relu(lay(x))
    # Last layer without a ReLu
    x = self.Linears[-1](x)
    
    x = x.reshape(batch_size, num_context_points, x.shape[-1])

    return x
    

## Encoder Deterministic Path

In [5]:
class DeterministicEncoder(nn.Module):
    """The Deterministic Encoder."""

    def __init__(self, layer_sizes, attention):
        """(A)NP deterministic encoder.

        Args:
          output_sizes: An iterable containing the output sizes of the encoding MLP.
          attention: The attention module.
        """
        super(DeterministicEncoder, self).__init__()
        self.Linears = nn.ModuleList([
            nn.Linear(s1, s2) for s1, s2 in zip(layer_sizes[:-1], layer_sizes[1:])
        ])
        self.attention = attention

    def forward(self, context_x, context_y, target_x):
        """Encodes the inputs into one representation.

        Args:
          context_x: Tensor of shape [B,observations,d_x]. For this 1D regression
              task this corresponds to the x-values.
          context_y: Tensor of shape [B,observations,d_y]. For this 1D regression
              task this corresponds to the y-values.
          target_x: Tensor of shape [B,target_observations,d_x]. 
              For this 1D regression task this corresponds to the x-values.

        Returns:
          The encoded representation. Tensor of shape [B,target_observations,d]
        """

        # Concatenate x and y along the filter axes
        x = torch.cat([context_x, context_y], axis=-1)
        x = forward_pass(x)
        x = self.attention(context_x, target_x, x)

        return x

## Encoder : Latent Path

In [None]:
class LatentEncoder(nn.Module):
    """The Latent Encoder."""
    def __init__(self, layer_sizes, num_latent):
        """(A)NP latent encoder.

        Args:
          output_sizes: An iterable containing the output sizes of the encoding MLP.
          num_latents: The latent dimensionality.
        """
        super(LatentEncoder, self).__init__()
        self.Linears = nn.ModuleList([
            nn.Linear(s1, s2)
            for s1, s2 in zip(layer_sizes[:-2], layer_sizes[1:-1])
        ])
        last_layer_size = (layer_sizes[-1] + num_latent)/2
        self.penultimate_layer = nn.Linear(layer_sizes[-2],last_layer_size)
        self.mu_layer = nn.Linear(last_layer_size, num_latent)
        self.std_layer = nn.Linear(last_layer_size, num_latent)

    def forward(self, tx, ty):
        """Encodes the inputs into one representation.

    Args:
      x: Tensor of shape [B,observations,d_x]. For this 1D regression
          task this corresponds to the x-values.
      y: Tensor of shape [B,observations,d_y]. For this 1D regression
          task this corresponds to the y-values.

    Returns:
      A normal distribution over tensors of shape [B, num_latents]
    """

        # Concatenate x and y along the filter axes
        x = torch.cat([tx, ty], axis=-1)
        x = forward_pass(x, self.Linears)

        # Aggregator: take the mean over all points
        x = x.mean(axis=1)

        # First apply intermediate relu layer
        x = F.relu(self.penultimate_layer(x))
            
        # Then apply further linear layers to output latent mu and log sigma
        mu = self.mu_layer(x)
        log_sigma = self.mu_layer(x)

        # Compute sigma
        sigma = 0.1 + 0.9 * tf.sigmoid(log_sigma)

        return Normal(loc=mu, scale=sigma)

## Decoder

In [None]:
class Decoder(nn.Module):
    """The Decoder."""
    def __init__(self, layer_sizes):
        """(A)NP decoder.

        Args:
          output_sizes: An iterable containing the output sizes of the decoder MLP 
              as defined in `basic.Linear`.
        """
        super(Decoder, self).__init__()
        self.Linears = nn.ModuleList([
            nn.Linear(s1, s2)
            for s1, s2 in zip(layer_sizes[:-1], layer_sizes[1:])
        ])

    def forward(self, representation, target_x):
        """Decodes the individual targets.

        Args:
          representation: The representation of the context for target predictions. 
              Tensor of shape [B,target_observations,?].
          target_x: The x locations for the target query.
              Tensor of shape [B,target_observations,d_x].

        Returns:
          dist: A multivariate Gaussian over the target points. A distribution over
              tensors of shape [B,target_observations,d_y].
          mu: The mean of the multivariate Gaussian.
              Tensor of shape [B,target_observations,d_x].
          sigma: The standard deviation of the multivariate Gaussian.
              Tensor of shape [B,target_observations,d_x].
        """
        # concatenate target_x and representation
        x = torch.cat([representation, target_x], axis=-1)

        x = forward_pass(x, self.Linears)

        # Get the mean an the variance
        mu, log_sigma = torch.split(hidden, 1, dim=-1)

        # Bound the variance
        sigma = 0.1 + 0.9 * F.softplus(log_sigma)

        # Get the distribution
        dist = Independent(Normal(mu, sigma), 1)

        return dist, mu, sigma

## ANP Model

In [6]:
class LatentModel(nn.Module):
    """The (A)NP model."""
    def __init__(self,
                 latent_encoder_layer_sizes,
                 num_latents,
                 decoder_layer_sizes,
                 use_deterministic_path=True,
                 deterministic_encoder_layer_sizes=None,
                 attention=None):
        """Initialises the model.

        Args:
          latent_encoder_output_sizes: An iterable containing the sizes of hidden 
              layers of the latent encoder.
          num_latents: The latent dimensionality.
          decoder_output_sizes: An iterable containing the sizes of hidden layers of
              the decoder. The last element should correspond to d_y * 2
              (it encodes both mean and variance concatenated)
          use_deterministic_path: a boolean that indicates whether the deterministic
              encoder is used or not.
          deterministic_encoder_output_sizes: An iterable containing the sizes of 
              hidden layers of the deterministic encoder. The last one is the size 
              of the deterministic representation r.
          attention: The attention module used in the deterministic encoder.
              Only relevant when use_deterministic_path=True.
        """
        self._latent_encoder = LatentEncoder(latent_encoder_output_sizes,num_latents)
        self._decoder = Decoder(decoder_output_sizes)
        self._use_deterministic_path = use_deterministic_path
        if use_deterministic_path:
            self._deterministic_encoder = DeterministicEncoder(deterministic_encoder_output_sizes, attention)

    def forward(forward, context_x, context_y, target_x, target_y=None):
        # Pass query through the encoder and the decoder
        prior = self._latent_encoder(context_x, context_y)

        # For training, when target_y is available, use targets for latent encoder.
        # Note that targets contain contexts by design.
        if target_y is None:
            latent_rep = prior.sample()
        # For testing, when target_y unavailable, use contexts for latent encoder.
        else:
            posterior = self._latent_encoder(target_x, target_y)
            latent_rep = posterior.sample()
            
        latent_rep = latent_rep.unsqueeze(1).repeat([1, num_targets, 1])
        
        if self._use_deterministic_path:
            deterministic_rep = self._deterministic_encoder(context_x, context_y, target_x)
            representation = torch.cat([deterministic_rep, latent_rep],axis=-1)
        else:
            representation = latent_rep

        dist, mu, sigma = self._decoder(representation, target_x)

        # If we want to calculate the log_prob for training we will make use of the
        # target_y. At test time the target_y is not available so we return None.
        if target_y is not None:
            log_p = dist.log_prob(target_y)
            posterior = self._latent_encoder(target_x, target_y)
            kl = torch.sum(kl_divergence(posterior, prior),-1,keepdim=True)
            kl = kl.repeat([1, num_targets])
            loss = -torch.mean(log_p -kl / num_targets.float())
        else:
            log_p = None
            kl = None
            loss = None

        return mu, sigma, log_p, kl, loss

### Cross-Attention Module

In [None]:
def uniform_attention(q, v):
    """Uniform attention. Equivalent to np.

  Args:
    q: queries. tensor of shape [B,m,d_k].
    v: values. tensor of shape [B,n,d_v].
    
  Returns:
    tensor of shape [B,m,d_v].
  """
    total_points = q.shape[1]
    rep = torch.mean(v, axis=1, keepdims=True)  # [B,1,d_v]
    rep = rep.repeat([1, total_points, 1])
    return rep


def laplace_attention(q, k, v, scale, normalise):
    """Computes laplace exponential attention.

  Args:
    q: queries. tensor of shape [B,m,d_k].
    k: keys. tensor of shape [B,n,d_k].
    v: values. tensor of shape [B,n,d_v].
    scale: float that scales the L1 distance.
    normalise: Boolean that determines whether weights sum to 1.
    
  Returns:
    tensor of shape [B,m,d_v].
  """
    k = torch.unsqueeze(k, 1)  # [B,1,n,d_k]
    q = torch.unsqueeze(q, 2)  # [B,m,1,d_k]
    unnorm_weights = -torch.abs((k - q) / scale)  # [B,m,n,d_k]
    unnorm_weights = torch.sum(unnorm_weights, axis=-1)  # [B,m,n]
    if normalise:
        weight_fn = F.softmax
    else:
        weight_fn = lambda x: 1 + torch.tanh(x)
    weights = weight_fn(unnorm_weights)  # [B,m,n]
    rep = torch.einsum('bik,bkj->bij', weights, v)  # [B,m,d_v]
    return rep


def dot_product_attention(q, k, v, normalise):
    """Computes dot product attention.

  Args:
    q: queries. tensor of  shape [B,m,d_k].
    k: keys. tensor of shape [B,n,d_k].
    v: values. tensor of shape [B,n,d_v].
    normalise: Boolean that determines whether weights sum to 1.
    
  Returns:
    tensor of shape [B,m,d_v].
  """
    d_k = q.shape[-1]
    scale = torch.sqrt(d_k.float())
    unnorm_weights = torch.einsum('bjk,bik->bij', k, q) / scale  # [B,m,n]
    if normalise:
        weight_fn = F.softmax
    else:
        weight_fn = F.sigmoid
    weights = weight_fn(unnorm_weights)  # [B,m,n]
    rep = torch.einsum('bik,bkj->bij', weights, v)  # [B,m,d_v]
    return rep


def create_multihead_network(q_last_shape,k_last_shape,v_last_shape,num_heads=8):
    
    d_k = q_last_shape
    d_v = v_last_shape
    head_size = d_v / num_heads
    conv_layers = []
    for h in range(num_heads):
        q_conv = nn.Conv1d(q_last_shape, head_size,1, bias = False, padding=0)
        nn.init.normal(q_conv.weights, std =q_last_shape**(-.5))
        k_conv = nn.Conv1d(k_last_shape, head_size,1, bias = False, padding=0)
        nn.init.normal(k_conv.weights, std =q_last_shape**(-.5))
        v_conv = nn.Conv1d(v_last_shape, head_size,1, bias = False, padding=0)
        nn.init.normal(v_conv.weights, std =q_last_shape**(-.5))
        rep_conv = nn.Conv1d(v_last_shape, v_last_shape,1, bias = False, padding=0)
        nn.init.normal(rep_conv.weights, std =v_last_shape**(-.5))
        
        conv_layers.append([q_conv, k_conv,v_conv, rep_conv])
        
        return conv_layers

def multihead_attention(conv_layers, q, k, v, num_heads=8):
    """Computes multi-head attention.

  Args:
    q: queries. tensor of  shape [B,m,d_k].
    k: keys. tensor of shape [B,n,d_k].
    v: values. tensor of shape [B,n,d_v].
    num_heads: number of heads. Should divide d_v.
    
  Returns:
    tensor of shape [B,m,d_v].
  """

    rep = tf.constant(0.0)
    
    for h in range(num_heads):
        o = dot_product_attention(
            conv_layers[h][0](q),
            conv_layers[h][1](k),
            conv_layers[h][2](v), normalise=True)
        rep += conv_layers[h][3](o)
    return rep


class Attention(nn.Module):
    """The Attention module."""
    def __init__(self,rep,layer_sizes,att_type,scale=1.,normalise=True,num_heads=8):
        """Create attention module.

        Takes in context inputs, target inputs and
        representations of each context input/output pair
        to output an aggregated representation of the context data.
        Args:
          rep: transformation to apply to contexts before computing attention. 
              One of: ['identity','mlp'].
          output_sizes: list of number of hidden units per layer of mlp.
              Used only if rep == 'mlp'.
          att_type: type of attention. One of the following:
              ['uniform','laplace','dot_product','multihead']
          scale: scale of attention.
          normalise: Boolean determining whether to:
              1. apply softmax to weights so that they sum to 1 across context pts or
              2. apply custom transformation to have weights in [0,1].
          num_heads: number of heads for multihead.
        """
        super(Attention, self).__init__()
        self._rep = rep
        self.Linears = nn.ModuleList([
            nn.Linear(s1, s2)
            for s1, s2 in zip(layer_sizes[:-1], layer_sizes[1:])
        ])
        self._type = att_type
        self._scale = scale
        self._normalise = normalise
        if self._type == 'multihead':
            self._num_heads = num_heads
            self._conv_layers = create_multihead_network
            
    def forward(self, x1, x2, r):
        """Apply attention to create aggregated representation of r.

        Args:
          x1: tensor of shape [B,n1,d_x].
          x2: tensor of shape [B,n2,d_x].
          r: tensor of shape [B,n1,d].

        Returns:
          tensor of shape [B,n2,d]

        Raises:
          NameError: The argument for rep/type was invalid.
        """
        if self._rep == 'identity':
            k, q = (x1, x2)
        elif self._rep == 'mlp':
            # Pass through MLP
            k = forward_pass(x1, self.Linears)
            q = forward_pass(x2, self.Linears)
        else:
            raise NameError("'rep' not among ['identity','mlp']")

        if self._type == 'uniform':
            rep = uniform_attention(q, r)
        elif self._type == 'laplace':
            rep = laplace_attention(q, k, r, self._scale, self._normalise)
        elif self._type == 'dot_product':
            rep = dot_product_attention(q, k, r, self._normalise)
        elif self._type == 'multihead':
            rep = multihead_attention(q, k, r, self._num_heads)
        else:
            raise NameError(
                ("'att_type' not among ['uniform','laplace','dot_product'"
                 ",'multihead']"))

        return rep

In [None]:
c = nn.Conv1d(10,11,1 bias = False, padding=0)

c(torch.randn(1,10,6)).shape
