<a href="https://colab.research.google.com/github/SashM9/proj_2022/blob/main/synth_dnri.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import argparse

def build_flags():
    parser = argparse.ArgumentParser('')
    parser.add_argument('--working_dir', required=True)
    parser.add_argument('--gpu', action='store_true')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--mode', choices=['train', 'eval', 'eval_fixedwindow'], required=True)
    parser.add_argument('--load_model')
    parser.add_argument('--load_best_model', action='store_true')
    parser.add_argument('--continue_training', action='store_true')
    parser.add_argument('--model_type', choices=['nri', 'dnri', 'fc_baseline'])

    # Training Params
    parser.add_argument('--num_epochs', type=int)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--mom', type=float, default=0)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--sub_batch_size', type=int)
    parser.add_argument('--val_batch_size', type=int)
    parser.add_argument('--val_interval', type=int, default=5)
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--use_adam', action='store_true')
    parser.add_argument('--lr_decay_factor', type=float)
    parser.add_argument('--lr_decay_steps', type=int)
    parser.add_argument('--clip_grad_norm', type=float)
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--tune_on_nll', action='store_true')
    parser.add_argument('--val_teacher_forcing', action='store_true')
    parser.add_argument('--accumulate_steps', type=int, default=1)
    parser.add_argument('--max_burn_in_count', type=int, default=-1)
    
    # Model Params
    parser.add_argument('--no_prior', action='store_true')
    parser.add_argument('--avg_prior', action='store_true')
    parser.add_argument('--add_uniform_prior', action='store_true')
    parser.add_argument('--prior_num_layers', type=int, default=1)
    parser.add_argument('--prior_hidden_size', type=int, default=256)
    parser.add_argument('--use_learned_prior', action='store_true')
    parser.add_argument('--graph_type', choices=['static', 'dynamic'])
    parser.add_argument('--avg_encoder_inputs', action='store_true')
    parser.add_argument('--use_dynamic_graph', action='store_true')
    parser.add_argument('--use_static_encoder', action='store_true')
    parser.add_argument('--decoder_type')
    parser.add_argument('--encoder_rnn_type', choices=['lstm', 'gru'], default='lstm')
    parser.add_argument('--decoder_rnn_type', choices=['lstm', 'gru'], default='gru')
    parser.add_argument('--encoder_hidden', type=int, default=256)
    parser.add_argument('--encoder_rnn_hidden', type=int)
    parser.add_argument('--num_edge_types', type=int, default=2)
    parser.add_argument('--encoder_dropout', type=float, default=0.0)
    parser.add_argument('--encoder_unidirectional', action='store_true')
    parser.add_argument('--encoder_bidirectional', action='store_true')
    parser.add_argument('--encoder_no_factor', action='store_true', default=False)
    parser.add_argument('--decoder_hidden', type=int, default=256)
    parser.add_argument('--decoder_msg_hidden', type=int, default=256)
    parser.add_argument('--decoder_dropout', type=float, default=0.0)
    parser.add_argument('--skip_first', action='store_true', default=False)
    parser.add_argument('--uniform_prior', action='store_true')
    parser.add_argument('--no_edge_prior', type=float)
    parser.add_argument('--teacher_forcing_steps', type=int, default=10)
    parser.add_argument('--gumbel_temp', type=float, default=0.5)
    parser.add_argument('--train_hard_sample', action='store_true')
    parser.add_argument('--normalize_kl', action='store_true')
    parser.add_argument('--normalize_kl_per_var', action='store_true')
    parser.add_argument('--normalize_nll', action='store_true')
    parser.add_argument('--normalize_nll_per_var', action='store_true')
    parser.add_argument('--kl_coef', type=float, default=1.)
    parser.add_argument('--no_encoder_bn', action='store_true')
    parser.add_argument('--encoder_mlp_hidden', type=int, default=256)
    parser.add_argument('--encoder_mlp_num_layers', type=int, default=1)
    parser.add_argument('--rnn_hidden', type=int, default=64)
    parser.add_argument('--teacher_forcing_prior', action='store_true')
    parser.add_argument('--decoder_rnn_hidden', type=int)
    parser.add_argument('--encoder_save_eval_memory', action='store_true')
    parser.add_argument('--encoder_normalize_mode', choices=[None, 'normalize_inp', 'normalize_all'])
    parser.add_argument('--normalize_inputs', action='store_true')
    #Modes
    return parser

In [2]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch


def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
    return labels_onehot



class RefNRIMLP(nn.Module):
    """Two-layer fully-connected ELU net with batch norm."""

    def __init__(self, n_in, n_hid, n_out, do_prob=0., no_bn=False):
        super(RefNRIMLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(n_in, n_hid),
            nn.ELU(inplace=True),
            nn.Dropout(do_prob),
            nn.Linear(n_hid, n_out),
            nn.ELU(inplace=True)
        )
        if no_bn:
            self.bn = None
        else:
            self.bn = nn.BatchNorm1d(n_out)

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def batch_norm(self, inputs):
        orig_shape = inputs.shape
        x = inputs.view(-1, inputs.size(-1))
        x = self.bn(x)
        return x.view(orig_shape)

    def forward(self, inputs):
        # Input shape: [num_sims, num_things, num_features]
        x = self.model(inputs)
        if self.bn is not None:
            return self.batch_norm(x)
        else:
            return x


def sample_gumbel(shape, eps=1e-10):
    """
    NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3
    Sample from Gumbel(0, 1)
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
    (MIT license)
    """
    U = torch.rand(shape).float()
    return - torch.log(eps - torch.log(U + eps))


def gumbel_softmax_sample(logits, tau=1, eps=1e-10):
    """
    NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3
    Draw a sample from the Gumbel-Softmax distribution
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
    (MIT license)
    """
    gumbel_noise = sample_gumbel(logits.size(), eps=eps)
    if logits.is_cuda:
        gumbel_noise = gumbel_noise.cuda()
    y = logits + gumbel_noise
    return F.softmax(y / tau, dim=-1)


def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10):
    """
    NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3
    Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
      logits: [batch_size, n_class] unnormalized log-probs
      tau: non-negative scalar temperature
      hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
      [batch_size, n_class] sample from the Gumbel-Softmax distribution.
      If hard=True, then the returned sample will be one-hot, otherwise it will
      be a probability distribution that sums to 1 across classes
    Constraints:
    - this implementation only works on batch_size x num_features tensor for now
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
    (MIT license)
    """
    y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps)
    if hard:
        shape = logits.size()
        _, k = y_soft.data.max(-1)
        # this bit is based on
        # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
        y_hard = torch.zeros(*shape)
        if y_soft.is_cuda:
            y_hard = y_hard.cuda()
        y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0)
        # this cool bit of code achieves two things:
        # - makes the output value exactly one-hot (since we add then
        #   subtract y_soft value)
        # - makes the gradient equal to y_soft gradient (since we strip
        #   all other gradients)
        y = y_hard - y_soft.data + y_soft
    else:
        y = y_soft
    return y


def get_graph_info(masks, num_vars, use_edge2node=True):
    if num_vars == 1:
        return None, None, None
    edges = torch.ones(num_vars, device=masks.device) - torch.eye(num_vars, device=masks.device)
    tmp = torch.where(edges)
    send_edges = tmp[0]
    recv_edges = tmp[1]
    tmp_inds = torch.tensor(list(range(num_vars)), device=masks.device, dtype=torch.long).unsqueeze_(1)
    if use_edge2node:
        edge2node_inds = (tmp_inds == recv_edges.unsqueeze(0)).nonzero()[:, 1].contiguous().view(-1, num_vars-1)
        return send_edges, recv_edges, edge2node_inds
    else:
        return send_edges, recv_edges

In [3]:
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np



class BaseEncoder(nn.Module):

    def __init__(self, num_vars, graph_type):
        super(BaseEncoder, self).__init__()
        self.num_vars = num_vars
        self.graph_type = graph_type
        self.dynamic = graph_type == 'dynamic'

        edges = np.ones(num_vars) - np.eye(num_vars)
        self.send_edges = np.where(edges)[0]
        self.recv_edges = np.where(edges)[1]
        self.edge2node_mat = nn.Parameter(torch.FloatTensor(encode_onehot(self.recv_edges).transpose()), requires_grad=False)
    
    def node2edge(self, node_embeddings):
        if self.dynamic:
            '''
            node_embeddings: [batch, num_nodes, embed_size]
            '''
            send_embed = node_embeddings[:, self.send_edges, :, :]
            recv_embed = node_embeddings[:, self.recv_edges, :, :]
            return torch.cat([send_embed, recv_embed], dim=3)
        else:
            send_embed = node_embeddings[:, self.send_edges, :]
            recv_embed = node_embeddings[:, self.recv_edges, :]
            return torch.cat([send_embed, recv_embed], dim=2)

    def edge2node(self, edge_embeddings):
        #TODO: there may be a more efficient way of doing this, but shrug
        if self.dynamic:
            old_shape = edge_embeddings.shape
            tmp_embeddings = edge_embeddings.view(old_shape[0], old_shape[1], -1)
            incoming = torch.matmul(self.edge2node_mat, tmp_embeddings).view(old_shape[0], -1, old_shape[2], old_shape[3])
        else:
            incoming = torch.matmul(self.edge2node_mat, edge_embeddings)
        return incoming/(self.num_vars-1) #TODO: do we want this average?

    def forward(self, inputs, state=None, return_state=False):
        raise NotImplementedError


class RefMLPEncoder(BaseEncoder):
    def __init__(self, params):
        num_vars = params['num_vars']
        inp_size = inp_size = params['input_size']*params['input_time_steps']
        hidden_size = params['encoder_hidden']
        num_edges = params['num_edge_types']
        factor = not params['encoder_no_factor']
        no_bn = False
        graph_type = params['graph_type']
        super(RefMLPEncoder, self).__init__(num_vars, graph_type)
        dropout = params['encoder_dropout']
        self.input_time_steps = params['input_time_steps']
        self.dynamic = self.graph_type == 'dynamic'
        self.factor = factor
        self.mlp1 = RefNRIMLP(inp_size, hidden_size, hidden_size, dropout, no_bn=no_bn)
        self.mlp2 = RefNRIMLP(hidden_size * 2, hidden_size, hidden_size, dropout, no_bn=no_bn)
        self.mlp3 = RefNRIMLP(hidden_size, hidden_size, hidden_size, dropout, no_bn=no_bn)
        if self.factor:
            self.mlp4 = RefNRIMLP(hidden_size * 3, hidden_size, hidden_size, dropout, no_bn=no_bn)
            print("Using factor graph MLP encoder.")
        else:
            self.mlp4 = RefNRIMLP(hidden_size * 2, hidden_size, hidden_size, dropout, no_bn=no_bn)
            print("Using MLP encoder.")
        num_layers = params['encoder_mlp_num_layers']
        if num_layers == 1:
            self.fc_out = nn.Linear(hidden_size, num_edges)
        else:
            tmp_hidden_size = params['encoder_mlp_hidden']
            layers = [nn.Linear(hidden_size, tmp_hidden_size), nn.ELU(inplace=True)]
            for _ in range(num_layers - 2):
                layers.append(nn.Linear(tmp_hidden_size, tmp_hidden_size))
                layers.append(nn.ELU(inplace=True))
            layers.append(nn.Linear(tmp_hidden_size, num_edges))
            self.fc_out = nn.Sequential(*layers)

        self.init_weights()

    def node2edge(self, node_embeddings):
        send_embed = node_embeddings[:, self.send_edges, :]
        recv_embed = node_embeddings[:, self.recv_edges, :]
        return torch.cat([send_embed, recv_embed], dim=2)

    def edge2node(self, edge_embeddings):
        incoming = torch.matmul(self.edge2node_mat, edge_embeddings)
        return incoming/(self.num_vars-1) #TODO: do we want this average?

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)

    def merge_states(self, states):
        return torch.cat(states, dim=0)

    def forward(self, inputs, state=None, return_state=False):
        if inputs.size(1) > self.input_time_steps:
            inputs = inputs[:, -self.input_time_steps:]
        elif inputs.size(1) < self.input_time_steps:
            begin_inp = inputs[:, 0:1].expand(-1, self.input_time_steps-inputs.size(1), -1, -1)
            inputs = torch.cat([begin_inp, inputs], dim=1)
        if state is not None:
            inputs = torch.cat([state, inputs], 1)[:, -self.input_time_steps:]
        x = inputs.transpose(1, 2).contiguous().view(inputs.size(0), inputs.size(2), -1)
        # New shape: [num_sims, num_atoms, num_timesteps*num_dims]
        x = self.mlp1(x)  # 2-layer ELU net per node

        x = self.node2edge(x)
        x = self.mlp2(x)
        x_skip = x

        if self.factor:
            x = self.edge2node(x)
            x = self.mlp3(x)
            x = self.node2edge(x)
            x = torch.cat((x, x_skip), dim=-1)  # Skip connection
            x = self.mlp4(x)
        else:
            x = self.mlp3(x)
            x = torch.cat((x, x_skip), dim=-1)  # Skip connection
            x = self.mlp4(x)
        result =  self.fc_out(x)
        result_dict = {
            'logits': result,
            'state': inputs,
        }
        return result_dict

In [4]:
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np




class GraphRNNDecoder(nn.Module):
    def __init__(self, params):
        super(GraphRNNDecoder, self).__init__()
        self.embedder = None
        self.num_vars = num_vars =  params['num_vars']
        input_size = params['input_size']
        self.gpu = params['gpu']
        n_hid = params['decoder_hidden']
        edge_types = params['num_edge_types']
        skip_first = params['skip_first']
        out_size = params['input_size']
        do_prob = params['decoder_dropout']

        self.msg_fc1 = nn.ModuleList(
            [nn.Linear(2*n_hid, n_hid) for _ in range(edge_types)]
        )
        self.msg_fc2 = nn.ModuleList(
            [nn.Linear(n_hid, n_hid) for _ in range(edge_types)]
        )
        self.msg_out_shape = n_hid
        self.skip_first_edge_type = skip_first

        self.hidden_r = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_i = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_h = nn.Linear(n_hid, n_hid, bias=False)

        self.input_r = nn.Linear(input_size, n_hid, bias=True)
        self.input_i = nn.Linear(input_size, n_hid, bias=True)
        self.input_n = nn.Linear(input_size, n_hid, bias=True)

        self.out_fc1 = nn.Linear(n_hid, n_hid)
        self.out_fc2 = nn.Linear(n_hid, n_hid)
        self.out_fc3 = nn.Linear(n_hid, out_size)

        print('Using learned recurrent interaction net decoder.')

        self.dropout_prob = do_prob

        self.num_vars = num_vars
        edges = np.ones(num_vars) - np.eye(num_vars)
        self.send_edges = np.where(edges)[0]
        self.recv_edges = np.where(edges)[1]
        self.edge2node_mat = torch.FloatTensor(encode_onehot(self.recv_edges))
        if self.gpu:
            self.edge2node_mat = self.edge2node_mat.cuda(non_blocking=True)

    def single_step_forward(self, inputs, rel_type, hidden):
        # Inputs: [batch, num_atoms, num_dims]
        # Hidden: [batch, num_atoms, msg_out]
        # rel_type: [batch_size, num_atoms*(num_atoms-1), num_edge_types]
        
        # node2edge
        receivers = hidden[:, self.recv_edges, :]
        senders = hidden[:, self.send_edges, :]

        # pre_msg: [batch, num_edges, 2*msg_out]
        pre_msg = torch.cat([receivers, senders], dim=-1)

        if inputs.is_cuda:
            all_msgs = torch.cuda.FloatTensor(pre_msg.size(0), pre_msg.size(1),
                                              self.msg_out_shape).fill_(0.)
        else:
            all_msgs = torch.zeros(pre_msg.size(0), pre_msg.size(1),
                                            self.msg_out_shape)
        
        if self.skip_first_edge_type:
            start_idx = 1
            norm = float(len(self.msg_fc2)) - 1
        else:
            start_idx = 0
            norm = float(len(self.msg_fc2))

        # Run separate MLP for every edge type
        # NOTE: to exclude one edge type, simply offset range by 1
        for i in range(start_idx, len(self.msg_fc2)):
            msg = torch.tanh(self.msg_fc1[i](pre_msg))
            msg = F.dropout(msg, p=self.dropout_prob)
            msg = torch.tanh(self.msg_fc2[i](msg))
            msg = msg * rel_type[:, :, i:i+1]
            all_msgs += msg/norm

        # This step sums all of the messages per node
        agg_msgs = all_msgs.transpose(-2, -1).matmul(self.edge2node_mat).transpose(-2, -1)
        agg_msgs = agg_msgs.contiguous() / (self.num_vars - 1) # Average

        # GRU-style gated aggregation
        inp_r = self.input_r(inputs).view(inputs.size(0), self.num_vars, -1)
        inp_i = self.input_i(inputs).view(inputs.size(0), self.num_vars, -1)
        inp_n = self.input_n(inputs).view(inputs.size(0), self.num_vars, -1)
        r = torch.sigmoid(inp_r + self.hidden_r(agg_msgs))
        i = torch.sigmoid(inp_i + self.hidden_i(agg_msgs))
        n = torch.tanh(inp_n + r*self.hidden_h(agg_msgs))
        hidden = (1 - i)*n + i*hidden

        # Output MLP
        pred = F.dropout(F.relu(self.out_fc1(hidden)), p=self.dropout_prob)
        pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob)
        pred = self.out_fc3(pred)

        pred = inputs + pred

        return pred, hidden

    def forward(self, inputs, sampled_edges, teacher_forcing=False, teacher_forcing_steps=-1, return_state=False,
                prediction_steps=-1, state=None, burn_in_masks=None):

        time_steps = inputs.size(1)

        # inputs has shape
        # [batch_size, num_timesteps, num_atoms, num_feats]

        if prediction_steps > 0:
            pred_steps = prediction_steps
        else:
            pred_steps = time_steps

        if len(sampled_edges.shape) == 3:
            sampled_edges = sampled_edges.unsqueeze(1).expand(sampled_edges.size(0), pred_steps, sampled_edges.size(1), sampled_edges.size(2))
        # sampled_edges has shape:
        # [batch_size, num_time_steps, num_atoms*(num_atoms-1), num_edge_types]
        # represents the sampled edges in the graph

        # Hidden size: [batch, num_atoms, msg_out]
        if state is None:
            if inputs.is_cuda:
                hidden = torch.cuda.FloatTensor(inputs.size(0), inputs.size(2), self.msg_out_shape).fill_(0.)
            else:
                hidden = torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape)
        else:
            hidden = state
        if teacher_forcing_steps == -1:
            teacher_forcing_steps = inputs.size(1)
        
        pred_all = []
        for step in range(0, pred_steps):
            if burn_in_masks is not None and step != 0:
                current_masks = burn_in_masks[:, step, :]
                ins = inputs[:, step, :]*current_masks + pred_all[-1]*(1 - current_masks)
            elif step == 0 or (teacher_forcing and step < teacher_forcing_steps): 
                ins = inputs[:, step, :]
            else:
                ins = pred_all[-1]
            edges = sampled_edges[:, step, :]
            pred, hidden = self.single_step_forward(ins, edges, hidden)

            pred_all.append(pred)
        preds = torch.stack(pred_all, dim=1)

        if return_state:
            return preds, hidden
        else:
            return preds

In [5]:
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np



class BaseNRI(nn.Module):
    def __init__(self, num_vars, encoder, decoder, params):
        super(BaseNRI, self).__init__()
        # Model Params
        self.num_vars = num_vars
        self.decoder = decoder
        self.encoder = encoder
        self.num_edge_types = params.get('num_edge_types')

        # Training params
        self.gumbel_temp = params.get('gumbel_temp')
        self.train_hard_sample = params.get('train_hard_sample')
        self.teacher_forcing_steps = params.get('teacher_forcing_steps', -1)
        if params.get('no_edge_prior') is not None:
            prior = np.zeros(self.num_edge_types)
            prior.fill((1 - params['no_edge_prior'])/(self.num_edge_types - 1))
            prior[0] = params['no_edge_prior']
            log_prior = torch.FloatTensor(np.log(prior))
            log_prior = torch.unsqueeze(log_prior, 0)
            log_prior = torch.unsqueeze(log_prior, 0)
            if params['gpu']:
                log_prior = log_prior.cuda(non_blocking=True)
            self.log_prior = log_prior
            print("USING NO EDGE PRIOR: ",self.log_prior)
        else:
            print("USING UNIFORM PRIOR")
            prior = np.zeros(self.num_edge_types)
            prior.fill(1.0/self.num_edge_types)
            log_prior = torch.FloatTensor(np.log(prior))
            log_prior = torch.unsqueeze(log_prior, 0)
            log_prior = torch.unsqueeze(log_prior, 0)
            if params['gpu']:
                log_prior = log_prior.cuda(non_blocking=True)
            self.log_prior = log_prior

        self.normalize_kl = params.get('normalize_kl', False)
        self.normalize_kl_per_var = params.get('normalize_kl_per_var', False)
        self.normalize_nll = params.get('normalize_nll', False)
        self.normalize_nll_per_var = params.get('normalize_nll_per_var', False)
        self.kl_coef = params.get('kl_coef', 1.)
        self.nll_loss_type = params.get('nll_loss_type', 'crossent')
        self.prior_variance = params.get('prior_variance')
        self.timesteps = params.get('timesteps', 0)
        self.extra_context = params.get('embedder_time_bins', 0)
        self.burn_in_steps = params.get('train_burn_in_steps')
        self.no_prior = params.get('no_prior', False)
        self.val_teacher_forcing_steps = params.get('val_teacher_forcing_steps', -1)

    def calculate_loss(self, inputs, is_train=False, teacher_forcing=True, return_edges=False, return_logits=False):
        # Should be shape [batch, num_edges, edge_dim]
        encoder_results = self.encoder(inputs)
        logits = encoder_results['logits']
        old_shape = logits.shape
        hard_sample = (not is_train) or self.train_hard_sample
        edges = model_utils.gumbel_softmax(
            logits.view(-1, self.num_edge_types), 
            tau=self.gumbel_temp, 
            hard=hard_sample).view(old_shape)
        if not is_train and teacher_forcing:
            teacher_forcing_steps = self.val_teacher_forcing_steps
        else:
            teacher_forcing_steps = self.teacher_forcing_steps
        output = self.decoder(inputs[:, self.extra_context:-1], edges, 
                              teacher_forcing=teacher_forcing, 
                              teacher_forcing_steps=teacher_forcing_steps)
        if len(inputs.shape) == 4:
            target = inputs[:, self.extra_context+1:, :, :]
        else:
            target = inputs[:, self.extra_context+1:, :]
        loss_nll = self.nll(output, target)
        prob = F.softmax(logits, dim=-1)
        if self.no_prior:
            loss_kl = torch.cuda.FloatTensor([0.0])
        elif self.log_prior is not None:
            loss_kl = self.kl_categorical(prob)
        else:
            loss_kl = self.kl_categorical_uniform(prob)
        loss = loss_nll + self.kl_coef*loss_kl
        loss = loss.mean()
        if return_edges:
            return loss, loss_nll, loss_kl, edges
        elif return_logits:
            return loss, loss_nll, loss_kl, logits, output
        else:
            return loss, loss_nll, loss_kl

        
    def predict_future(self, data_encoder, data_decoder):
        raise NotImplementedError()

    def nll(self, preds, target):
        if self.nll_loss_type == 'crossent':
            return self.nll_crossent(preds, target)
        elif self.nll_loss_type == 'gaussian':
            return self.nll_gaussian(preds, target)
        elif self.nll_loss_type == 'poisson':
            return self.nll_poisson(preds, target)

    def nll_gaussian(self, preds, target, add_const=False):
        neg_log_p = ((preds - target) ** 2 / (2 * self.prior_variance))
        const = 0.5 * np.log(2 * np.pi * self.prior_variance)
        if self.normalize_nll_per_var:
            return neg_log_p.sum() / (target.size(0) * target.size(2))
        elif self.normalize_nll:
            return (neg_log_p.sum(-1) + const).view(preds.size(0), -1).mean(dim=1)
        else:
            return neg_log_p.view(target.size(0), -1).sum() / (target.size(1))

    def nll_crossent(self, preds, target):
        if self.normalize_nll:
            return nn.BCEWithLogitsLoss(reduction='none')(preds, target).view(preds.size(0), -1).mean(dim=1)
        else:
            return nn.BCEWithLogitsLoss(reduction='none')(preds, target).view(preds.size(0), -1).sum(dim=1)

    def nll_poisson(self, preds, target):
        if self.normalize_nll:
            return nn.PoissonNLLLoss(reduction='none')(preds, target).view(preds.size(0), -1).mean(dim=1)
        else:
            return nn.PoissonNLLLoss(reduction='none')(preds, target).view(preds.size(0), -1).sum(dim=1)

    def kl_categorical(self, preds, eps=1e-16):
        kl_div = preds*(torch.log(preds+eps) - self.log_prior)
        if self.normalize_kl:     
            return kl_div.sum(-1).view(preds.size(0), -1).mean(dim=1)
        elif self.normalize_kl_per_var:
            return kl_div.sum() / (self.num_vars * preds.size(0))
        else:
            return kl_div.view(preds.size(0), -1).sum(dim=1)
    
    def kl_categorical_uniform(self, preds, eps=1e-16):
        kl_div = preds*(torch.log(preds + eps) + np.log(self.num_edge_types))
        if self.normalize_kl:
            return kl_div.sum(-1).view(preds.size(0), -1).mean()
        elif self.normalize_kl_per_var:
            return kl_div.sum() / (self.num_vars * preds.size(0))
        else:
            return kl_div.view(preds.size(0), -1).sum(dim=1)/self.num_edge_types
    
    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))


class StaticNRI(BaseNRI):

    def copy_state(self, decoder_state):
        if isinstance(decoder_state, tuple) or isinstance(decoder_state, list):
            current_decoder_state = (decoder_state[0].clone(), decoder_state[1].clone())
        else:
            current_decoder_state = decoder_state.clone()
        return current_decoder_state

    def predict_future(self, inputs, prediction_steps, return_edges=False, return_everything=False):
        encoder_dict = self.encoder(inputs)
        logits = encoder_dict['logits']
        old_shape = logits.shape
        edges = nn.functional.gumbel_softmax(
            logits.view(-1, self.num_edge_types), 
            tau=self.gumbel_temp, 
            hard=True).view(old_shape)
        tmp_predictions, decoder_state = self.decoder(inputs[:, :-1], edges, teacher_forcing=True, teacher_forcing_steps=-1, return_state=True)
        decoder_inputs = inputs[:, -1].unsqueeze(1)
        predictions = self.decoder(decoder_inputs, edges, prediction_steps=prediction_steps, teacher_forcing=False, state=decoder_state)
        if return_everything:
            predictions = torch.cat([tmp_predictions, predictions], dim=1)
        if return_edges:
            return predictions, edges
        else:
            return predictions

    def merge_states(self, states):
        if isinstance(states[0], tuple) or isinstance(states[0], list):
            result0 = torch.cat([x[0] for x in states], dim=0)
            result1 = torch.cat([x[1] for x in states], dim=0)
            return (result0, result1)
        else:
            return torch.cat(states, dim=0)

    def predict_future_fixedwindow(self, inputs, burn_in_steps, prediction_steps, batch_size, return_edges=False):
        burn_in_inputs = inputs[:, :burn_in_steps]
        encoder_dict = self.encoder(burn_in_inputs)
        logits = encoder_dict['logits']
        old_shape = logits.shape
        edges = nn.functional.gumbel_softmax(
            logits.view(-1, self.num_edge_types),
            tau=self.gumbel_temp,
            hard=True).view(old_shape)
        _, decoder_state = self.decoder(burn_in_inputs[:, :-1], edges, teacher_forcing=True, teacher_forcing_steps=-1, return_state=True)
        all_timestep_preds = []
        for window_ind in range(burn_in_steps - 1, inputs.size(1)-1, batch_size):
            current_batch_preds = []
            encoder_states = []
            decoder_states = []
            for step in range(batch_size):
                if window_ind + step >= inputs.size(1):
                    break
                tmp_ind = window_ind + step
                predictions = inputs[:, tmp_ind:tmp_ind+1]
                predictions, decoder_state = self.decoder(predictions,edges, teacher_forcing=False, prediction_steps=1, return_state=True, state=decoder_state)
                current_batch_preds.append(predictions)
                decoder_states.append(self.copy_state(decoder_state))
            batch_decoder_state = self.merge_states(decoder_states)
            current_batch_preds = torch.cat(current_batch_preds, 0)
            current_timestep_preds = [current_batch_preds]
            batch_edges = edges.expand(current_batch_preds.size(0), -1, -1)
            for step in range(prediction_steps - 1):
                current_batch_preds, batch_decoder_state =self.decoder(current_batch_preds, batch_edges, teacher_forcing=False, prediction_steps=1, return_state=True, state=batch_decoder_state) 
                current_timestep_preds.append(current_batch_preds)
            all_timestep_preds.append(torch.cat(current_timestep_preds, dim=1))
        result = torch.cat(all_timestep_preds, dim=0)
        if return_edges:
            return result.unsqueeze(0), edges
        else:
            return result.unsqueeze(0)

    
class DynamicNRI(BaseNRI):
    def predict_future(self, inputs, prediction_steps, return_edges=False, return_everything=False):
        encoder_dict = self.encoder(inputs)
        burn_in_logits = encoder_dict['logits']
        encoder_state = encoder_dict['state']
        old_shape = burn_in_logits.shape
        burn_in_edges = nn.functional.gumbel_softmax(
            burn_in_logits.view(-1, self.num_edge_types), 
            tau=self.gumbel_temp, 
            hard=True).view(old_shape)
        burn_in_predictions, decoder_state = self.decoder(inputs, burn_in_edges, teacher_forcing=True, teacher_forcing_steps=-1, return_state=True)
        prev_preds = burn_in_predictions[:, -1].unsqueeze(1)
        preds = [prev_preds]
        all_edges = [burn_in_edges]
        for step in range(prediction_steps-1):
            encoder_dict = self.encoder(prev_preds, encoder_state)
            logits = encoder_dict['logits']
            encoder_state = encoder_dict['state']
            old_shape = logits.shape
            edges = nn.functional.gumbel_softmax(
                logits.view(-1, self.num_edge_types),
                tau=self.gumbel_temp,
                hard=True).view(old_shape)
            if return_edges:
                all_edges.append(edges)
            prev_preds, decoder_state = self.decoder(prev_preds, edges, teacher_forcing=False, prediction_steps=1, return_state=True, state=decoder_state)
            preds.append(prev_preds)
        preds = torch.cat(preds, dim=1)
        if return_everything:
            preds = torch.cat([burn_in_predictions[:, :-1], preds], dim=1)
        if return_edges:
            return preds, torch.stack(all_edges, dim=1)
        else:
            return preds

    def copy_states(self, encoder_state, decoder_state):
        if isinstance(encoder_state, tuple) or isinstance(encoder_state, list):
            current_encoder_state = (encoder_state[0].clone(), encoder_state[1].clone())
        else:
            current_encoder_state = encoder_state.clone()
        if isinstance(decoder_state, tuple) or isinstance(decoder_state, list):
            current_decoder_state = (decoder_state[0].clone(), decoder_state[1].clone())
        else:
            current_decoder_state = decoder_state.clone()
        return current_encoder_state, current_decoder_state

    def merge_states(self, states):
        if isinstance(states[0], tuple) or isinstance(states[0], list):
            result0 = torch.cat([x[0] for x in states], dim=0)
            result1 = torch.cat([x[1] for x in states], dim=0)
            return (result0, result1)
        else:
            return torch.cat(states, dim=0)

    def predict_future_fixedwindow(self, inputs, burn_in_steps, prediction_steps, batch_size, return_edges=False):
        burn_in_inputs = inputs[:, :burn_in_steps-1]
        encoder_dict = self.encoder(burn_in_inputs)
        burn_in_logits = encoder_dict['logits']
        encoder_state = encoder_dict['state']
        old_shape = burn_in_logits.shape
        burn_in_edges = nn.functional.gumbel_softmax(
            burn_in_logits.view(-1, self.num_edge_types), 
            tau=self.gumbel_temp, 
            hard=True).view(old_shape)
        burn_in_predictions, decoder_state = self.decoder(burn_in_inputs, burn_in_edges, teacher_forcing=True, teacher_forcing_steps=-1, return_state=True)
        all_timestep_preds = []
        all_timestep_edges = []
        all_edges = [] 
        for window_ind in range(burn_in_steps - 1, inputs.size(1) - 1, batch_size):
            current_batch_preds = []
            current_batch_edges = []
            encoder_states = []
            decoder_states = []
            for step in range(batch_size):
                if window_ind + step >= inputs.size(1):
                    break
                tmp_ind = window_ind + step
                predictions = inputs[:, tmp_ind:tmp_ind+1]
                encoder_dict = self.encoder(predictions, encoder_state)
                logits = encoder_dict['logits']
                encoder_state = encoder_dict['state']
                old_shape = logits.shape
                edges = nn.functional.gumbel_softmax(
                    logits.view(-1, self.num_edge_types),
                    tau=self.gumbel_temp,
                    hard=True).view(old_shape)
                predictions, decoder_state = self.decoder(predictions, edges, teacher_forcing=False, prediction_steps=1, return_state=True, state=decoder_state)
                current_batch_preds.append(predictions)
                if return_edges:
                    current_batch_edges.append(edges)
                tmp_encoder, tmp_decoder = self.copy_states(encoder_state, decoder_state)
                encoder_states.append(tmp_encoder)
                decoder_states.append(tmp_decoder)
            batch_encoder_state = self.encoder.merge_states(encoder_states)
            batch_decoder_state = self.merge_states(decoder_states)
            current_batch_preds = torch.cat(current_batch_preds, 0)
            current_timestep_preds = [current_batch_preds]
            if return_edges:
                current_timestep_edges = [torch.cat(current_batch_edges, 0)]
                
            for step in range(prediction_steps - 1):
                encoder_dict = self.encoder(current_batch_preds, batch_encoder_state)
                logits = encoder_dict['logits']
                batch_encoder_state = encoder_dict['state']
                old_shape = logits.shape
                edges = nn.functional.gumbel_softmax(
                    logits.view(-1, self.num_edge_types),
                    tau=self.gumbel_temp,
                    hard=True).view(old_shape)
                if return_edges:
                    current_timestep_edges.append(edges)
                current_batch_preds, batch_decoder_state = self.decoder(current_batch_preds, edges, teacher_forcing=False, prediction_steps=1, return_state=True, state=batch_decoder_state)
                current_timestep_preds.append(current_batch_preds)
            all_timestep_preds.append(torch.cat(current_timestep_preds, dim=1))
            if return_edges:
                all_timestep_edges.append(torch.cat(current_timestep_edges, dim=1))
        result = torch.cat(all_timestep_preds, dim=0)
        if return_edges:
            edges = torch.cat(all_timestep_edges, dim=0)
            return result.unsqueeze(0), edges.unsqueeze(0)
        else:
            return result.unsqueeze(0)


In [6]:
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np

import math


class DNRI(nn.Module):
    def __init__(self, params):
        super(DNRI, self).__init__()
        # Model Params
        self.num_vars = params['num_vars']
        self.encoder = DNRI_Encoder(params)
        decoder_type = params.get('decoder_type', None)
        if decoder_type == 'ref_mlp':
            self.decoder = DNRI_MLP_Decoder(params)
        else:
            self.decoder = DNRI_Decoder(params)
        self.num_edge_types = params.get('num_edge_types')

        # Training params
        self.gumbel_temp = params.get('gumbel_temp')
        self.train_hard_sample = params.get('train_hard_sample')
        self.teacher_forcing_steps = params.get('teacher_forcing_steps', -1)
        
        self.normalize_kl = params.get('normalize_kl', False)
        self.normalize_kl_per_var = params.get('normalize_kl_per_var', False)
        self.normalize_nll = params.get('normalize_nll', False)
        self.normalize_nll_per_var = params.get('normalize_nll_per_var', False)
        self.kl_coef = params.get('kl_coef', 1.)
        self.nll_loss_type = params.get('nll_loss_type', 'crossent')
        self.prior_variance = params.get('prior_variance')
        self.timesteps = params.get('timesteps', 0)
        self.burn_in_steps = params.get('train_burn_in_steps')
        self.teacher_forcing_prior = params.get('teacher_forcing_prior', False)
        self.val_teacher_forcing_steps = params.get('val_teacher_forcing_steps', -1)
        self.add_uniform_prior = params.get('add_uniform_prior')
        if self.add_uniform_prior:
            if params.get('no_edge_prior') is not None:
                prior = np.zeros(self.num_edge_types)
                prior.fill((1 - params['no_edge_prior'])/(self.num_edge_types - 1))
                prior[0] = params['no_edge_prior']
                log_prior = torch.FloatTensor(np.log(prior))
                log_prior = torch.unsqueeze(log_prior, 0)
                log_prior = torch.unsqueeze(log_prior, 0)
                if params['gpu']:
                    log_prior = log_prior.cuda(non_blocking=True)
                self.log_prior = log_prior
                print("USING NO EDGE PRIOR: ",self.log_prior)
            else:
                print("USING UNIFORM PRIOR")
                prior = np.zeros(self.num_edge_types)
                prior.fill(1.0/self.num_edge_types)
                log_prior = torch.FloatTensor(np.log(prior))
                log_prior = torch.unsqueeze(log_prior, 0)
                log_prior = torch.unsqueeze(log_prior, 0)
                if params['gpu']:
                    log_prior = log_prior.cuda(non_blocking=True)
                self.log_prior = log_prior

    def single_step_forward(self, inputs, decoder_hidden, edge_logits, hard_sample):
        old_shape = edge_logits.shape
        edges = model_utils.gumbel_softmax(
            edge_logits.reshape(-1, self.num_edge_types), 
            tau=self.gumbel_temp, 
            hard=hard_sample).view(old_shape)
        predictions, decoder_hidden = self.decoder(inputs, decoder_hidden, edges)
        return predictions, decoder_hidden, edges

    def calculate_loss(self, inputs, is_train=False, teacher_forcing=True, return_edges=False, return_logits=False, use_prior_logits=False):
        decoder_hidden = self.decoder.get_initial_hidden(inputs)
        num_time_steps = inputs.size(1)
        all_edges = []
        all_predictions = []
        all_priors = []
        hard_sample = (not is_train) or self.train_hard_sample
        prior_logits, posterior_logits, _ = self.encoder(inputs[:, :-1])
        if not is_train:
            teacher_forcing_steps = self.val_teacher_forcing_steps
        else:
            teacher_forcing_steps = self.teacher_forcing_steps
        for step in range(num_time_steps-1):
            if (teacher_forcing and (teacher_forcing_steps == -1 or step < teacher_forcing_steps)) or step == 0:
                current_inputs = inputs[:, step]
            else:
                current_inputs = predictions
            if not use_prior_logits:
                current_p_logits = posterior_logits[:, step]
            else:
                current_p_logits = prior_logits[:, step]
            predictions, decoder_hidden, edges = self.single_step_forward(current_inputs, decoder_hidden, current_p_logits, hard_sample)
            all_predictions.append(predictions)
            all_edges.append(edges)
        all_predictions = torch.stack(all_predictions, dim=1)
        target = inputs[:, 1:, :, :]
        loss_nll = self.nll(all_predictions, target)
        prob = F.softmax(posterior_logits, dim=-1)
        loss_kl = self.kl_categorical_learned(prob, prior_logits)
        if self.add_uniform_prior:
            loss_kl = 0.5*loss_kl + 0.5*self.kl_categorical_avg(prob)
        loss = loss_nll + self.kl_coef*loss_kl
        loss = loss.mean()

        if return_edges:
            return loss, loss_nll, loss_kl, edges
        elif return_logits:
            return loss, loss_nll, loss_kl, posterior_logits, all_predictions
        else:
            return loss, loss_nll, loss_kl

    def predict_future(self, inputs, prediction_steps, return_edges=False, return_everything=False):
        burn_in_timesteps = inputs.size(1)
        decoder_hidden = self.decoder.get_initial_hidden(inputs)
        all_predictions = []
        all_edges = []
        prior_logits, _, prior_hidden = self.encoder(inputs[:, :-1])
        for step in range(burn_in_timesteps-1):
            current_inputs = inputs[:, step]
            current_edge_logits = prior_logits[:, step]
            predictions, decoder_hidden, edges = self.single_step_forward(current_inputs, decoder_hidden, current_edge_logits, True)
            if return_everything:
                all_edges.append(edges)
                all_predictions.append(predictions)
        predictions = inputs[:, burn_in_timesteps-1]
        for step in range(prediction_steps):
            current_edge_logits, prior_hidden = self.encoder.single_step_forward(predictions, prior_hidden)
            predictions, decoder_hidden, edges = self.single_step_forward(predictions, decoder_hidden, current_edge_logits, True)
            all_predictions.append(predictions)
            all_edges.append(edges)
        
        predictions = torch.stack(all_predictions, dim=1)
        if return_edges:
            edges = torch.stack(all_edges, dim=1)
            return predictions, edges
        else:
            return predictions

    def copy_states(self, state):
        if isinstance(state, tuple) or isinstance(state, list):
            current_state = (state[0].clone(), state[1].clone())
        else:
            current_state = state.clone()
        return current_state

    def merge_hidden(self, hidden):
        if isinstance(hidden[0], tuple) or isinstance(hidden[0], list):
            result0 = torch.cat([x[0] for x in hidden], dim=0)
            result1 = torch.cat([x[1] for x in hidden], dim=0)
            return (result0, result1)
        else:
            return torch.cat(hidden, dim=0)

    def predict_future_fixedwindow(self, inputs, burn_in_steps, prediction_steps, batch_size, return_edges=False):
        print("INPUT SHAPE: ",inputs.shape)
        prior_logits, _, prior_hidden = self.encoder(inputs[:, :-1])
        decoder_hidden = self.decoder.get_initial_hidden(inputs)
        for step in range(burn_in_steps-1):
            current_inputs = inputs[:, step]
            current_edge_logits = prior_logits[:, step]
            predictions, decoder_hidden, _ = self.single_step_forward(current_inputs, decoder_hidden, current_edge_logits, True)
        all_timestep_preds = []
        all_timestep_edges = []
        for window_ind in range(burn_in_steps - 1, inputs.size(1)-1, batch_size):
            current_batch_preds = []
            current_batch_edges = []
            prior_states = []
            decoder_states = []
            for step in range(batch_size):
                if window_ind + step >= inputs.size(1):
                    break
                predictions = inputs[:, window_ind + step] 
                current_edge_logits, prior_hidden = self.encoder.single_step_forward(predictions, prior_hidden)
                predictions, decoder_hidden, _ = self.single_step_forward(predictions, decoder_hidden, current_edge_logits, True)
                current_batch_preds.append(predictions)
                tmp_prior = self.encoder.copy_states(prior_hidden)
                tmp_decoder = self.copy_states(decoder_hidden)
                prior_states.append(tmp_prior)
                decoder_states.append(tmp_decoder)
                if return_edges:
                    current_batch_edges.append(current_edge_logits.cpu())
            batch_prior_hidden = self.encoder.merge_hidden(prior_states)
            batch_decoder_hidden = self.merge_hidden(decoder_states)
            current_batch_preds = torch.cat(current_batch_preds, 0)
            current_timestep_preds = [current_batch_preds]
            if return_edges:
                current_batch_edges = torch.cat(current_batch_edges, 0)
                current_timestep_edges = [current_batch_edges]
            for step in range(prediction_steps - 1):
                current_batch_edge_logits, batch_prior_hidden = self.encoder.single_step_forward(current_batch_preds, batch_prior_hidden)
                current_batch_preds, batch_decoder_hidden, _ = self.single_step_forward(current_batch_preds, batch_decoder_hidden, current_batch_edge_logits, True)
                current_timestep_preds.append(current_batch_preds)
                if return_edges:
                    current_timestep_edges.append(current_batch_edge_logits.cpu())
            all_timestep_preds.append(torch.stack(current_timestep_preds, dim=1))
            if return_edges:
                all_timestep_edges.append(torch.stack(current_timestep_edges, dim=1))
        result =  torch.cat(all_timestep_preds, dim=0)
        if return_edges:
            edge_result = torch.cat(all_timestep_edges, dim=0)
            return result.unsqueeze(0), edge_result.unsqueeze(0)
        else:
            return result.unsqueeze(0)

    def nll(self, preds, target):
        if self.nll_loss_type == 'crossent':
            return self.nll_crossent(preds, target)
        elif self.nll_loss_type == 'gaussian':
            return self.nll_gaussian(preds, target)
        elif self.nll_loss_type == 'poisson':
            return self.nll_poisson(preds, target)

    def nll_gaussian(self, preds, target, add_const=False):
        neg_log_p = ((preds - target) ** 2 / (2 * self.prior_variance))
        const = 0.5 * np.log(2 * np.pi * self.prior_variance)
        #neg_log_p += const
        if self.normalize_nll_per_var:
            return neg_log_p.sum() / (target.size(0) * target.size(2))
        elif self.normalize_nll:
            return (neg_log_p.sum(-1) + const).view(preds.size(0), -1).mean(dim=1)
        else:
            return neg_log_p.view(target.size(0), -1).sum() / (target.size(1))


    def nll_crossent(self, preds, target):
        if self.normalize_nll:
            return nn.BCEWithLogitsLoss(reduction='none')(preds, target).view(preds.size(0), -1).mean(dim=1)
        else:
            return nn.BCEWithLogitsLoss(reduction='none')(preds, target).view(preds.size(0), -1).sum(dim=1)

    def nll_poisson(self, preds, target):
        if self.normalize_nll:
            return nn.PoissonNLLLoss(reduction='none')(preds, target).view(preds.size(0), -1).mean(dim=1)
        else:
            return nn.PoissonNLLLoss(reduction='none')(preds, target).view(preds.size(0), -1).sum(dim=1)

    def kl_categorical_learned(self, preds, prior_logits):
        log_prior = nn.LogSoftmax(dim=-1)(prior_logits)
        kl_div = preds*(torch.log(preds + 1e-16) - log_prior)
        if self.normalize_kl:     
            return kl_div.sum(-1).view(preds.size(0), -1).mean(dim=1)
        elif self.normalize_kl_per_var:
            return kl_div.sum() / (self.num_vars * preds.size(0))
        else:
            return kl_div.view(preds.size(0), -1).sum(dim=1)

    def kl_categorical_avg(self, preds, eps=1e-16):
        avg_preds = preds.mean(dim=2)
        kl_div = avg_preds*(torch.log(avg_preds+eps) - self.log_prior)
        if self.normalize_kl:     
            return kl_div.sum(-1).view(preds.size(0), -1).mean(dim=1)
        elif self.normalize_kl_per_var:
            return kl_div.sum() / (self.num_vars * preds.size(0))
        else:
            return kl_div.view(preds.size(0), -1).sum(dim=1)


    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))

class DNRI_Encoder(nn.Module):
    # Here, encoder also produces prior
    def __init__(self, params):
        super(DNRI_Encoder, self).__init__()
        num_vars = params['num_vars']
        self.num_edges = params['num_edge_types']
        self.sepaate_prior_encoder = params.get('separate_prior_encoder', False)
        no_bn = False
        dropout = params['encoder_dropout']
        edges = np.ones(num_vars) - np.eye(num_vars)
        self.send_edges = np.where(edges)[0]
        self.recv_edges = np.where(edges)[1]
        self.edge2node_mat = nn.Parameter(torch.FloatTensor(encode_onehot(self.recv_edges).transpose()), requires_grad=False)
        self.save_eval_memory = params.get('encoder_save_eval_memory', False)


        hidden_size = params['encoder_hidden']
        rnn_hidden_size = params['encoder_rnn_hidden']
        rnn_type = params['encoder_rnn_type']
        inp_size = params['input_size']
        self.mlp1 = RefNRIMLP(inp_size, hidden_size, hidden_size, dropout, no_bn=no_bn)
        self.mlp2 = RefNRIMLP(hidden_size * 2, hidden_size, hidden_size, dropout, no_bn=no_bn)
        self.mlp3 = RefNRIMLP(hidden_size, hidden_size, hidden_size, dropout, no_bn=no_bn)
        self.mlp4 = RefNRIMLP(hidden_size * 3, hidden_size, hidden_size, dropout, no_bn=no_bn)

        if rnn_hidden_size is None:
            rnn_hidden_size = hidden_size
        if rnn_type == 'lstm':
            self.forward_rnn = nn.LSTM(hidden_size, rnn_hidden_size, batch_first=True)
            self.reverse_rnn = nn.LSTM(hidden_size, rnn_hidden_size, batch_first=True)
        elif rnn_type == 'gru':
            self.forward_rnn = nn.GRU(hidden_size, rnn_hidden_size, batch_first=True)
            self.reverse_rnn = nn.GRU(hidden_size, rnn_hidden_size, batch_first=True)
        out_hidden_size = 2*rnn_hidden_size
        num_layers = params['encoder_mlp_num_layers']
        if num_layers == 1:
            self.encoder_fc_out = nn.Linear(out_hidden_size, self.num_edges)
        else:
            tmp_hidden_size = params['encoder_mlp_hidden']
            layers = [nn.Linear(out_hidden_size, tmp_hidden_size), nn.ELU(inplace=True)]
            for _ in range(num_layers - 2):
                layers.append(nn.Linear(tmp_hidden_size, tmp_hidden_size))
                layers.append(nn.ELU(inplace=True))
            layers.append(nn.Linear(tmp_hidden_size, self.num_edges))
            self.encoder_fc_out = nn.Sequential(*layers)

        num_layers = params['prior_num_layers']
        if num_layers == 1:
            self.prior_fc_out = nn.Linear(rnn_hidden_size, self.num_edges)
        else:
            tmp_hidden_size = params['prior_hidden_size']
            layers = [nn.Linear(rnn_hidden_size, tmp_hidden_size), nn.ELU(inplace=True)]
            for _ in range(num_layers - 2):
                layers.append(nn.Linear(tmp_hidden_size, tmp_hidden_size))
                layers.append(nn.ELU(inplace=True))
            layers.append(nn.Linear(tmp_hidden_size, self.num_edges))
            self.prior_fc_out = nn.Sequential(*layers)


        self.num_vars = num_vars
        edges = np.ones(num_vars) - np.eye(num_vars)
        self.send_edges = np.where(edges)[0]
        self.recv_edges = np.where(edges)[1]

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)

    def node2edge(self, node_embeddings):
        # Input size: [batch, num_vars, num_timesteps, embed_size]
        if len(node_embeddings.shape) == 4:
            send_embed = node_embeddings[:, self.send_edges, :, :]
            recv_embed = node_embeddings[:, self.recv_edges, :, :]
        else:
            send_embed = node_embeddings[:, self.send_edges, :]
            recv_embed = node_embeddings[:, self.recv_edges, :]
        return torch.cat([send_embed, recv_embed], dim=-1)

    def edge2node(self, edge_embeddings):
        if len(edge_embeddings.shape) == 4:
            old_shape = edge_embeddings.shape
            tmp_embeddings = edge_embeddings.view(old_shape[0], old_shape[1], -1)
            incoming = torch.matmul(self.edge2node_mat, tmp_embeddings).view(old_shape[0], -1, old_shape[2], old_shape[3])
        else:
            incoming = torch.matmul(self.edge2node_mat, edge_embeddings)
        return incoming/(self.num_vars-1) #TODO: do we want this average?


    def copy_states(self, prior_state):
        if isinstance(prior_state, tuple) or isinstance(prior_state, list):
            current_prior_state = (prior_state[0].clone(), prior_state[1].clone())
        else:
            current_prior_state = prior_state.clone()
        return current_prior_state

    def merge_hidden(self, hidden):
        if isinstance(hidden[0], tuple) or isinstance(hidden[0], list):
            result0 = torch.cat([x[0] for x in hidden], dim=0)
            result1 = torch.cat([x[1] for x in hidden], dim=0)
            result = (result0, result1)
        else:
            result = torch.cat(hidden, dim=0)
        return result



    def forward(self, inputs):
        if self.training or not self.save_eval_memory:
            # Inputs is shape [batch, num_timesteps, num_vars, input_size]
            num_timesteps = inputs.size(1)
            x = inputs.transpose(2, 1).contiguous()
            # New shape: [num_sims, num_atoms, num_timesteps, num_dims]
            x = self.mlp1(x)  # 2-layer ELU net per node
            x = self.node2edge(x)
            x = self.mlp2(x)
            x_skip = x
            x = self.edge2node(x)
            x = self.mlp3(x)
            x = self.node2edge(x)
            x = torch.cat((x, x_skip), dim=-1)  # Skip connection
            x = self.mlp4(x)
        
            
            # At this point, x should be [batch, num_edges, num_timesteps, hidden_size]
            # RNN aggregation
            old_shape = x.shape
            x = x.contiguous().view(-1, old_shape[2], old_shape[3])
            forward_x, prior_state = self.forward_rnn(x)
            timesteps = old_shape[2]
            reverse_x = x.flip(1)
            reverse_x, _ = self.reverse_rnn(reverse_x)
            reverse_x = reverse_x.flip(1)
            
            #x: [batch*num_edges, num_timesteps, hidden_size]
            prior_result = self.prior_fc_out(forward_x).view(old_shape[0], old_shape[1], timesteps, self.num_edges).transpose(1,2).contiguous()
            combined_x = torch.cat([forward_x, reverse_x], dim=-1)
            encoder_result = self.encoder_fc_out(combined_x).view(old_shape[0], old_shape[1], timesteps, self.num_edges).transpose(1,2).contiguous()
            return prior_result, encoder_result, prior_state
        else:
            # Inputs is shape [batch, num_timesteps, num_vars, input_size]
            num_timesteps = inputs.size(1)
            all_x = []
            all_forward_x = []
            all_prior_result = []
            prior_state = None
            for timestep in range(num_timesteps):
                x = inputs[:, timestep]
                #x = inputs.transpose(2, 1).contiguous()
                x = self.mlp1(x)  # 2-layer ELU net per node
                x = self.node2edge(x)
                x = self.mlp2(x)
                x_skip = x
                x = self.edge2node(x)
                x = self.mlp3(x)
                x = self.node2edge(x)
                x = torch.cat((x, x_skip), dim=-1)  # Skip connection
                x = self.mlp4(x)
            
                
                # At this point, x should be [batch, num_edges, num_timesteps, hidden_size]
                # RNN aggregation
                old_shape = x.shape
                x = x.contiguous().view(-1, 1, old_shape[-1])
                forward_x, prior_state = self.forward_rnn(x, prior_state)
                all_x.append(x.cpu())
                all_forward_x.append(forward_x.cpu())
                all_prior_result.append(self.prior_fc_out(forward_x).view(old_shape[0], 1, old_shape[1], self.num_edges).cpu())
            reverse_state = None
            all_encoder_result = []
            for timestep in range(num_timesteps-1, -1, -1):
                x = all_x[timestep].cuda()
                reverse_x, reverse_state = self.reverse_rnn(x, reverse_state)
                forward_x = all_forward_x[timestep].cuda()
                
                #x: [batch*num_edges, num_timesteps, hidden_size]
                combined_x = torch.cat([forward_x, reverse_x], dim=-1)
                all_encoder_result.append(self.encoder_fc_out(combined_x).view(inputs.size(0), 1, -1, self.num_edges))
            prior_result = torch.cat(all_prior_result, dim=1).cuda(non_blocking=True)
            encoder_result = torch.cat(all_encoder_result, dim=1).cuda(non_blocking=True)
            return prior_result, encoder_result, prior_state

    def single_step_forward(self, inputs, prior_state):
        # Inputs is shape [batch, num_vars, input_size]
        x = self.mlp1(inputs)  # 2-layer ELU net per node
        x = self.node2edge(x)
        x = self.mlp2(x)
        x_skip = x
        x = self.edge2node(x)
        x = self.mlp3(x)
        x = self.node2edge(x)
        x = torch.cat((x, x_skip), dim=-1)  # Skip connection
        x = self.mlp4(x)

        old_shape = x.shape
        x  = x.contiguous().view(-1, 1, old_shape[-1])
        old_prior_shape = prior_state[0].shape
        prior_state = (prior_state[0].view(1, old_prior_shape[0]*old_prior_shape[1], old_prior_shape[2]),
                       prior_state[1].view(1, old_prior_shape[0]*old_prior_shape[1], old_prior_shape[2]))

        x, prior_state = self.forward_rnn(x, prior_state)
        prior_result = self.prior_fc_out(x).view(old_shape[0], old_shape[1], self.num_edges)
        prior_state = (prior_state[0].view(old_prior_shape), prior_state[1].view(old_prior_shape))
        return prior_result, prior_state

    
class DNRI_Decoder(nn.Module):
    def __init__(self, params):
        super(DNRI_Decoder, self).__init__()
        self.num_vars = num_vars =  params['num_vars']
        input_size = params['input_size']
        self.gpu = params['gpu']
        n_hid = params['decoder_hidden']
        edge_types = params['num_edge_types']
        skip_first = params['skip_first']
        out_size = params['input_size']
        do_prob = params['decoder_dropout']

        self.msg_fc1 = nn.ModuleList(
            [nn.Linear(2*n_hid, n_hid) for _ in range(edge_types)]
        )
        self.msg_fc2 = nn.ModuleList(
            [nn.Linear(n_hid, n_hid) for _ in range(edge_types)]
        )
        self.msg_out_shape = n_hid
        self.skip_first_edge_type = skip_first

        self.hidden_r = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_i = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_h = nn.Linear(n_hid, n_hid, bias=False)

        self.input_r = nn.Linear(input_size, n_hid, bias=True)
        self.input_i = nn.Linear(input_size, n_hid, bias=True)
        self.input_n = nn.Linear(input_size, n_hid, bias=True)

        self.out_fc1 = nn.Linear(n_hid, n_hid)
        self.out_fc2 = nn.Linear(n_hid, n_hid)
        self.out_fc3 = nn.Linear(n_hid, out_size)

        print('Using learned recurrent interaction net decoder.')

        self.dropout_prob = do_prob

        self.num_vars = num_vars
        edges = np.ones(num_vars) - np.eye(num_vars)
        self.send_edges = np.where(edges)[0]
        self.recv_edges = np.where(edges)[1]
        self.edge2node_mat = nn.Parameter(torch.FloatTensor(encode_onehot(self.recv_edges)), requires_grad=False)

    def get_initial_hidden(self, inputs):
        return torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape, device=inputs.device)
        
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)

    def forward(self, inputs, hidden, edges):
        # Input Size: [batch, num_vars, input_size]
        # Hidden Size: [batch, num_vars, rnn_hidden]
        # Edges size: [batch, num_edges, num_edge_types]
        if self.training:
            dropout_prob = self.dropout_prob
        else:
            dropout_prob = 0.
        
        # node2edge
        receivers = hidden[:, self.recv_edges, :]
        senders = hidden[:, self.send_edges, :]

        # pre_msg: [batch, num_edges, 2*msg_out]
        pre_msg = torch.cat([receivers, senders], dim=-1)

        all_msgs = torch.zeros(pre_msg.size(0), pre_msg.size(1),
                                        self.msg_out_shape, device=inputs.device)
        
        if self.skip_first_edge_type:
            start_idx = 1
            norm = float(len(self.msg_fc2)) - 1
        else:
            start_idx = 0
            norm = float(len(self.msg_fc2))

        # Run separate MLP for every edge type
        # NOTE: to exclude one edge type, simply offset range by 1
        for i in range(start_idx, len(self.msg_fc2)):
            msg = torch.tanh(self.msg_fc1[i](pre_msg))
            msg = F.dropout(msg, p=dropout_prob)
            msg = torch.tanh(self.msg_fc2[i](msg))
            msg = msg * edges[:, :, i:i+1]
            all_msgs += msg/norm

        # This step sums all of the messages per node
        agg_msgs = all_msgs.transpose(-2, -1).matmul(self.edge2node_mat).transpose(-2, -1)
        agg_msgs = agg_msgs.contiguous() / (self.num_vars - 1) # Average

        # GRU-style gated aggregation
        inp_r = self.input_r(inputs).view(inputs.size(0), self.num_vars, -1)
        inp_i = self.input_i(inputs).view(inputs.size(0), self.num_vars, -1)
        inp_n = self.input_n(inputs).view(inputs.size(0), self.num_vars, -1)
        r = torch.sigmoid(inp_r + self.hidden_r(agg_msgs))
        i = torch.sigmoid(inp_i + self.hidden_i(agg_msgs))
        n = torch.tanh(inp_n + r*self.hidden_h(agg_msgs))
        hidden = (1 - i)*n + i*hidden

        # Output MLP
        pred = F.dropout(F.relu(self.out_fc1(hidden)), p=dropout_prob)
        pred = F.dropout(F.relu(self.out_fc2(pred)), p=dropout_prob)
        pred = self.out_fc3(pred)

        pred = inputs + pred

        return pred, hidden


class DNRI_MLP_Decoder(nn.Module):
    def __init__(self, params):
        super(DNRI_MLP_Decoder, self).__init__()
        num_vars = params['num_vars']
        edge_types = params['num_edge_types']
        n_hid = params['decoder_hidden']
        msg_hid = params['decoder_hidden']
        msg_out = msg_hid #TODO: make this a param
        skip_first = params['skip_first']
        n_in_node = params['input_size']

        do_prob = params['decoder_dropout']
        in_size = n_in_node
        self.msg_fc1 = nn.ModuleList(
            [nn.Linear(2 * in_size, msg_hid) for _ in range(edge_types)])
        self.msg_fc2 = nn.ModuleList(
            [nn.Linear(msg_hid, msg_out) for _ in range(edge_types)])
        self.msg_out_shape = msg_out
        self.skip_first_edge_type = skip_first

        out_size = n_in_node
        self.out_fc1 = nn.Linear(in_size + msg_out, n_hid)
        self.out_fc2 = nn.Linear(n_hid, n_hid)
        self.out_fc3 = nn.Linear(n_hid, out_size)

        print('Using learned interaction net decoder.')

        self.dropout_prob = do_prob
        self.num_vars = num_vars
        edges = np.ones(num_vars) - np.eye(num_vars)
        self.send_edges = np.where(edges)[0]
        self.recv_edges = np.where(edges)[1]
        self.edge2node_mat = nn.Parameter(torch.FloatTensor(encode_onehot(self.recv_edges)), requires_grad=False)

    def get_initial_hidden(self, inputs):
        return None

    def forward(self, inputs, hidden, edges):

        # single_timestep_inputs has shape
        # [batch_size, num_atoms, num_dims]

        # single_timestep_rel_type has shape:
        # [batch_size, num_atoms*(num_atoms-1), num_edge_types]
        # Node2edge
        receivers = inputs[:, self.recv_edges, :]
        senders = inputs[:, self.send_edges, :]
        pre_msg = torch.cat([receivers, senders], dim=-1)

        if inputs.is_cuda:
            all_msgs = torch.cuda.FloatTensor(pre_msg.size(0), pre_msg.size(1),
                                self.msg_out_shape).fill_(0.)
        else:
            all_msgs = torch.zeros(pre_msg.size(0), pre_msg.size(1),
                                self.msg_out_shape)

        if self.skip_first_edge_type:
            start_idx = 1
        else:
            start_idx = 0
        if self.training:
            p = self.dropout_prob
        else:
            p = 0

        # Run separate MLP for every edge type
        # NOTE: To exlude one edge type, simply offset range by 1
        for i in range(start_idx, len(self.msg_fc2)):
            msg = F.relu(self.msg_fc1[i](pre_msg))
            msg = F.dropout(msg, p=p)
            msg = F.relu(self.msg_fc2[i](msg))
            msg = msg * edges[:, :, i:i + 1]
            all_msgs += msg

        # Aggregate all msgs to receiver
        agg_msgs = all_msgs.transpose(-2, -1).matmul(self.edge2node_mat).transpose(-2, -1)
        agg_msgs = agg_msgs.contiguous()

        # Skip connection
        aug_inputs = torch.cat([inputs, agg_msgs], dim=-1)

        # Output MLP
        pred = F.dropout(F.relu(self.out_fc1(aug_inputs)), p=p)
        pred = F.dropout(F.relu(self.out_fc2(pred)), p=p)
        pred = self.out_fc3(pred)

        # Predict position/velocity difference
        return inputs + pred, None

In [7]:
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np
import math


class DNRI_DynamicVars(nn.Module):
    def __init__(self, params):
        super(DNRI_DynamicVars, self).__init__()
        # Model Params
        self.encoder = DNRI_DynamicVars_Encoder(params)
        self.decoder = DNRI_DynamicVars_Decoder(params)
        self.num_edge_types = params.get('num_edge_types')

        # Training params
        self.gumbel_temp = params.get('gumbel_temp')
        self.train_hard_sample = params.get('train_hard_sample')
        self.teacher_forcing_steps = params.get('teacher_forcing_steps', -1)

        self.normalize_kl = params.get('normalize_kl', False)
        self.normalize_kl_per_var = params.get('normalize_kl_per_var', False)
        self.normalize_nll = params.get('normalize_nll', False)
        self.normalize_nll_per_var = params.get('normalize_nll_per_var', False)
        self.kl_coef = params.get('kl_coef', 1.)
        self.nll_loss_type = params.get('nll_loss_type', 'crossent')
        self.prior_variance = params.get('prior_variance')
        self.timesteps = params.get('timesteps', 0)

        self.burn_in_steps = params.get('train_burn_in_steps')
        self.no_prior = params.get('no_prior', False)
        self.avg_prior = params.get('avg_prior', False)
        self.learned_prior = params.get('use_learned_prior', False)
        self.anneal_teacher_forcing = params.get('anneal_teacher_forcing', False)
        self.teacher_forcing_prior = params.get('teacher_forcing_prior', False)
        self.steps = 0
        self.gpu = params.get('gpu')

    def get_graph_info(self, masks):
        num_vars = masks.size(-1)
        edges = torch.ones(num_vars, device=masks.device) - torch.eye(num_vars, device=masks.device)
        tmp = torch.where(edges)
        send_edges = tmp[0]
        recv_edges = tmp[1]
        tmp_inds = torch.tensor(list(range(num_vars)), device=masks.device, dtype=torch.long).unsqueeze_(1) #TODO: should initialize as long
        edge2node_inds = (tmp_inds == recv_edges.unsqueeze(0)).nonzero()[:, 1].contiguous().view(-1, num_vars-1)
        edge_masks = masks[:, :, send_edges]*masks[:, :, recv_edges] #TODO: gotta figure this one out still
        return send_edges, recv_edges, edge2node_inds, edge_masks

    def single_step_forward(self, inputs, node_masks, graph_info, decoder_hidden, edge_logits, hard_sample):
        old_shape = edge_logits.shape
        edges = model_utils.gumbel_softmax(
            edge_logits.reshape(-1, self.num_edge_types), 
            tau=self.gumbel_temp, 
            hard=hard_sample).view(old_shape)
        predictions, decoder_hidden = self.decoder(inputs, decoder_hidden, edges, node_masks, graph_info)
        return predictions, decoder_hidden, edges


    def normalize_inputs(self, inputs, node_masks):
        return self.encoder.normalize_inputs(inputs, node_masks)

    #profile
    def calculate_loss(self, inputs, node_masks, node_inds, graph_info, is_train=False, teacher_forcing=True, return_edges=False, return_logits=False, use_prior_logits=False, normalized_inputs=None):
        decoder_hidden = self.decoder.get_initial_hidden(inputs)
        num_time_steps = inputs.size(1)
        all_edges = []
        all_predictions = []
        all_priors = []
        hard_sample = (not is_train) or self.train_hard_sample
        prior_logits, posterior_logits, _ = self.encoder(inputs[:, :-1], node_masks[:, :-1], node_inds, graph_info, normalized_inputs)
        if self.anneal_teacher_forcing:
            teacher_forcing_steps = math.ceil((1 - self.train_percent)*num_time_steps)
        else:
            teacher_forcing_steps = self.teacher_forcing_steps
        edge_ind = 0
        for step in range(num_time_steps-1):
            if (teacher_forcing and (teacher_forcing_steps == -1 or step < teacher_forcing_steps)) or step == 0:
                current_inputs = inputs[:, step]
            else:
                current_inputs = predictions
            current_node_masks = node_masks[:, step]
            node_inds = current_node_masks.nonzero()[:, -1]
            num_edges = len(node_inds)*(len(node_inds)-1)
            current_graph_info = graph_info[0][step]
            if not use_prior_logits:
                current_p_logits = posterior_logits[:, edge_ind:edge_ind+num_edges]
            else:
                current_p_logits = prior_logits[:, edge_ind:edge_ind+num_edges]
            if self.gpu:
                current_p_logits = current_p_logits.cuda(non_blocking=True)
            edge_ind += num_edges
            predictions, decoder_hidden, edges = self.single_step_forward(current_inputs, current_node_masks, current_graph_info, decoder_hidden, current_p_logits, hard_sample)
            all_predictions.append(predictions)
            all_edges.append(edges)
        all_predictions = torch.stack(all_predictions, dim=1)
        target = inputs[:, 1:, :, :]
        target_masks = ((node_masks[:, :-1] == 1)*(node_masks[:, 1:] == 1)).float()
        loss_nll = self.nll(all_predictions, target, target_masks)
        prob = F.softmax(posterior_logits, dim=-1)
        if self.gpu:
            prob = prob.cuda(non_blocking=True)
            prior_logits = prior_logits.cuda(non_blocking=True)
        loss_kl = self.kl_categorical_learned(prob, prior_logits)
        loss = loss_nll + self.kl_coef*loss_kl
        loss = loss.mean()
        if return_edges:
            return loss, loss_nll, loss_kl, edges
        elif return_logits:
            return loss, loss_nll, loss_kl, posterior_logits, all_predictions
        else:
            return loss, loss_nll, loss_kl

    def get_prior_posterior(self, inputs, student_force=False, burn_in_steps=None):
        self.eval()
        posterior_logits = self.encoder(inputs)
        posterior_probs = torch.softmax(posterior_logits, dim=-1)
        prior_hidden = self.prior_model.get_initial_hidden(inputs)
        all_logits = []
        if student_force:
            decoder_hidden = self.decoder.get_initial_hidden(inputs)
            for step in range(burn_in_steps):
                current_inputs= inputs[:, step]
                predictions, prior_hidden, decoder_hidden, _, prior_logits = self.single_step_forward(current_inputs, prior_hidden, decoder_hidden, None, True)
                all_logits.append(prior_logits)
            for step in range(inputs.size(1) - burn_in_steps):
                predictions, prior_hidden, decoder_hidden, _, prior_logits = self.single_step_forward(predictions, prior_hidden, decoder_hidden, None, True)
                all_logits.append(prior_logits)
        else:
            for step in range(inputs.size(1)):
                current_inputs = inputs[:, step]
                prior_logits, prior_hidden = self.prior_model(prior_hidden, current_inputs)
                all_logits.append(prior_logits)
        logits = torch.stack(all_logits, dim=1)
        prior_probs = torch.softmax(logits, dim=-1)
        return prior_probs, posterior_probs

    def get_edge_probs(self, inputs):
        self.eval()
        prior_hidden = self.prior_model.get_initial_hidden(inputs)
        all_logits = []
        for step in range(inputs.size(1)):
            current_inputs = inputs[:, step]
            prior_logits, prior_hidden = self.prior_model(prior_hidden, current_inputs)
            all_logits.append(prior_logits)
        logits = torch.stack(all_logits, dim=1)
        edge_probs = torch.softmax(logits, dim=-1)
        return edge_probs

    def predict_future(self, inputs, masks, node_inds, graph_info, burn_in_masks):
        '''
        Here, we assume the following:
        * inputs contains all of the gt inputs, including for the time steps we're predicting
        * masks keeps track of the variables that are being tracked
        * burn_in_masks is set to 1 whenever we're supposed to feed in that variable's state
          for a given time step
        '''
        total_timesteps = inputs.size(1)
        prior_hidden = self.encoder.get_initial_hidden(inputs)
        decoder_hidden = self.decoder.get_initial_hidden(inputs)
        predictions = inputs[:, 0]
        preds = []
        for step in range(total_timesteps-1):
            current_masks = masks[:, step]
            current_burn_in_masks = burn_in_masks[:, step].unsqueeze(-1).type(inputs.dtype)
            current_inps = inputs[:, step]
            current_node_inds = node_inds[0][step] #TODO: check what's passed in here
            current_graph_info = graph_info[0][step]
            encoder_inp = current_burn_in_masks*current_inps + (1-current_burn_in_masks)*predictions
            current_edge_logits, prior_hidden = self.encoder.single_step_forward(encoder_inp, current_masks, current_node_inds, current_graph_info, prior_hidden)
            predictions, decoder_hidden, _ = self.single_step_forward(encoder_inp, current_masks, current_graph_info, decoder_hidden, current_edge_logits, True)
            preds.append(predictions)
        return torch.stack(preds, dim=1)

    def copy_states(self, prior_state, decoder_state):
        if isinstance(prior_state, tuple) or isinstance(prior_state, list):
            current_prior_state = (prior_state[0].clone(), prior_state[1].clone())
        else:
            current_prior_state = prior_state.clone()
        if isinstance(decoder_state, tuple) or isinstance(decoder_state, list):
            current_decoder_state = (decoder_state[0].clone(), decoder_state[1].clone())
        else:
            current_decoder_state = decoder_state.clone()
        return current_prior_state, current_decoder_state

    def merge_hidden(self, hidden):
        if isinstance(hidden[0], tuple) or isinstance(hidden[0], list):
            result0 = torch.cat([x[0] for x in hidden], dim=0)
            result1 = torch.cat([x[1] for x in hidden], dim=0)
            return (result0, result1)
        else:
            return torch.cat(hidden, dim=0)

    def predict_future_fixedwindow(self, inputs, burn_in_steps, prediction_steps, batch_size):
        if self.fix_encoder_alignment:
            prior_logits, _, prior_hidden = self.encoder(inputs)
        else:
            prior_logits, _, prior_hidden = self.encoder(inputs[:, :-1])
        decoder_hidden = self.decoder.get_initial_hidden(inputs)
        for step in range(burn_in_steps-1):
            current_inputs = inputs[:, step]
            current_edge_logits = prior_logits[:, step]
            predictions, decoder_hidden, _ = self.single_step_forward(current_inputs, decoder_hidden, current_edge_logits, True)
        all_timestep_preds = []
        for window_ind in range(burn_in_steps - 1, inputs.size(1)-1, batch_size):
            current_batch_preds = []
            prior_states = []
            decoder_states = []
            for step in range(batch_size):
                if window_ind + step >= inputs.size(1):
                    break
                predictions = inputs[:, window_ind + step] 
                current_edge_logits, prior_hidden = self.encoder.single_step_forward(predictions, prior_hidden)
                predictions, decoder_hidden, _ = self.single_step_forward(predictions, decoder_hidden, current_edge_logits, True)
                current_batch_preds.append(predictions)
                tmp_prior, tmp_decoder = self.copy_states(prior_hidden, decoder_hidden)
                prior_states.append(tmp_prior)
                decoder_states.append(tmp_decoder)
            batch_prior_hidden = self.merge_hidden(prior_states)
            batch_decoder_hidden = self.merge_hidden(decoder_states)
            current_batch_preds = torch.cat(current_batch_preds, 0)
            current_timestep_preds = [current_batch_preds]
            for step in range(prediction_steps - 1):
                current_batch_edge_logits, batch_prior_hidden = self.encoder.single_step_forward(current_batch_preds, batch_prior_hidden)
                current_batch_preds, batch_decoder_hidden, _ = self.single_step_forward(current_batch_preds, batch_decoder_hidden, current_batch_edge_logits, True)
                current_timestep_preds.append(current_batch_preds)
            all_timestep_preds.append(torch.stack(current_timestep_preds, dim=1))
        result =  torch.cat(all_timestep_preds, dim=0)
        return result.unsqueeze(0)

    def nll(self, preds, target, masks):
        if self.nll_loss_type == 'crossent':
            return self.nll_crossent(preds, target, masks)
        elif self.nll_loss_type == 'gaussian':
            return self.nll_gaussian(preds, target, masks)
        elif self.nll_loss_type == 'poisson':
            return self.nll_poisson(preds, target, masks)

    def nll_gaussian(self, preds, target, masks, add_const=False):
        neg_log_p = ((preds - target) ** 2 / (2 * self.prior_variance))*masks.unsqueeze(-1)
        const = 0.5 * np.log(2 * np.pi * self.prior_variance)
        #neg_log_p += const
        if self.normalize_nll_per_var:
            raise NotImplementedError()
        elif self.normalize_nll:
            return (neg_log_p.sum(-1) + const*masks).view(preds.size(0), -1).sum(dim=-1)/(masks.view(masks.size(0), -1).sum(dim=1)+1e-8)
        else:
            raise NotImplementedError()


    def nll_crossent(self, preds, target, masks):
        if self.normalize_nll:
            loss = nn.BCEWithLogitsLoss(reduction='none')(preds, target)
            return (loss*masks.unsqueeze(-1)).view(preds.size(0), -1).sum(dim=-1)/(masks.view(masks.size(0), -1).sum(dim=1))
        else:
            raise NotImplementedError()

    def nll_poisson(self, preds, target, masks):
        if self.normalize_nll:
            loss = nn.PoissonNLLLoss(reduction='none')(preds, target)
            return (loss*masks.unsqueeze(-1)).view(preds.size(0), -1).sum(dim=-1)/(masks.view(masks.size(0), -1).sum(dim=1))
        else:
            raise NotImplementedError()

    def kl_categorical_learned(self, preds, prior_logits):
        log_prior = nn.LogSoftmax(dim=-1)(prior_logits)
        kl_div = preds*(torch.log(preds + 1e-16) - log_prior)
        if self.normalize_kl:     
            return kl_div.sum(-1).view(preds.size(0), -1).mean(dim=1)
        elif self.normalize_kl_per_var:
            raise NotImplementedError()
        else:
            raise NotImplementedError()

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))


class DNRI_DynamicVars_Encoder(nn.Module):
    # Here, encoder also produces prior
    def __init__(self, params):
        super(DNRI_DynamicVars_Encoder, self).__init__()
        self.num_edges = params['num_edge_types']
        self.smooth_graph = params.get('smooth_graph', False)
        self.gpu_parallel = params.get('gpu_parallel', False)
        self.pool_edges = params.get('pool_edges', False)
        self.separate_prior_encoder = params.get('separate_prior_encoder', False)
        self.gpu = params.get('gpu')
        no_bn = params['no_encoder_bn']
        dropout = params['encoder_dropout']

        hidden_size = params['encoder_hidden']
        self.rnn_hidden_size = rnn_hidden_size = params['encoder_rnn_hidden']
        rnn_type = params['encoder_rnn_type']
        inp_size = params['input_size']
        self.mlp1 = RefNRIMLP(inp_size, hidden_size, hidden_size, dropout, no_bn=no_bn)
        self.mlp2 = RefNRIMLP(hidden_size * 2, hidden_size, hidden_size, dropout, no_bn=no_bn)
        self.mlp3 = RefNRIMLP(hidden_size, hidden_size, hidden_size, dropout, no_bn=no_bn)
        self.mlp4 = RefNRIMLP(hidden_size * 3, hidden_size, hidden_size, dropout, no_bn=no_bn)
        self.train_data_len = params.get('train_data_len', -1)

        if rnn_hidden_size is None:
            rnn_hidden_size = hidden_size
        if rnn_type == 'lstm':
            self.forward_rnn = nn.LSTM(hidden_size, rnn_hidden_size, batch_first=True)
            self.reverse_rnn = nn.LSTM(hidden_size, rnn_hidden_size, batch_first=True)
        elif rnn_type == 'gru':
            self.forward_rnn = nn.GRU(hidden_size, rnn_hidden_size, batch_first=True)
            self.reverse_rnn = nn.GRU(hidden_size, rnn_hidden_size, batch_first=True)
        out_hidden_size = 2*rnn_hidden_size
        num_layers = params['encoder_mlp_num_layers']
        if num_layers == 1:
            self.encoder_fc_out = nn.Linear(out_hidden_size, self.num_edges)
        else:
            tmp_hidden_size = params['encoder_mlp_hidden']
            layers = [nn.Linear(out_hidden_size, tmp_hidden_size), nn.ELU(inplace=True)]
            for _ in range(num_layers - 2):
                layers.append(nn.Linear(tmp_hidden_size, tmp_hidden_size))
                layers.append(nn.ELU(inplace=True))
            layers.append(nn.Linear(tmp_hidden_size, self.num_edges))
            self.encoder_fc_out = nn.Sequential(*layers)

        num_layers = params['prior_num_layers']
        if num_layers == 1:
            self.prior_fc_out = nn.Linear(rnn_hidden_size, self.num_edges)
        else:
            tmp_hidden_size = params['prior_hidden_size']
            layers = [nn.Linear(rnn_hidden_size, tmp_hidden_size), nn.ELU(inplace=True)]
            for _ in range(num_layers - 2):
                layers.append(nn.Linear(tmp_hidden_size, tmp_hidden_size))
                layers.append(nn.ELU(inplace=True))
            layers.append(nn.Linear(tmp_hidden_size, self.num_edges))
            self.prior_fc_out = nn.Sequential(*layers)

        self.normalize_mode = params['encoder_normalize_mode']
        if self.normalize_mode == 'normalize_inp':
            self.bn = nn.BatchNorm1d(inp_size)
        # Possible options: None, 'normalize_inp', 'normalize_all'

        self.init_weights()
        
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)
    def node2edge(self, node_embeddings, send_edges, recv_edges):
        #node_embeddings: [batch, num_nodes, embed_size]
        send_embed = node_embeddings[:, send_edges, :]
        recv_embed = node_embeddings[:, recv_edges, :]
        return torch.cat([send_embed, recv_embed], dim=-1) 

    def edge2node(self, edge_embeddings, edge2node_inds, num_vars):
        incoming = edge_embeddings[:, edge2node_inds[:, 0], :].clone()
        for i in range(1, edge2node_inds.size(1)):
            incoming += edge_embeddings[:, edge2node_inds[:, i]]
        return incoming/(num_vars-1)
    
    def get_initial_hidden(self, inputs):
        batch = inputs.size(0)*inputs.size(2)*(inputs.size(2)-1)
        hidden = torch.zeros(1, batch, self.rnn_hidden_size, device=inputs.device)
        cell = torch.zeros(1, batch, self.rnn_hidden_size, device=inputs.device)
        return hidden, cell

    def batch_node2edge(self, x, send_edges, recv_edges):
        send_embed = x[send_edges]
        recv_embed = x[recv_edges]
        return torch.cat([send_embed, recv_embed], dim=-1)

    def batch_edge2node(self, x, result_shape, recv_edges):
        result = torch.zeros(result_shape, device=x.device)
        result.index_add_(0, recv_edges, x)
        return result

    def normalize_inputs(self, inputs, node_masks):
        if self.normalize_mode == 'normalize_inp':
            raise NotImplementedError
        elif self.normalize_mode == 'normalize_all':
            result = self.compute_feat_transform(inputs, node_masks)
        return result

    def compute_feat_transform(self, inputs, node_masks):
        # The following is to ensure we don't run on a sequence that's too long at once
        if len(inputs.shape) == 3:
            inputs = inputs.unsqueeze(0)
            node_masks = node_masks.unsqueeze(0)
        inp_list = inputs.split(self.train_data_len, dim=1)
        node_masks_list = node_masks.split(self.train_data_len, dim=1)
        final_result = [[] for _ in range(inputs.size(0))]
        count = 0
        for inputs, node_masks in zip(inp_list, node_masks_list):
            if not self.training:
                #print("IND %d OF %d"%(count, len(node_masks_list)))
                count += 1
            flat_masks = node_masks.nonzero(as_tuple=True)
            batch_num_vars = (node_masks != 0).sum(dim=-1)
            num_vars = batch_num_vars.view(-1)
            #TODO: is it faster to put this on cpu or gpu?
            edge_info = [torch.where(torch.ones(nvar, device=num_vars.device) - torch.eye(nvar, device=num_vars.device)) for nvar in num_vars]
            offsets = torch.cat([torch.tensor([0], dtype=torch.long, device=num_vars.device), num_vars.cumsum(0)[:-1]])
            send_edges = torch.cat([l[0] + offset for l,offset in zip(edge_info, offsets)])
            recv_edges = torch.cat([l[1] + offset for l, offset in zip(edge_info, offsets)])
            
            flat_inp = inputs[flat_masks]
            tmp_batch = flat_inp.size(0)
            x = self.mlp1(flat_inp)
            x = self.batch_node2edge(x, send_edges, recv_edges)
            x = self.mlp2(x)
            x_skip = x
            result_shape = (tmp_batch, x.size(-1))
            x = self.batch_edge2node(x, result_shape, recv_edges)
            x = self.mlp3(x)
            x = self.batch_node2edge(x, send_edges, recv_edges)
            x = torch.cat([x, x_skip], dim=-1)
            x = self.mlp4(x)
            # TODO: extract into batch-wise structure
            num_edges = batch_num_vars*(batch_num_vars-1)
            num_edges = num_edges.sum(dim=-1)
            #if not self.training:
            #    x = x.cpu()
            batched_result = x.split(num_edges.tolist())
            for ind,tmp_result in enumerate(batched_result):
                final_result[ind].append(tmp_result)
        final_result = [torch.cat(tmp_result) for tmp_result in final_result]
        return final_result

    def forward(self, inputs, node_masks, all_node_inds, all_graph_info, normalized_inputs=None):
        if inputs.size(0) > 1:
            raise ValueError("Batching during forward not currently supported")
        if self.normalize_mode == 'normalize_all':
            if normalized_inputs is not None:
                x = torch.cat(normalized_inputs, dim=0)
            else:
                x = torch.cat(self.normalize_inputs(inputs, node_masks), dim=0)
        else:
            # Right now, we'll always want to do this
            raise NotImplementedError
        # Inputs is shape [batch, num_timesteps, num_vars, input_size]
        num_timesteps = node_masks.size(1)
        max_num_vars = inputs.size(2)
        max_num_edges = max_num_vars*(max_num_vars-1)
        forward_state = (torch.zeros(1, max_num_edges, self.rnn_hidden_size, device=inputs.device),
                         torch.zeros(1, max_num_edges, self.rnn_hidden_size, device=inputs.device))
        reverse_state = (torch.zeros(1, max_num_edges, self.rnn_hidden_size, device=inputs.device),
                         torch.zeros(1, max_num_edges, self.rnn_hidden_size, device=inputs.device))
        all_x = []
        all_forward_states = []
        all_reverse_states = [] 
        prior_results = []
        x_ind = 0
        for timestep in range(num_timesteps):
            current_node_masks = node_masks[:, timestep]
            node_inds = all_node_inds[0][timestep]
            if self.gpu:
                node_inds = node_inds.cuda(non_blocking=True)
            if len(node_inds) <= 1:
                all_forward_states.append(torch.empty(1, 0, self.rnn_hidden_size, device=inputs.device))
                all_x.append(None)
                continue
            send_edges, recv_edges, _ = all_graph_info[0][timestep]
            if self.gpu:
                send_edges, recv_edges = send_edges.cuda(non_blocking=True), recv_edges.cuda(non_blocking=True)
            global_send_edges = node_inds[send_edges]
            global_recv_edges = node_inds[recv_edges]
            global_edge_inds = global_send_edges*(max_num_vars-1) + global_recv_edges - (global_recv_edges >= global_send_edges).long()
            current_x = x[x_ind:x_ind+len(global_send_edges)]
            if self.gpu:
                current_x = current_x.cuda(non_blocking=True)
            x_ind += len(global_send_edges)

            old_shape = current_x.shape
            current_x = current_x.view(old_shape[-2], 1, old_shape[-1])
            current_state = (forward_state[0][:, global_edge_inds], forward_state[1][:, global_edge_inds])
            current_x, current_state = self.forward_rnn(current_x, current_state)
            tmp_state0 = forward_state[0].clone()
            tmp_state0[:, global_edge_inds] = current_state[0]
            tmp_state1 = forward_state[1].clone()
            tmp_state1[:, global_edge_inds] = current_state[1]
            all_forward_states.append(current_state[0])

        # Reverse pass
        encoder_results = []
        x_ind = x.size(0)
        for timestep in range(num_timesteps-1, -1, -1):
            current_node_masks = node_masks[:, timestep]
            node_inds = all_node_inds[0][timestep]
            if self.gpu:
                node_inds = node_inds.cuda(non_blocking=True)

            if len(node_inds) <= 1:
                continue
            send_edges, recv_edges, _ = all_graph_info[0][timestep]
            if self.gpu:
                send_edges, recv_edges = send_edges.cuda(non_blocking=True), recv_edges.cuda(non_blocking=True)
            global_send_edges = node_inds[send_edges]
            global_recv_edges = node_inds[recv_edges]
            global_edge_inds = global_send_edges*(max_num_vars-1) + global_recv_edges - (global_recv_edges >= global_send_edges).long()
            current_x = x[x_ind-len(global_send_edges):x_ind]
            if self.gpu:
                current_x = current_x.cuda(non_blocking=True)
            x_ind -= len(global_send_edges)
            old_shape = current_x.shape
            current_x = current_x.view(old_shape[-2], 1, old_shape[-1])

            current_state = (reverse_state[0][:, global_edge_inds], reverse_state[1][:, global_edge_inds])
            current_x, current_state = self.reverse_rnn(current_x, current_state)

            tmp_state0 = reverse_state[0].clone()
            tmp_state0[:, global_edge_inds] = current_state[0]
            tmp_state1 = reverse_state[1].clone()
            tmp_state1[:, global_edge_inds] = current_state[1]
            all_reverse_states.append(current_state[0])
        all_forward_states = torch.cat(all_forward_states, dim=1)
        all_reverse_states = torch.cat(all_reverse_states, dim=1).flip(1)
        all_states = torch.cat([all_forward_states, all_reverse_states], dim=-1)
        prior_result = self.prior_fc_out(all_forward_states)
        encoder_result = self.encoder_fc_out(all_states)
        return prior_result, encoder_result, forward_state
        
    def single_step_forward(self, inputs, node_masks, node_inds, all_graph_info, forward_state):
        if self.normalize_mode == 'normalize_all':
            x = self.normalize_inputs(inputs, node_masks)[0]
        else:
            raise NotImplementedError
        # Inputs is shape [batch, num_vars, input_size]
        max_num_vars = inputs.size(1)
        max_num_edges = max_num_vars*(max_num_vars-1)
        if len(node_inds) > 1:
            send_edges, recv_edges, _ = all_graph_info
            global_send_edges = node_inds[send_edges]
            global_recv_edges = node_inds[recv_edges]
            global_edge_inds = global_send_edges*(max_num_vars-1) + global_recv_edges - (global_recv_edges >= global_send_edges).long()
            old_shape = x.shape
            x = x.view(old_shape[-2], 1, old_shape[-1])
            current_state = (forward_state[0][:, global_edge_inds], forward_state[1][:, global_edge_inds])
            x, current_state = self.forward_rnn(x, current_state)
            tmp_state0 = forward_state[0].clone()
            tmp_state0[:, global_edge_inds] = current_state[0]
            tmp_state1 = forward_state[1].clone()
            tmp_state1[:, global_edge_inds] = current_state[1]
            forward_state = (tmp_state0, tmp_state1)
            prior_result = self.prior_fc_out(current_state[0])
        else:
            prior_result = torch.empty(1, 0, self.num_edges)

        return prior_result, forward_state


class DNRI_DynamicVars_Decoder(nn.Module):
    def __init__(self, params):
        super(DNRI_DynamicVars_Decoder, self).__init__()
        input_size = params['input_size']
        self.gpu = params['gpu']
        n_hid = params['decoder_hidden']
        edge_types = params['num_edge_types']
        skip_first = params['skip_first']
        out_size = params['input_size']
        do_prob = params['decoder_dropout']

        self.msg_fc1 = nn.ModuleList(
            [nn.Linear(2*n_hid, n_hid) for _ in range(edge_types)]
        )
        self.msg_fc2 = nn.ModuleList(
            [nn.Linear(n_hid, n_hid) for _ in range(edge_types)]
        )
        self.msg_out_shape = n_hid
        self.skip_first_edge_type = skip_first

        self.hidden_r = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_i = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_h = nn.Linear(n_hid, n_hid, bias=False)

        self.input_r = nn.Linear(input_size, n_hid, bias=True)
        self.input_i = nn.Linear(input_size, n_hid, bias=True)
        self.input_n = nn.Linear(input_size, n_hid, bias=True)

        self.out_fc1 = nn.Linear(n_hid, n_hid)
        self.out_fc2 = nn.Linear(n_hid, n_hid)
        self.out_fc3 = nn.Linear(n_hid, out_size)

        print('Using learned recurrent interaction net decoder.')

        self.dropout_prob = do_prob
        
    def get_initial_hidden(self, inputs):
        return torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape, device=inputs.device)
        
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)

    def forward(self, inputs, hidden, edges, node_masks, graph_info):
        # Input Size: [batch, num_vars, input_size]
        # Hidden Size: [batch, num_vars, rnn_hidden]
        # Edges size: [batch, current_num_edges, num_edge_types]

        max_num_vars = inputs.size(1)
        node_inds = node_masks.nonzero()[:, -1]

        current_hidden = hidden[:, node_inds]
        current_inputs = inputs[:, node_inds]
        num_vars = current_hidden.size(1)
        
        if num_vars > 1:
            send_edges, recv_edges, edge2node_inds = graph_info
            if self.gpu:
                send_edges, recv_edges, edge2node_inds = send_edges.cuda(non_blocking=True), recv_edges.cuda(non_blocking=True), edge2node_inds.cuda(non_blocking=True)
            global_send_edges = node_inds[send_edges]
            global_recv_edges = node_inds[recv_edges]
            receivers = current_hidden[:, recv_edges]
            senders = current_hidden[:, send_edges]
            # pre_msg: [batch, num_edges, 2*msg_out]
            pre_msg = torch.cat([receivers, senders], dim=-1)

            all_msgs = torch.zeros(pre_msg.size(0), pre_msg.size(1),
                                            self.msg_out_shape, device=inputs.device)
            if self.skip_first_edge_type:
                start_idx = 1
                norm = float(len(self.msg_fc2)) - 1
            else:
                start_idx = 0
                norm = float(len(self.msg_fc2))

            # Run separate MLP for every edge type
            # NOTE: to exclude one edge type, simply offset range by 1
            for i in range(start_idx, len(self.msg_fc2)):
                msg = torch.tanh(self.msg_fc1[i](pre_msg))
                msg = F.dropout(msg, p=self.dropout_prob)
                msg = torch.tanh(self.msg_fc2[i](msg))
                msg = msg * edges[:, :, i:i+1]
                all_msgs += msg/norm

            incoming = all_msgs[:, edge2node_inds[:, 0], :].clone()
            for i in range(1, edge2node_inds.size(1)):
                incoming += all_msgs[:, edge2node_inds[:, i], :]
            agg_msgs = incoming/(num_vars-1)
        elif num_vars == 0:
            pred_all = torch.zeros(inputs.size(0), max_num_vars, inputs.size(-1), device=inputs.device)
            return pred_all, hidden
        else:
            agg_msgs = torch.zeros(current_inputs.size(0), num_vars, self.msg_out_shape, device=inputs.device)

        # GRU-style gated aggregation
        inp_r = self.input_r(current_inputs).view(current_inputs.size(0), num_vars, -1)
        inp_i = self.input_i(current_inputs).view(current_inputs.size(0), num_vars, -1)
        inp_n = self.input_n(current_inputs).view(current_inputs.size(0), num_vars, -1)
        r = torch.sigmoid(inp_r + self.hidden_r(agg_msgs))
        i = torch.sigmoid(inp_i + self.hidden_i(agg_msgs))
        n = torch.tanh(inp_n + r*self.hidden_h(agg_msgs))
        current_hidden = (1 - i)*n + i*current_hidden

        # Output MLP
        pred = F.dropout(F.relu(self.out_fc1(current_hidden)), p=self.dropout_prob)
        pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob)
        pred = self.out_fc3(pred)

        pred = current_inputs + pred
        hidden = hidden.clone()
        hidden[:, node_inds] = current_hidden
        pred_all = torch.zeros(inputs.size(0), max_num_vars, inputs.size(-1), device=inputs.device)
        pred_all[0, node_inds] = pred

        return pred_all, hidden

In [8]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np


class RecurrentBaseline(nn.Module):
    def __init__(self, params):
        super(RecurrentBaseline, self).__init__()
        self.num_vars = params['num_vars']
        self.teacher_forcing_steps = params.get('teacher_forcing_steps', -1)
        self.nll_loss_type = params.get('nll_loss_type', 'crossent')
        self.prior_variance = params.get('prior_variance')
        self.normalize_nll = params.get('normalize_nll', False)
        self.normalize_nll_per_var = params.get('normalize_nll_per_var', False)
        self.anneal_teacher_forcing = params.get('anneal_teacher_forcing', False)
        self.val_teacher_forcing_steps = params.get('val_teacher_forcing_steps', -1)
        self.kl_coef = 0
        self.steps = 0
    
    def single_step_forward(self, inputs, hidden):
        raise NotImplementedError

    def get_initial_hidden(self, inputs):
        raise NotImplementedError

    def normalize_inputs(self, inputs):
        return inputs
    
    def calculate_loss(self, inputs, is_train=False, teacher_forcing=True, return_logits=False, use_prior_logits=False, normalized_inputs=None):
        hidden = self.get_initial_hidden(inputs)
        num_time_steps = inputs.size(1)
        all_predictions = []
        if not is_train:
            teacher_forcing_steps = self.val_teacher_forcing_steps
        else:
            teacher_forcing_steps = self.teacher_forcing_steps
        for step in range(num_time_steps-1):
            if (teacher_forcing and (teacher_forcing_steps == -1 or step < teacher_forcing_steps)) or step == 0:
                current_inputs = inputs[:, step]
            else:
                current_inputs = predictions
            predictions, hidden = self.single_step_forward(current_inputs, hidden)
            all_predictions.append(predictions)
        all_predictions = torch.stack(all_predictions, dim=1)
        target = inputs[:, 1:, :, :]
        loss_nll = self.nll(all_predictions, target)
        loss_kl = torch.zeros_like(loss_nll)
        loss = loss_nll.mean()
        if return_logits:
            return loss, loss_nll, loss_kl, None, all_predictions
        else:
            return loss, loss_nll, loss_kl

    def predict_future(self, inputs, prediction_steps, return_everything=False):
        burn_in_timesteps = inputs.size(1)
        hidden = self.get_initial_hidden(inputs)
        all_predictions = []
        for step in range(burn_in_timesteps-1):
            current_inputs = inputs[:, step]
            predictions, hidden = self.single_step_forward(current_inputs, hidden)
            if return_everything:
                all_predictions.append(predictions)
        predictions = inputs[:, burn_in_timesteps-1]
        for step in range(prediction_steps):
            predictions, hidden = self.single_step_forward(predictions, hidden)
            all_predictions.append(predictions)
        
        predictions = torch.stack(all_predictions, dim=1)
        return predictions

    def copy_states(self, state):
        if isinstance(state, tuple) or isinstance(state, list):
            current_state = (state[0].clone(), state[1].clone())
        else:
            current_state = state.clone()
        return current_state

    def merge_hidden(self, hidden):
        if isinstance(hidden[0], tuple) or isinstance(hidden[0], list):
            result0 = torch.cat([x[0] for x in hidden], dim=0)
            result1 = torch.cat([x[1] for x in hidden], dim=0)
            return (result0, result1)
        else:
            return torch.cat(hidden, dim=0)

    def predict_future_fixedwindow(self, inputs, burn_in_steps, prediction_steps, batch_size):
        hidden = self.get_initial_hidden(inputs)
        for step in range(burn_in_steps-1):
            current_inputs = inputs[:, step]
            predictions, hidden = self.single_step_forward(current_inputs, hidden)
        all_timestep_preds = []
        for window_ind in range(burn_in_steps - 1, inputs.size(1)-1, batch_size):
            current_batch_preds = []
            states = []
            for step in range(batch_size):
                if window_ind + step >= inputs.size(1):
                    break
                predictions = inputs[:, window_ind + step]
                predictions, hidden = self.single_step_forward(predictions, hidden)
                current_batch_preds.append(predictions)
                tmp_decoder = self.copy_states(hidden)
                states.append(tmp_decoder)
            batch_hidden = self.merge_hidden(states)
            current_batch_preds = torch.cat(current_batch_preds, 0)
            current_timestep_preds = [current_batch_preds]
            for step in range(prediction_steps - 1):
                current_batch_preds, batch_hidden = self.single_step_forward(current_batch_preds, batch_hidden)
                current_timestep_preds.append(current_batch_preds)
            all_timestep_preds.append(torch.stack(current_timestep_preds, dim=1))
        results = torch.cat(all_timestep_preds, dim=0)
        return results.unsqueeze(0)

    def nll(self, preds, target):
        if self.nll_loss_type == 'crossent':
            return self.nll_crossent(preds, target)
        elif self.nll_loss_type == 'gaussian':
            return self.nll_gaussian(preds, target)
        elif self.nll_loss_type == 'poisson':
            return self.nll_poisson(preds, target)

    def nll_gaussian(self, preds, target, add_const=False):
        neg_log_p = ((preds - target) ** 2 / (2 * self.prior_variance))
        const = 0.5 * np.log(2 * np.pi * self.prior_variance)
        #neg_log_p += const
        if self.normalize_nll_per_var:
            return neg_log_p.sum() / (target.size(0) * target.size(2))
        elif self.normalize_nll:
            return (neg_log_p.sum(-1) + const).view(preds.size(0), -1).mean(dim=1)
        else:
            return neg_log_p.view(target.size(0), -1).sum() / (target.size(1))


    def nll_crossent(self, preds, target):
        if self.normalize_nll:
            return nn.BCEWithLogitsLoss(reduction='none')(preds, target).view(preds.size(0), -1).mean(dim=1)
        else:
            return nn.BCEWithLogitsLoss(reduction='none')(preds, target).view(preds.size(0), -1).sum(dim=1)

    def nll_poisson(self, preds, target):
        if self.normalize_nll:
            return nn.PoissonNLLLoss(reduction='none')(preds, target).view(preds.size(0), -1).mean(dim=1)
        else:
            return nn.PoissonNLLLoss(reduction='none')(preds, target).view(preds.size(0), -1).sum(dim=1)

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))


class SingleRNNBaseline(RecurrentBaseline):
    def __init__(self, params):
        super(SingleRNNBaseline, self).__init__(params)
        self.n_hid = n_hid = params['decoder_hidden']
        out_size = params['input_size']
        do_prob = params['decoder_dropout']
        input_size = params['input_size']
        self.num_vars = num_vars = params['num_vars']
        self.rnn_type = params['decoder_rnn_type']
        if self.rnn_type == 'lstm':
            self.rnn = nn.LSTMCell(input_size, n_hid)
        elif self.rnn_type == 'gru':
            self.rnn = nn.GRUCell(input_size, n_hid)
        self.out = nn.Sequential(
            nn.Linear(n_hid, n_hid),
            nn.ReLU(inplace=True),
            nn.Dropout(do_prob),
            nn.Linear(n_hid, n_hid),
            nn.ReLU(inplace=True),
            nn.Dropout(do_prob),
            nn.Linear(n_hid, out_size),
        )

    def get_initial_hidden(self, inputs):
        return None
        '''
        if self.rnn_type == 'lstm':
            raise NotImplementedError
            return torch.zeros(inputs.size(0)*inputs.size(2), self.n_hid, device=inputs.device)
        else:
            return torch.zeros(inputs.size(0)*inputs.size(2), self.n_hid, device=inputs.device)
        '''

    def single_step_forward(self, inputs, hidden):
        # Input Size: [batch, num_vars, input_size]
        # Hidden Size: [batch, num_vars, rnn_hidden]
        tmp_inp = inputs.reshape(-1, inputs.size(-1))
        hidden = self.rnn(tmp_inp)
        if self.rnn_type == 'lstm':
            tmp = hidden[0].view(inputs.size(0), inputs.size(1), -1)
        else:
            tmp = hidden.view(inputs.size(0), inputs.size(1), -1)
        outputs = inputs + self.out(tmp)
        return outputs, hidden

class JointRNNBaseline(RecurrentBaseline):
    def __init__(self, params):
        super(JointRNNBaseline, self).__init__(params)
        self.n_hid = n_hid = params['decoder_hidden']
        do_prob = params['decoder_dropout']
        self.num_vars = num_vars = params['num_vars']
        out_size = input_size = params['input_size']*num_vars
        self.rnn_type = params['decoder_rnn_type']
        if self.rnn_type == 'lstm':
            self.rnn = nn.LSTMCell(input_size, n_hid)
        elif self.rnn_type == 'gru':
            self.rnn = nn.GRUCell(input_size, n_hid)
        self.out = nn.Sequential(
            nn.Linear(n_hid, n_hid),
            nn.ReLU(inplace=True),
            nn.Dropout(do_prob),
            nn.Linear(n_hid, n_hid),
            nn.ReLU(inplace=True),
            nn.Dropout(do_prob),
            nn.Linear(n_hid, out_size),
        )

    def get_initial_hidden(self, inputs):
        return None
        '''
        if self.rnn_type == 'lstm':
            raise NotImplementedError
            return torch.zeros(inputs.size(0)*inputs.size(2), self.n_hid, device=inputs.device)
        else:
            return torch.zeros(inputs.size(0)*inputs.size(2), self.n_hid, device=inputs.device)
        '''

    def single_step_forward(self, inputs, hidden):
        # Input Size: [batch, num_vars, input_size]
        # Hidden Size: [batch, num_vars, rnn_hidden]
        tmp_inp = inputs.view(inputs.size(0), -1)
        hidden = self.rnn(tmp_inp)
        if self.rnn_type == 'lstm':
            tmp = hidden[0]
        else:
            tmp = hidden
        outputs = inputs + self.out(tmp).view(inputs.size(0), inputs.size(1), -1)
        return outputs, hidden
 
        


class FullyConnectedBaseline(RecurrentBaseline):
    def __init__(self, params):
        super(FullyConnectedBaseline, self).__init__(params)
        n_hid = params['decoder_hidden']
        edge_types = params['num_edge_types']
        skip_first = params['skip_first']
        out_size = params['input_size']
        do_prob = params['decoder_dropout']
        input_size = params['input_size']
        self.num_vars = num_vars =  params['num_vars']

        self.msg_fc1 = nn.Linear(2*n_hid, n_hid)
        self.msg_fc2 = nn.Linear(n_hid, n_hid)
        self.msg_out_shape = n_hid
        self.skip_first_edge_type = skip_first

        self.hidden_r = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_i = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_h = nn.Linear(n_hid, n_hid, bias=False)

        self.input_r = nn.Linear(input_size, n_hid, bias=True)
        self.input_i = nn.Linear(input_size, n_hid, bias=True)
        self.input_n = nn.Linear(input_size, n_hid, bias=True)

        self.out_fc1 = nn.Linear(n_hid, n_hid)
        self.out_fc2 = nn.Linear(n_hid, n_hid)
        self.out_fc3 = nn.Linear(n_hid, out_size)

        print('Using learned recurrent interaction net decoder.')

        self.dropout_prob = do_prob

        self.num_vars = num_vars
        edges = np.ones(num_vars) - np.eye(num_vars)
        self.send_edges = np.where(edges)[0]
        self.recv_edges = np.where(edges)[1]
        self.edge2node_mat = nn.Parameter(torch.FloatTensor(encode_onehot(self.recv_edges)), requires_grad=False)

    def get_initial_hidden(self, inputs):
        return torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape, device=inputs.device)


    def single_step_forward(self, inputs, hidden):
        # Input Size: [batch, num_vars, input_size]
        # Hidden Size: [batch, num_vars, rnn_hidden]
        # Edges size: [batch, num_edges, num_edge_types]
        if self.training:
            dropout_prob = self.dropout_prob
        else:
            dropout_prob = 0.
        
        # node2edge
        receivers = hidden[:, self.recv_edges, :]
        senders = hidden[:, self.send_edges, :]

        # pre_msg: [batch, num_edges, 2*msg_out]
        pre_msg = torch.cat([receivers, senders], dim=-1)

        msg = torch.tanh(self.msg_fc1(pre_msg))
        msg = F.dropout(msg, p=dropout_prob)
        msg = torch.tanh(self.msg_fc2(msg))
        all_msgs = msg

        # This step sums all of the messages per node
        agg_msgs = all_msgs.transpose(-2, -1).matmul(self.edge2node_mat).transpose(-2, -1)
        agg_msgs = agg_msgs.contiguous() / (self.num_vars - 1) # Average

        # GRU-style gated aggregation
        inp_r = self.input_r(inputs).view(inputs.size(0), self.num_vars, -1)
        inp_i = self.input_i(inputs).view(inputs.size(0), self.num_vars, -1)
        inp_n = self.input_n(inputs).view(inputs.size(0), self.num_vars, -1)
        r = torch.sigmoid(inp_r + self.hidden_r(agg_msgs))
        i = torch.sigmoid(inp_i + self.hidden_i(agg_msgs))
        n = torch.tanh(inp_n + r*self.hidden_h(agg_msgs))
        hidden = (1 - i)*n + i*hidden

        # Output MLP
        pred = F.dropout(F.relu(self.out_fc1(hidden)), p=dropout_prob)
        pred = F.dropout(F.relu(self.out_fc2(pred)), p=dropout_prob)
        pred = self.out_fc3(pred)

        pred = inputs + pred

        return pred, hidden

In [9]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

class RecurrentBaseline_DynamicVars(nn.Module):
    def __init__(self, params):
        super(RecurrentBaseline_DynamicVars, self).__init__()
        self.teacher_forcing_steps = params.get('teacher_forcing_steps', -1)
        self.nll_loss_type = params.get('nll_loss_type', 'crossent')
        self.prior_variance = params.get('prior_variance')
        self.normalize_nll = params.get('normalize_nll', False)
        self.normalize_nll_per_var = params.get('normalize_nll_per_var', False)
        self.anneal_teacher_forcing = params.get('anneal_teacher_forcing', False)
        self.val_teacher_forcing_steps = params.get('val_teacher_forcing_steps', -1)
        self.kl_coef = 0
        self.steps = 0
    
    def single_step_forward(self, inputs, hidden):
        raise NotImplementedError

    def get_initial_hidden(self, inputs):
        raise NotImplementedError

    def normalize_inputs(self, inputs, masks):
        return inputs
    
    def calculate_loss(self, inputs, node_masks, node_inds, graph_info, is_train=False, teacher_forcing=True, return_logits=False, use_prior_logits=False, normalized_inputs=None):
        hidden = self.get_initial_hidden(inputs)
        num_time_steps = inputs.size(1)
        all_predictions = []
        if not is_train:
            teacher_forcing_steps = self.val_teacher_forcing_steps
        else:
            teacher_forcing_steps = self.teacher_forcing_steps
        for step in range(num_time_steps-1):
            if (teacher_forcing and (teacher_forcing_steps == -1 or step < teacher_forcing_steps)) or step == 0:
                current_inputs = inputs[:, step]
            else:
                current_inputs = predictions
            current_node_masks = node_masks[:, step]
            node_inds = current_node_masks.nonzero()[:, -1]
            num_edges = len(node_inds)*(len(node_inds)-1)
            current_graph_info = graph_info[0][step]
            predictions, hidden = self.single_step_forward(current_inputs, current_node_masks, current_graph_info, hidden)
            all_predictions.append(predictions)
        all_predictions = torch.stack(all_predictions, dim=1)
        target = inputs[:, 1:, :, :]
        target_masks = ((node_masks[:, :-1] == 1)*(node_masks[:, 1:] == 1)).float()
        loss_nll = self.nll(all_predictions, target, target_masks)
        loss_kl = torch.zeros_like(loss_nll)
        loss = loss_nll.mean()
        if return_logits:
            return loss, loss_nll, loss_kl, None, all_predictions
        else:
            return loss, loss_nll, loss_kl

    def predict_future(self, inputs, masks, node_inds, graph_info, burn_in_masks):
        '''
        Here, we assume the following:
        * inputs contains all of the gt inputs, including for the time steps we're predicting
        * masks keeps track of the variables that are being tracked
        * burn_in_masks is set to 1 whenever we're supposed to feed in that variable's state
          for a given time step
        '''
        total_timesteps = inputs.size(1)
        hidden = self.get_initial_hidden(inputs)
        predictions = inputs[:, 0]
        preds = []
        for step in range(total_timesteps-1):
            current_masks = masks[:, step]
            current_burn_in_masks = burn_in_masks[:, step].unsqueeze(-1).type(inputs.dtype)
            current_inps = inputs[:, step]
            current_node_inds = node_inds[0][step]
            current_graph_info = graph_info[0][step]
            decoder_inp = current_burn_in_masks*current_inps + (1-current_burn_in_masks)*predictions
            predictions, hidden = self.single_step_forward(decoder_inp, current_masks, current_graph_info, hidden)
            preds.append(predictions)
        return torch.stack(preds, dim=1)
        
    def copy_states(self, state):
        if isinstance(state, tuple) or isinstance(state, list):
            current_state = (state[0].clone(), state[1].clone())
        else:
            current_state = state.clone()
        return current_state

    def merge_hidden(self, hidden):
        if isinstance(hidden[0], tuple) or isinstance(hidden[0], list):
            result0 = torch.cat([x[0] for x in hidden], dim=0)
            result1 = torch.cat([x[1] for x in hidden], dim=0)
            return (result0, result1)
        else:
            return torch.cat(hidden, dim=0)

    def predict_future_fixedwindow(self, inputs, burn_in_steps, prediction_steps, batch_size):
        hidden = self.get_initial_hidden(inputs)
        for step in range(burn_in_steps-1):
            current_inputs = inputs[:, step]
            predictions, hidden = self.single_step_forward(current_inputs, hidden)
        all_timestep_preds = []
        for window_ind in range(burn_in_steps - 1, inputs.size(1)-1, batch_size):
            current_batch_preds = []
            states = []
            for step in range(batch_size):
                if window_ind + step >= inputs.size(1):
                    break
                predictions = inputs[:, window_ind + step]
                predictions, hidden = self.single_step_forward(predictions, hidden)
                current_batch_preds.append(predictions)
                tmp_decoder = self.copy_states(hidden)
                states.append(tmp_decoder)
            batch_hidden = self.merge_hidden(states)
            current_batch_preds = torch.cat(current_batch_preds, 0)
            current_timestep_preds = [current_batch_preds]
            for step in range(prediction_steps - 1):
                current_batch_preds, batch_hidden = self.single_step_forward(current_batch_preds, batch_hidden)
                current_timestep_preds.append(current_batch_preds)
            all_timestep_preds.append(torch.stack(current_timestep_preds, dim=1))
        results = torch.cat(all_timestep_preds, dim=0)
        return results.unsqueeze(0)

    def nll(self, preds, target, masks):
        if self.nll_loss_type == 'crossent':
            return self.nll_crossent(preds, target, masks)
        elif self.nll_loss_type == 'gaussian':
            return self.nll_gaussian(preds, target, masks)
        elif self.nll_loss_type == 'poisson':
            return self.nll_poisson(preds, target, masks)

    def nll_gaussian(self, preds, target, masks, add_const=False):
        neg_log_p = ((preds - target) ** 2 / (2 * self.prior_variance))*masks.unsqueeze(-1)
        const = 0.5 * np.log(2 * np.pi * self.prior_variance)
        #neg_log_p += const
        if self.normalize_nll_per_var:
            raise NotImplementedError()
        elif self.normalize_nll:
            return (neg_log_p.sum(-1) + const*masks).view(preds.size(0), -1).sum(dim=-1)/(masks.view(masks.size(0), -1).sum(dim=1)+1e-8)
        else:
            raise NotImplementedError()


    def nll_crossent(self, preds, target, masks):
        if self.normalize_nll:
            loss = nn.BCEWithLogitsLoss(reduction='none')(preds, target)
            return (loss*masks.unsqueeze(-1)).view(preds.size(0), -1).sum(dim=-1)/(masks.view(masks.size(0), -1).sum(dim=1))
        else:
            raise NotImplementedError()

    def nll_poisson(self, preds, target, masks):
        if self.normalize_nll:
            loss = nn.PoissonNLLLoss(reduction='none')(preds, target)
            return (loss*masks.unsqueeze(-1)).view(preds.size(0), -1).sum(dim=-1)/(masks.view(masks.size(0), -1).sum(dim=1))
        else:
            raise NotImplementedError

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))


class FullyConnectedBaseline_DynamicVars(RecurrentBaseline_DynamicVars):
    def __init__(self, params):
        super(FullyConnectedBaseline_DynamicVars, self).__init__(params)
        n_hid = params['decoder_hidden']
        out_size = params['input_size']
        do_prob = params['decoder_dropout']
        input_size = params['input_size']

        self.msg_fc1 = nn.Linear(2*n_hid, n_hid)
        self.msg_fc2 = nn.Linear(n_hid, n_hid)
        self.msg_out_shape = n_hid

        self.hidden_r = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_i = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_h = nn.Linear(n_hid, n_hid, bias=False)

        self.input_r = nn.Linear(input_size, n_hid, bias=True)
        self.input_i = nn.Linear(input_size, n_hid, bias=True)
        self.input_n = nn.Linear(input_size, n_hid, bias=True)

        self.out_fc1 = nn.Linear(n_hid, n_hid)
        self.out_fc2 = nn.Linear(n_hid, n_hid)
        self.out_fc3 = nn.Linear(n_hid, out_size)

        print('Using learned recurrent interaction net decoder.')

        self.dropout_prob = do_prob

    def get_initial_hidden(self, inputs):
        return torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape, device=inputs.device)


    def single_step_forward(self, inputs, node_masks, graph_info, hidden):
        # Input Size: [batch, num_vars, input_size]
        # Hidden Size: [batch, num_vars, rnn_hidden]
        # Edges size: [batch, num_edges, num_edge_types]
        if self.training:
            dropout_prob = self.dropout_prob
        else:
            dropout_prob = 0.
        
        max_num_vars = inputs.size(1)
        node_inds = node_masks.nonzero()[:, -1]
        current_hidden = hidden[:, node_inds]
        current_inputs = inputs[:, node_inds]
        num_vars = current_hidden.size(1)
        if num_vars > 1:
            send_edges, recv_edges, edge2node_inds = graph_info
            send_edges, recv_edges, edge2node_inds = send_edges.cuda(non_blocking=True), recv_edges.cuda(non_blocking=True), edge2node_inds.cuda(non_blocking=True)

            receivers = current_hidden[:, recv_edges]
            senders = current_hidden[:, send_edges]

            # pre_msg: [batch, num_edges, 2*msg_out]
            pre_msg = torch.cat([receivers, senders], dim=-1)

            msg = torch.tanh(self.msg_fc1(pre_msg))
            msg = F.dropout(msg, p=dropout_prob)
            msg = torch.tanh(self.msg_fc2(msg))
            all_msgs = msg

            incoming = all_msgs[:, edge2node_inds[:, 0], :].clone()
            for i in range(1, edge2node_inds.size(1)):
                incoming += all_msgs[:, edge2node_inds[:, i], :]
            agg_msgs = incoming/(num_vars-1)
        elif num_vars == 0:
            pred_all = torch.zeros(inputs.size(0), max_num_vars, inputs.size(-1), device=inputs.device)
            return pred_all, hidden
        else:
            agg_msgs = torch.zeros(current_inputs.size(0), num_vars, self.msg_out_shape, device=inputs.device)


        # GRU-style gated aggregation
        inp_r = self.input_r(current_inputs).view(current_inputs.size(0), num_vars, -1)
        inp_i = self.input_i(current_inputs).view(current_inputs.size(0), num_vars, -1)
        inp_n = self.input_n(current_inputs).view(current_inputs.size(0), num_vars, -1)
        r = torch.sigmoid(inp_r + self.hidden_r(agg_msgs))
        i = torch.sigmoid(inp_i + self.hidden_i(agg_msgs))
        n = torch.tanh(inp_n + r*self.hidden_h(agg_msgs))
        current_hidden = (1 - i)*n + i*current_hidden

        # Output MLP
        pred = F.dropout(F.relu(self.out_fc1(current_hidden)), p=dropout_prob)
        pred = F.dropout(F.relu(self.out_fc2(pred)), p=dropout_prob)
        pred = self.out_fc3(pred)

        pred = current_inputs + pred
        hidden = hidden.clone()
        hidden[:, node_inds] = current_hidden
        pred_all = torch.zeros(inputs.size(0), max_num_vars, inputs.size(-1), device=inputs.device)
        pred_all[0, node_inds] = pred

        return pred_all, hidden

In [10]:

import os


def build_model(params):
    if params['model_type'] == 'dnri':
        dynamic_vars = params.get('dynamic_vars', False)
        if dynamic_vars:
            model = DNRI_DynamicVars(params)
        else:
            model = DNRI(params)
        print("dNRI MODEL: ",model)
    elif params['model_type'] == 'fc_baseline':
        dynamic_vars = params.get('dynamic_vars', False)
        if dynamic_vars:
            model = FullyConnectedBaseline_DynamicVars(params)
        else:
            model = FullyConnectedBaseline(params)
        print("FCBaseline: ",model)
    else:
        num_vars = params['num_vars']
        graph_type = params['graph_type']

        # Build Encoder
        encoder = RefMLPEncoder(params)
        print("ENCODER: ",encoder)

        # Build Decoder
        decoder = GraphRNNDecoder(params)
        print("DECODER: ",decoder)
        if graph_type == 'dynamic':
            model = DynamicNRI(num_vars, encoder, decoder, params)
        else:
            model = StaticNRI(num_vars, encoder, decoder, params)

    if params['load_best_model']:
        print("LOADING BEST MODEL")
        path = os.path.join(params['working_dir'], 'best_model')
        model.load(path)
    elif params['load_model']:
        print("LOADING MODEL FROM SPECIFIED PATH")
        model.load(params['load_model'])
    if params['gpu']:
        model.cuda()
    return model



In [11]:
import numpy as np


# Code from NRI.
def normalize(data, data_max, data_min):
	return (data - data_min) * 2 / (data_max - data_min) - 1


def unnormalize(data, data_max, data_min):
	return (data + 1) * (data_max - data_min) / 2. + data_min


def get_edge_inds(num_vars):
	edges = []
	for i in range(num_vars):
		for j in range(num_vars):
			if i == j:
				continue
			edges.append([i, j])
	return edges

In [12]:
import numpy as np
import torch
from torch.utils.data import Dataset


import argparse, os


class SmallSynthData(Dataset):
    def __init__(self, data_path, mode, params):
        self.mode = mode
        self.data_path = data_path
        if self.mode == 'train':
            path = os.path.join(data_path, 'train_feats')
            edge_path = os.path.join(data_path, 'train_edges')
        elif self.mode == 'val':
            path = os.path.join(data_path, 'val_feats')
            edge_path = os.path.join(data_path, 'val_edges')
        elif self.mode == 'test':
            path = os.path.join(data_path, 'test_feats')
            edge_path = os.path.join(data_path, 'test_edges')
        self.feats = torch.load(path)
        self.edges = torch.load(edge_path)
        self.same_norm = params['same_data_norm']
        self.no_norm = params['no_data_norm']
        if not self.no_norm:
            self._normalize_data()

    def _normalize_data(self):
        train_data = torch.load(os.path.join(self.data_path, 'train_feats'))
        if self.same_norm:
            self.feat_max = train_data.max()
            self.feat_min = train_data.min()
            self.feats = (self.feats - self.feat_min)*2/(self.feat_max-self.feat_min) - 1
        else:
            self.loc_max = train_data[:, :, :, :2].max()
            self.loc_min = train_data[:, :, :, :2].min()
            self.vel_max = train_data[:, :, :, 2:].max()
            self.vel_min = train_data[:, :, :, 2:].min()
            self.feats[:,:,:, :2] = (self.feats[:,:,:,:2]-self.loc_min)*2/(self.loc_max - self.loc_min) - 1
            self.feats[:,:,:,2:] = (self.feats[:,:,:,2:]-self.vel_min)*2/(self.vel_max-self.vel_min)-1

    def unnormalize(self, data):
        if self.no_norm:
            return data
        elif self.same_norm:
            return (data + 1) * (self.feat_max - self.feat_min) / 2. + self.feat_min
        else:
            result1 = (data[:, :, :, :2] + 1) * (self.loc_max - self.loc_min) / 2. + self.loc_min
            result2 = (data[:, :, :, 2:] + 1) * (self.vel_max - self.vel_min) / 2. + self.vel_min
            return np.concatenate([result1, result2], axis=-1)


    def __getitem__(self, idx):
        return {'inputs': self.feats[idx], 'edges':self.edges[idx]}

    def __len__(self):
        return len(self.feats)
        


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', required=True)
    parser.add_argument('--num_train', type=int, default=100)
    parser.add_argument('--num_val', type=int, default=100)
    parser.add_argument('--num_test', type=int, default=100)
    parser.add_argument('--num_time_steps', type=int, default=50)
    parser.add_argument('--pull_factor', type=float, default=0.1)
    parser.add_argument('--push_factor', type=float, default=0.05)

    args = parser.parse_args()
    np.random.seed(1)
    all_data = []
    all_edges = []
    num_sims = args.num_train + args.num_val + args.num_test
    flip_count = 0
    total_steps = 0
    for sim in range(num_sims):
        p1_loc = np.random.uniform(-2, -1, size=(2))
        p1_vel = np.random.uniform(0.05, 0.1, size=(2))
        p2_loc = np.random.uniform(1, 2, size=(2))
        p2_vel = np.random.uniform(-0.05, -0.1, size=(2))
        p3_loc = np.random.uniform(-1, 1, size=(2))
        p3_vel = np.random.uniform(-0.05, 0.05, size=(2))

        current_feats = []
        current_edges = []
        for time_step in range(args.num_time_steps):
            current_edge = np.array([0,0,0,0,0,0])
            current_edges.append(current_edge)
            if np.linalg.norm(p3_loc - p1_loc) < 1:
                norm = np.linalg.norm(p3_loc - p1_loc)
                coef = 1 - norm
                dir_1 = (p3_loc - p1_loc)/norm
                p3_vel += args.push_factor*coef*dir_1
                current_edge[1] = 1
            if np.linalg.norm(p3_loc - p2_loc) < 1:
                norm = np.linalg.norm(p3_loc - p2_loc)
                coef = 1 - norm
                dir_2 = (p3_loc - p2_loc)/norm
                p3_vel += args.push_factor*coef*dir_2
                current_edge[3] = 1

            p1_loc += p1_vel
            p2_loc += p2_vel
            p3_loc += p3_vel
            p1_feat = np.concatenate([p1_loc, p1_vel])
            p2_feat = np.concatenate([p2_loc, p2_vel])
            p3_feat = np.concatenate([p3_loc, p3_vel])
            new_feat = np.stack([p1_feat, p2_feat, p3_feat])
            current_feats.append(new_feat)
        all_data.append(np.stack(current_feats))
        all_edges.append(np.stack(current_edges))
        
    all_data = np.stack(all_data)
    train_data = torch.FloatTensor(all_data[:args.num_train])
    val_data = torch.FloatTensor(all_data[args.num_train:args.num_train+args.num_val])
    test_data = torch.FloatTensor(all_data[args.num_train+args.num_val:])
    train_path = os.path.join(args.output_dir, 'train_feats')
    torch.save(train_data, train_path)
    val_path = os.path.join(args.output_dir, 'val_feats')
    torch.save(val_data, val_path)
    test_path = os.path.join(args.output_dir, 'test_feats')
    torch.save(test_data, test_path)

    train_edges = torch.FloatTensor(all_edges[:args.num_train])
    val_edges = torch.FloatTensor(all_edges[args.num_train:args.num_train+args.num_val])
    test_edges = torch.FloatTensor(all_edges[args.num_train+args.num_val:])
    train_path = os.path.join(args.output_dir, 'train_edges')
    torch.save(train_edges, train_path)
    val_path = os.path.join(args.output_dir, 'val_edges')
    torch.save(val_edges, val_path)
    test_path = os.path.join(args.output_dir, 'test_edges')
    torch.save(test_edges, test_path)

usage: ipykernel_launcher.py [-h] --output_dir OUTPUT_DIR
                             [--num_train NUM_TRAIN] [--num_val NUM_VAL]
                             [--num_test NUM_TEST]
                             [--num_time_steps NUM_TIME_STEPS]
                             [--pull_factor PULL_FACTOR]
                             [--push_factor PUSH_FACTOR]
ipykernel_launcher.py: error: the following arguments are required: --output_dir


SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [13]:
import torch
import numpy as np 
import random


def seed(seed_val):
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    random.seed(seed_val)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_val)

In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter

import os


def build_scheduler(opt, params):
    lr_decay_factor = params.get('lr_decay_factor')
    lr_decay_steps = params.get('lr_decay_steps')
    if lr_decay_factor:
        return torch.optim.lr_scheduler.StepLR(opt, lr_decay_steps, lr_decay_factor)
    else:
        return None


class build_writers:
    def __init__(self, working_dir, is_test=False):
        self.writer_dir = os.path.join(working_dir, 'logs/')
        self.is_test = is_test

    def __enter__(self):
        train_writer_dir = os.path.join(self.writer_dir, 'train')
        val_writer_dir = os.path.join(self.writer_dir, 'val')
        self.train_writer = SummaryWriter(train_writer_dir)
        self.val_writer = SummaryWriter(val_writer_dir)
        if self.is_test:
            test_writer_dir = os.path.join(self.writer_dir, 'test')
            self.test_writer = SummaryWriter(test_writer_dir)
            return self.train_writer, self.val_writer, self.test_writer
        else:
            return self.train_writer, self.val_writer

    def __exit__(self, type, value, traceback):
        self.train_writer.close()
        self.val_writer.close()
        if self.is_test:
            self.test_writer.close()

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter



import time, os

import random
import numpy as np

def train(model, train_data, val_data, params, train_writer, val_writer):
    gpu = params.get('gpu', False)
    batch_size = params.get('batch_size', 1000)
    val_batch_size = params.get('val_batch_size', batch_size)
    if val_batch_size is None:
        val_batch_size = batch_size
    accumulate_steps = params.get('accumulate_steps')
    training_scheduler = params.get('training_scheduler', None)
    num_epochs = params.get('num_epochs', 100)
    val_interval = params.get('val_interval', 1)
    val_start = params.get('val_start', 0)
    clip_grad = params.get('clip_grad', None)
    clip_grad_norm = params.get('clip_grad_norm', None)
    normalize_nll = params.get('normalize_nll', False)
    normalize_kl = params.get('normalize_kl', False)
    tune_on_nll = params.get('tune_on_nll', False)
    verbose = params.get('verbose', False)
    val_teacher_forcing = params.get('val_teacher_forcing', False)
    continue_training = params.get('continue_training', False)
    train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    val_data_loader = DataLoader(val_data, batch_size=val_batch_size)
    lr = params['lr']
    wd = params.get('wd', 0.)
    mom = params.get('mom', 0.)
    
    model_params = [param for param in model.parameters() if param.requires_grad]
    if params.get('use_adam', False):
        opt = torch.optim.Adam(model_params, lr=lr, weight_decay=wd)
    else:
        opt = torch.optim.SGD(model_params, lr=lr, weight_decay=wd, momentum=mom)

    working_dir = params['working_dir']
    best_path = os.path.join(working_dir, 'best_model')
    checkpoint_dir = os.path.join(working_dir, 'model_checkpoint')
    training_path = os.path.join(working_dir, 'training_checkpoint')
    if continue_training:
        print("RESUMING TRAINING")
        model.load(checkpoint_dir)
        train_params = torch.load(training_path)
        start_epoch = train_params['epoch']
        opt.load_state_dict(train_params['optimizer'])
        best_val_result = train_params['best_val_result']
        best_val_epoch = train_params['best_val_epoch']
        print("STARTING EPOCH: ",start_epoch)
    else:
        start_epoch = 1
        best_val_epoch = -1
        best_val_result = 10000000
    
    training_scheduler = train_utils.build_scheduler(opt, params)
    end = start = 0 
    misc.seed(1)
    for epoch in range(start_epoch, num_epochs+1):
        print("EPOCH", epoch, (end-start))
        model.train()
        model.train_percent = epoch / num_epochs
        start = time.time() 
        for batch_ind, batch in enumerate(train_data_loader):
            inputs = batch['inputs']
            if gpu:
                inputs = inputs.cuda(non_blocking=True)
            loss, loss_nll, loss_kl, logits, _ = model.calculate_loss(inputs, is_train=True, return_logits=True)
            loss.backward()
            if verbose:
                print("\tBATCH %d OF %d: %f, %f, %f"%(batch_ind+1, len(train_data_loader), loss.item(), loss_nll.mean().item(), loss_kl.mean().item()))
            if accumulate_steps == -1 or (batch_ind+1)%accumulate_steps == 0:
                if verbose and accumulate_steps > 0:
                    print("\tUPDATING WEIGHTS")
                if clip_grad is not None:
                    nn.utils.clip_grad_value_(model.parameters(), clip_grad)
                elif clip_grad_norm is not None:
                    nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)        
                opt.step()
                opt.zero_grad()
                if accumulate_steps > 0 and accumulate_steps > len(train_data_loader) - batch_ind - 1:
                    break
            
        if training_scheduler is not None:
            training_scheduler.step()
        
        if train_writer is not None:
            train_writer.add_scalar('loss', loss.item(), global_step=epoch)
            if normalize_nll:
                train_writer.add_scalar('NLL', loss_nll.mean().item(), global_step=epoch)
            else:
                train_writer.add_scalar('NLL', loss_nll.mean().item()/(inputs.size(1)*inputs.size(2)), global_step=epoch)
            
            train_writer.add_scalar("KL Divergence", loss_kl.mean().item(), global_step=epoch)
        model.eval()
        opt.zero_grad()

        total_nll = 0
        total_kl = 0
        if verbose:
            print("COMPUTING VAL LOSSES")
        with torch.no_grad():
            for batch_ind, batch in enumerate(val_data_loader):
                inputs = batch['inputs']
                if gpu:
                    inputs = inputs.cuda(non_blocking=True)
                loss, loss_nll, loss_kl, logits, _ = model.calculate_loss(inputs, is_train=False, teacher_forcing=val_teacher_forcing, return_logits=True)
                total_kl += loss_kl.sum().item()
                total_nll += loss_nll.sum().item()
                if verbose:
                    print("\tVAL BATCH %d of %d: %f, %f"%(batch_ind+1, len(val_data_loader), loss_nll.mean(), loss_kl.mean()))
            
        total_kl /= len(val_data)
        total_nll /= len(val_data)
        total_loss = model.kl_coef*total_kl + total_nll #TODO: this is a thing you fixed
        if val_writer is not None:
            val_writer.add_scalar('loss', total_loss, global_step=epoch)
            val_writer.add_scalar("NLL", total_nll, global_step=epoch)
            val_writer.add_scalar("KL Divergence", total_kl, global_step=epoch)
        if tune_on_nll:
            tuning_loss = total_nll
        else:
            tuning_loss = total_loss
        if tuning_loss < best_val_result:
            best_val_epoch = epoch
            best_val_result = tuning_loss
            print("BEST VAL RESULT. SAVING MODEL...")
            model.save(best_path)
        model.save(checkpoint_dir)
        torch.save({
                    'epoch':epoch+1,
                    'optimizer':opt.state_dict(),
                    'best_val_result':best_val_result,
                    'best_val_epoch':best_val_epoch,
                   }, training_path)
        print("EPOCH %d EVAL: "%epoch)
        print("\tCURRENT VAL LOSS: %f"%tuning_loss)
        print("\tBEST VAL LOSS:    %f"%best_val_result)
        print("\tBEST VAL EPOCH:   %d"%best_val_epoch)
        end = time.time()

    

In [None]:
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch

import os
import numpy as np


def eval_forward_prediction(model, dataset, burn_in_steps, forward_pred_steps, params, return_total_errors=False):
    dataset.return_edges = False
    gpu = params.get('gpu', False)
    batch_size = params.get('batch_size', 1000)
    data_loader = DataLoader(dataset, batch_size=batch_size, pin_memory=gpu)
    model.eval()
    total_se = 0
    batch_count = 0
    all_errors = []
    for batch_ind, batch in enumerate(data_loader):
        inputs = batch['inputs']
        with torch.no_grad():
            model_inputs = inputs[:, :burn_in_steps]
            gt_predictions = inputs[:, burn_in_steps:burn_in_steps+forward_pred_steps]
            if gpu:
                model_inputs = model_inputs.cuda(non_blocking=True)
            model_preds = model.predict_future(model_inputs, forward_pred_steps).cpu()
            batch_count += 1
            if return_total_errors:
                all_errors.append(F.mse_loss(model_preds, gt_predictions, reduction='none').view(model_preds.size(0), model_preds.size(1), -1).mean(dim=-1))
            else:
                total_se += F.mse_loss(model_preds, gt_predictions, reduction='none').view(model_preds.size(0), model_preds.size(1), -1).mean(dim=-1).sum(dim=0)
    if return_total_errors:
        return torch.cat(all_errors, dim=0)
    else:
        return total_se / len(dataset)

def eval_forward_prediction_fixedwindow(model, dataset, burn_in_steps, forward_pred_steps, params, return_total_errors=False):
    dataset.return_edges = False
    gpu = params.get('gpu', False)
    batch_size = params.get('batch_size', 1000)
    data_loader = DataLoader(dataset, batch_size=1)
    model.eval()
    total_se = 0
    batch_count = 0
    all_errors = []
    total_count = torch.zeros(forward_pred_steps)
    for batch_ind, batch in enumerate(data_loader):
        inputs = batch['inputs']
        print("BATCH IND %d OF %d"%(batch_ind+1, len(data_loader)))
        with torch.no_grad():

            if gpu:
                inputs = inputs.cuda(non_blocking=True)
            model_preds = model.predict_future_fixedwindow(inputs, burn_in_steps, forward_pred_steps, batch_size).cpu()
            for window_ind in range(model_preds.size(1)):
                current_preds = model_preds[:, window_ind]
                start_ind = burn_in_steps + window_ind
                gt_preds = inputs[:, start_ind:start_ind + forward_pred_steps].cpu()
                if gt_preds.size(1) < forward_pred_steps:
                    mask = torch.cat([torch.ones(gt_preds.size(1)), torch.zeros(forward_pred_steps - gt_preds.size(1))])
                    gt_preds = torch.cat([gt_preds, torch.zeros(gt_preds.size(0), forward_pred_steps-gt_preds.size(1), gt_preds.size(2), gt_preds.size(3))], dim=1)
                else:
                    mask = torch.ones(forward_pred_steps)
                total_se += F.mse_loss(current_preds, gt_preds, reduction='none').view(current_preds.size(0), current_preds.size(1), -1).mean(dim=-1).sum(dim=0).cpu()*mask
                total_count += mask

    return total_se / total_count


def eval_forward_prediction_dynamicvars(model, dataset, params):
    gpu = params.get('gpu', False)
    batch_size = params.get('batch_size', 1000)
    collate_fn = params.get('collate_fn', None)
    data_loader = DataLoader(dataset, batch_size=1, pin_memory=gpu, collate_fn=collate_fn)
    model.eval()
    total_se = 0
    batch_count = 0
    final_errors = torch.zeros(0)
    final_counts = torch.zeros(0)
    bad_count = 0
    for batch_ind, batch in enumerate(data_loader):
        print("DATA POINT ",batch_ind)
        inputs = batch['inputs']
        gt_preds = inputs[0, 1:]
        masks = batch['masks']
        node_inds = batch.get('node_inds', None)
        graph_info = batch.get('graph_info', None)
        burn_in_masks = batch['burn_in_masks']
        pred_masks = (masks.float() - burn_in_masks)[0, 1:]
        with torch.no_grad():
            if gpu:
                inputs = inputs.cuda(non_blocking=True)
                masks = masks.cuda(non_blocking=True)
                burn_in_masks = burn_in_masks.cuda(non_blocking=True)
            model_preds = model.predict_future(inputs, masks, node_inds, graph_info, burn_in_masks)[0].cpu()
            max_len = pred_masks.sum(dim=0).max().int().item()
            if max_len > len(final_errors):
                final_errors = torch.cat([final_errors, torch.zeros(max_len - len(final_errors))])
                final_counts = torch.cat([final_counts, torch.zeros(max_len - len(final_counts))])
            for var in range(masks.size(-1)):
                var_gt = gt_preds[:, var]
                var_preds = model_preds[:, var]
                var_pred_masks = pred_masks[:, var]
                var_losses = F.mse_loss(var_preds, var_gt, reduction='none').mean(dim=-1)*var_pred_masks
                tmp_inds = torch.nonzero(var_pred_masks)
                if len(tmp_inds) == 0:
                    continue
                for i in range(len(tmp_inds)-1):
                    if tmp_inds[i+1] - tmp_inds[i] != 1:
                        bad_count += 1
                        break
                num_entries = var_pred_masks.sum().int().item()
                final_errors[:num_entries] += var_losses[tmp_inds[0].item():tmp_inds[0].item()+num_entries]
                final_counts[:num_entries] += var_pred_masks[tmp_inds[0]:tmp_inds[0]+num_entries]
    print("FINAL BAD COUNT: ",bad_count)
    return final_errors/final_counts, final_counts

In [None]:


from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import torch

import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.colors as mcolors
import numpy as np

def eval_edges(model, dataset, params):

    gpu = params.get('gpu', False)
    batch_size = params.get('batch_size', 1000)
    eval_metric = params.get('eval_metric')
    num_edge_types = params['num_edge_types']
    skip_first = params['skip_first']
    data_loader = DataLoader(dataset, batch_size=batch_size, pin_memory=gpu)
    full_edge_count = 0.
    model.eval()
    correct_edges = 0.
    edge_count = 0.
    correct_0_edges = 0.
    edge_0_count = 0.
    correct_1_edges = 0.
    edge_1_count = 0.

    correct = num_predicted = num_gt = 0
    all_edges = []
    for batch_ind, batch in enumerate(data_loader):
        inputs = batch['inputs']
        gt_edges = batch['edges'].long()
        with torch.no_grad():
            if gpu:
                inputs = inputs.cuda(non_blocking=True)
                gt_edges = gt_edges.cuda(non_blocking=True)

            _, _, _, edges, _ = model.calculate_loss(inputs, is_train=False, return_logits=True)
            edges = edges.argmax(dim=-1)
            all_edges.append(edges.cpu())
            if len(edges.shape) == 3 and len(gt_edges.shape) == 2:
                gt_edges = gt_edges.unsqueeze(1).expand(gt_edges.size(0), edges.size(1), gt_edges.size(1))
            elif len(gt_edges.shape) == 3 and len(edges.shape) == 2:
                edges = edges.unsqueeze(1).expand(edges.size(0), gt_edges.size(1), edges.size(1))
            if edges.size(1) == gt_edges.size(1) - 1:
                gt_edges = gt_edges[:, :-1]
            edge_count += edges.numel()
            full_edge_count += gt_edges.numel()
            correct_edges += ((edges == gt_edges)).sum().item()
            edge_0_count += (gt_edges == 0).sum().item()
            edge_1_count += (gt_edges == 1).sum().item()
            correct_0_edges += ((edges == gt_edges)*(gt_edges == 0)).sum().item()
            correct_1_edges += ((edges == gt_edges)*(gt_edges == 1)).sum().item()
            correct += (edges*gt_edges).sum().item()
            num_predicted += edges.sum().item()
            num_gt += gt_edges.sum().item()
    prec = correct / (num_predicted + 1e-8)
    rec = correct / (num_gt + 1e-8)
    f1 = 2*prec*rec / (prec+rec+1e-6)
    all_edges = torch.cat(all_edges)
    return f1, correct_edges / (full_edge_count + 1e-8), correct_0_edges / (edge_0_count + 1e-8), correct_1_edges / (edge_1_count + 1e-8), all_edges

def plot_sample(model, dataset, num_samples, params):
    gpu = params.get('gpu', False)
    batch_size = params.get('batch_size', 1)
    use_gt_edges = params.get('use_gt_edges')
    data_loader = DataLoader(dataset, batch_size=batch_size)
    model.eval()
    batch_count = 0
    all_errors = []
    burn_in_steps = 10
    forward_pred_steps = 40
    for batch_ind, batch in enumerate(data_loader):
        inputs = batch['inputs']
        gt_edges = batch.get('edges', None)
        with torch.no_grad():
            model_inputs = inputs[:, :burn_in_steps]
            gt_predictions = inputs[:, burn_in_steps:burn_in_steps+forward_pred_steps]
            if gpu:
                model_inputs = model_inputs.cuda(non_blocking=True)
                if gt_edges is not None and use_gt_edges:
                    gt_edges = gt_edges.cuda(non_blocking=True)
            if not use_gt_edges:
                gt_edges=None
            model_preds = model.predict_future(model_inputs, forward_pred_steps).cpu()
            #total_se += F.mse_loss(model_preds, gt_predictions).item()
            print("MSE: ", torch.nn.functional.mse_loss(model_preds, gt_predictions).item())
            batch_count += 1
        fig, ax = plt.subplots()
        unnormalized_preds = dataset.unnormalize(model_preds)
        unnormalized_gt = dataset.unnormalize(inputs)
        def update(frame):
            ax.clear()
            ax.plot(unnormalized_gt[0, frame, 0, 0], unnormalized_gt[0, frame, 0, 1], 'bo')
            ax.plot(unnormalized_gt[0, frame, 1, 0], unnormalized_gt[0, frame, 1, 1], 'ro')
            ax.plot(unnormalized_gt[0, frame, 2, 0], unnormalized_gt[0, frame, 2, 1], 'go')
            if frame >= burn_in_steps:
                tmp_fr = frame - burn_in_steps
                ax.plot(unnormalized_preds[0, tmp_fr, 0, 0], unnormalized_preds[0, tmp_fr, 0, 1], 'bo', alpha=0.5)
                ax.plot(unnormalized_preds[0, tmp_fr, 1, 0], unnormalized_preds[0, tmp_fr, 1, 1], 'ro', alpha=0.5)
                ax.plot(unnormalized_preds[0, tmp_fr, 2, 0], unnormalized_preds[0, tmp_fr, 2, 1], 'go', alpha=0.5)
            ax.set_xlim(-6, 6)
            ax.set_ylim(-6, 6)
        ani = animation.FuncAnimation(fig, update, interval=100, frames=burn_in_steps+forward_pred_steps)
        path = os.path.join(params['working_dir'], 'pred_trajectory_%d.mp4'%batch_ind)
        ani.save(path, codec='mpeg4')
        if batch_count >= num_samples:
            break




if __name__ == '__main__':
    parser = build_flags()
    parser.add_argument('--data_path')
    parser.add_argument('--same_data_norm', action='store_true')
    parser.add_argument('--no_data_norm', action='store_true')
    parser.add_argument('--error_out_name', default='prediction_errors_%dstep.npy')
    parser.add_argument('--prior_variance', type=float, default=5e-5)
    parser.add_argument('--test_burn_in_steps', type=int, default=10)
    parser.add_argument('--error_suffix')
    parser.add_argument('--subject_ind', type=int, default=-1)

    args = parser.parse_args()
    params = vars(args)
    seed(args.seed)

    params['num_vars'] = 3
    params['input_size'] = 4
    params['input_time_steps'] = 50
    params['nll_loss_type'] = 'gaussian'
    train_data = SmallSynthData(args.data_path, 'train', params)
    val_data = SmallSynthData(args.data_path, 'val', params)

    model = build_model(params)
    if args.mode == 'train':
        with build_writers(args.working_dir) as (train_writer, val_writer):
            train.train(model, train_data, val_data, params, train_writer, val_writer)
 
    elif args.mode == 'eval':
        test_data = SmallSynthData(args.data_path, 'test', params)
        forward_pred = 50 - args.test_burn_in_steps
        test_mse  = eval_forward_prediction(model, test_data, args.test_burn_in_steps, forward_pred, params)
        path = os.path.join(args.working_dir, args.error_out_name%args.test_burn_in_steps)
        np.save(path, test_mse.cpu().numpy())
        test_mse_1 = test_mse[0].item()
        test_mse_15 = test_mse[14].item()
        test_mse_25 = test_mse[24].item()
        print("FORWARD PRED RESULTS:")
        print("\t1 STEP: ",test_mse_1)
        print("\t15 STEP: ",test_mse_15)
        print("\t25 STEP: ",test_mse_25)


        f1, all_acc, acc_0, acc_1, edges = eval_edges(model, val_data, params)
        print("Val Edge results:")
        print("\tF1: ",f1)
        print("\tAll predicted edge accuracy: ",all_acc)
        print("\tFirst Edge Acc: ",acc_0)
        print("\tSecond Edge Acc: ",acc_1)
        out_dir = os.path.join(args.working_dir, 'preds')
        os.makedirs(out_dir, exist_ok=True)
        out_path = os.path.join(out_dir, 'encoder_edges.npy')
        np.save(out_path, edges.numpy())

        plot_sample(model, test_data, args.test_burn_in_steps, params)
    elif args.mode == 'record_predictions':
        model.eval()
        burn_in = args.test_burn_in_steps
        forward_pred = 50 - args.test_burn_in_steps
        test_data = SmallSynthData(args.data_path, 'test', params)
        if args.subject_ind == -1:
            val_data_loader = DataLoader(test_data, batch_size=params['batch_size'])
            all_predictions = []
            all_edges = []
            for batch_ind,batch in enumerate(val_data_loader):
                print("BATCH %d of %d"%(batch_ind+1, len(val_data_loader)))
                inputs = batch['inputs']
                if args.gpu:
                    inputs = inputs.cuda(non_blocking=True)
                with torch.no_grad():
                    predictions, edges = model.predict_future(inputs[:, :burn_in], forward_pred, return_edges=True, return_everything=True)
                    all_predictions.append(predictions)
                    all_edges.append(edges)
            if args.error_suffix is not None:
                out_path = os.path.join(args.working_dir, 'preds/', 'all_test_subjects_%s.npy'%args.error_suffix)
            else:
                out_path = os.path.join(args.working_dir, 'preds/', 'all_test_subjects.npy')

            predictions = torch.cat(all_predictions, dim=0)
            edges = torch.cat(all_edges, dim=0)

        else:
            data = test_data[args.subject_ind]
            inputs = data['inputs'].unsqueeze(0)
            if args.gpu:
                inputs = inputs.cuda(non_blocking=True)
            with torch.no_grad():
                predictions, edges = model.predict_future(inputs[:, :burn_in], forward_pred, return_edges=True, return_everything=True)
                predictions = predictions.squeeze(0)
                edges = edges.squeeze(0)
            out_path = os.path.join(args.working_dir, 'preds/', 'subject_%d.npy'%args.subject_ind)
        tmp_dir = os.path.join(args.working_dir, 'preds/')
        if not os.path.exists(tmp_dir):
            os.makedirs(tmp_dir)
        torch.save([predictions.cpu(), edges.cpu()], out_path)