In [1]:
import numpy as np
from MF_Attn_process.lib.model import *
import yaml
import torch
import glob

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = yaml.safe_load(open('config.yaml'))['model']

In [3]:
data_l1_x = "/data/allen/climate_data/SPCAM5/inputs_*"
data_l1_y = "/data/allen/climate_data/SPCAM5/outputs_*"
data_l2_x = "/data/allen/climate_data/CAM5/inputs_*"
data_l2_y = "/data/allen/climate_data/CAM5/outputs_*"

l1_x_data = sorted(glob.glob(data_l1_x))
l1_y_data = sorted(glob.glob(data_l1_y))
l2_x_data = sorted(glob.glob(data_l2_x))
l2_y_data = sorted(glob.glob(data_l2_y))

l1_x = torch.from_numpy(np.load(l1_x_data[0]))
l1_y = torch.from_numpy(np.load(l1_y_data[0]))
l2_x = torch.from_numpy(np.load(l2_x_data[0]))
l2_y = torch.from_numpy(np.load(l2_y_data[0]))

l1_x_all = l1_x[:10].unsqueeze(0)
l1_y_all = l1_y[:10, :26].unsqueeze(0)
l2_x_all = l2_x[:10].unsqueeze(0)
l2_y_all = l2_y[:10, :26].unsqueeze(0)

l1_x_context = l1_x[:10].unsqueeze(0)
l1_y_context = l1_y[:10, :26].unsqueeze(0)
l2_x_context = l2_x[:10].unsqueeze(0)
l2_y_context = l2_y[:10, :26].unsqueeze(0)

l1_x_target = l1_x[10:20].unsqueeze(0)
l1_y_target = l1_y[10:20, :26].unsqueeze(0)
l2_x_target = l2_x[10:20].unsqueeze(0)
l2_y_target = l2_y[10:20, :26].unsqueeze(0)

In [19]:
class MLP_Z1Z2_Encoder(nn.Module):

    def __init__(self, config):
        dim = config['hidden_dim']
        in_dim = dim * 2 # hidden, l_r/l_z
        hidden_dim = dim
        out_dim = dim
        hidden_layers = config['hidden_layers']
        
        nn.Module.__init__(self)
        layers = [nn.Linear(in_dim, hidden_dim), nn.ELU()]
        for _ in range(hidden_layers - 1):
            # layers.append(nn.LayerNorm(hidden_dim))
            layers += [nn.Linear(hidden_dim, hidden_dim), nn.ELU()]
        layers.append(nn.Linear(hidden_dim, hidden_dim))

        self.model = nn.Sequential(*layers)
        self.mean_out = nn.Linear(hidden_dim, out_dim)

    def forward(self, l_z, l_r):
        output = self.model(torch.cat([l_z, l_r], dim=-1))
        mean = self.mean_out(output)

        return mean

In [20]:
model = Model(config)
z2_z1_encoder = MLP_Z1Z2_Encoder(config)

In [37]:
class LatentEncoder(nn.Module):
    """
    Latent Encoder [For prior, posterior]
    """

    def __init__(self, config, level=1):
        super(LatentEncoder, self).__init__()
        if level == 1:
            input_dim = config['l1_input_dim']
            output_dim = config['l1_output_dim']
        if level == 2:
            input_dim = config['l2_input_dim']
            output_dim = config['l2_output_dim']
            self.l1z_l2z_encoder = MLP_Z1Z2_Encoder(config)
        hidden_dim = config['hidden_dim']

        attention_layers = config['attention_layers']

        self.input_projection = Linear(input_dim+output_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.self_attentions = nn.ModuleList(
            [Attention(config) for _ in range(attention_layers)])
        self.penultimate_layer = Linear(hidden_dim, hidden_dim, w_init='relu')
        self.mu = Linear(hidden_dim, hidden_dim)
        self.log_sigma = Linear(hidden_dim, hidden_dim)

    def forward(self, x, y, l_z=None):
        # concat location (x) and value (y)
        encoder_input = t.cat([x, y], dim=-1)
            

        # project vector with dimension 132+136 --> hidden_dim
        encoder_input = self.input_projection(encoder_input)
        encoder_input = self.layer_norm(encoder_input)

        # self attention layer
        for attention in self.self_attentions:
            encoder_input, _ = attention(
                encoder_input, encoder_input, encoder_input)

        # mean
        hidden = encoder_input.mean(dim=1)
        hidden = t.relu(self.penultimate_layer(hidden))

        # z_mu combine with hidden if level==2
        if l_z is not None:
            hidden = self.l1z_l2z_encoder(hidden, l_z)
            print(hidden.shape)


        # get mu and sigma
        mu = self.mu(hidden)
        log_sigma = self.log_sigma(hidden)

        # reparameterization trick
        std = t.exp(0.5 * log_sigma)
        eps = t.randn_like(std)
        z = eps.mul(std).add_(mu)

        # return distribution
        return mu, log_sigma, z

In [38]:
class DeterministicEncoder(nn.Module):
    """
    Deterministic Encoder [r]
    """

    def __init__(self, config, level=1):
        super(DeterministicEncoder, self).__init__()
        if level == 1:
            input_dim = config['l1_input_dim']
            output_dim = config['l1_output_dim']
        if level == 2:
            input_dim = config['l2_input_dim']
            output_dim = config['l2_output_dim']
            self.l1r_l2r_encoder = MLP_Z1Z2_Encoder(config)
        hidden_dim = config['hidden_dim']
        attention_layers = config['attention_layers']

        self.self_attentions = nn.ModuleList(
            [Attention(config) for _ in range(attention_layers)])
        self.cross_attentions = nn.ModuleList(
            [Attention(config) for _ in range(attention_layers)])
        self.input_projection = Linear(input_dim+output_dim, hidden_dim)
        self.context_projection = Linear(input_dim, hidden_dim)
        self.target_projection = Linear(input_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, context_x, context_y, target_x, l_r=None):
        # concat context location (x), context value (y)
        encoder_input = t.cat([context_x, context_y], dim=-1)

        # project vector with dimension 132+136 --> num_hidden
        encoder_input = self.input_projection(encoder_input)
        encoder_input = self.layer_norm(encoder_input)

        # self attention layer
        for attention in self.self_attentions:
            encoder_input, _ = attention(
                encoder_input, encoder_input, encoder_input)

        # query: target_x, key: context_x, value: representation
        query = self.target_projection(target_x)
        keys = self.context_projection(context_x)

        # cross attention layer
        for attention in self.cross_attentions:
            query, _ = attention(keys, encoder_input, query)

        if l_r is not None:
            query = self.l1r_l2r_encoder(query, l_r)

        return query

In [39]:
l1_latent = LatentEncoder(config, level=1)
l2_latent = LatentEncoder(config, level=2)
l1_determ = DeterministicEncoder(config, level=1)
l2_determ = DeterministicEncoder(config, level=2)
l1_decoder = Decoder(config, level=1)
l2_decoder = Decoder(config, level=2)

In [45]:
l1_y_target = None
l2_y_target = None

In [46]:
l1_z_mu_c, l1_z_cov_c, l1_prior_z = l1_latent(l1_x_context, l1_y_context)
l2_z_mu_c, l2_z_cov_c, l2_prior_z = l2_latent(l2_x_context, l2_y_context, l1_z_mu_c)

if l1_y_target is not None:
    l1_z_mu_all, l1_z_cov_all, l1_posterior_z = l1_latent(l1_x_target, l1_y_target)
    l1_z = l1_posterior_z
else:
    l1_z = l1_prior_z

if l2_y_target is not None:
    l2_z_mu_all, l2_z_cov_all, l2_posterior_z = l2_latent(l2_x_target, l2_y_target, l1_z_mu_all)
    l2_z = l2_posterior_z
else:
    l2_z = l2_prior_z


l1_r = l1_determ(l1_x_context, l1_y_context, l1_x_target)  # [B, T_target, H]
l2_r = l2_determ(l2_x_context, l2_y_context, l2_x_target, l1_r)

torch.Size([1, 64])


In [47]:
l1_z = l1_z_mu_c.unsqueeze(1).repeat(1, l1_x_target.size(1), 1)
l2_z = l2_z_mu_c.unsqueeze(1).repeat(1, l2_x_target.size(1), 1)
l1_output_mu, l1_output_cov = l1_decoder(l1_r, l1_z, l1_x_target)
l2_output_mu, l2_output_cov = l2_decoder(l2_r, l2_z, l2_x_target)