In [1]:
import numpy as np
from model import *
import yaml
import torch
import glob

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = yaml.safe_load(open('config.yaml'))['model']

In [3]:
data_l1_x = "/data/allen/climate_data/SPCAM5/inputs_*"
data_l1_y = "/data/allen/climate_data/SPCAM5/outputs_*"
data_l2_x = "/data/allen/climate_data/CAM5/inputs_*"
data_l2_y = "/data/allen/climate_data/CAM5/outputs_*"

l1_x_data = sorted(glob.glob(data_l1_x))
l1_y_data = sorted(glob.glob(data_l1_y))
l2_x_data = sorted(glob.glob(data_l2_x))
l2_y_data = sorted(glob.glob(data_l2_y))

In [4]:
l1_x = torch.from_numpy(np.load(l1_x_data[0]))
l1_y = torch.from_numpy(np.load(l1_y_data[0]))
l2_x = torch.from_numpy(np.load(l2_x_data[0]))
l2_y = torch.from_numpy(np.load(l2_y_data[0]))

In [5]:
model = Model(config)

In [7]:
l1_x_context = l1_x[:10].unsqueeze(0)
l1_y_context = l1_y[:10, :26].unsqueeze(0)
l2_x_context = l2_x[:10].unsqueeze(0)
l2_y_context = l2_y[:10, :26].unsqueeze(0)

l1_x_target = l1_x[10:20].unsqueeze(0)
l1_y_target = l1_y[10:20, :26].unsqueeze(0)
l2_x_target = l2_x[10:20].unsqueeze(0)
l2_y_target = l2_y[10:20, :26].unsqueeze(0)

In [14]:

class LatentEncoder(nn.Module):
    """
    Latent Encoder [For prior, posterior]
    """

    def __init__(self, config, level=1):
        super(LatentEncoder, self).__init__()
        if level == 1:
            input_dim = config['l1_input_dim']
            output_dim = config['l1_output_dim']
        if level == 2:
            input_dim = config['l2_input_dim']
            output_dim = config['l2_output_dim']
        hidden_dim = config['hidden_dim']

        attention_layers = config['attention_layers']

        self.input_projection = Linear(input_dim+output_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.self_attentions = nn.ModuleList(
            [Attention(config) for _ in range(attention_layers)])
        self.penultimate_layer = Linear(hidden_dim, hidden_dim, w_init='relu')
        self.mu = Linear(hidden_dim, hidden_dim)
        self.log_sigma = Linear(hidden_dim, hidden_dim)

    def forward(self, x, y):
        # concat location (x) and value (y)
        encoder_input = t.cat([x, y], dim=-1)

        # project vector with dimension 132+136 --> hidden_dim
        encoder_input = self.input_projection(encoder_input)
        encoder_input = self.layer_norm(encoder_input)

        # self attention layer
        for attention in self.self_attentions:
            encoder_input, _ = attention(
                encoder_input, encoder_input, encoder_input)

        # mean
        hidden = encoder_input.mean(dim=1)
        hidden = t.relu(self.penultimate_layer(hidden))
        # get mu and sigma
        mu = self.mu(hidden)
        log_sigma = self.log_sigma(hidden)

        # reparameterization trick
        std = t.exp(0.5 * log_sigma)
        eps = t.randn_like(std)
        z = eps.mul(std).add_(mu)

        # return distribution
        return mu, log_sigma, z

In [18]:
latent = LatentEncoder(config, level=1)

In [19]:
mu, log_sigma, z = latent(l1_x_context, l1_y_context)

In [20]:
mu.shape, log_sigma.shape, z.shape

(torch.Size([1, 64]), torch.Size([1, 64]), torch.Size([1, 64]))