In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import torch.nn.functional as F

In [None]:
class Graph_RNN_structure(nn.Module):
    def __init__(self, hidden_size, batch_size, output_size, num_layers, is_dilation=True, is_bn=True):
        super(Graph_RNN_structure, self).__init__()
        
        # Model configuration
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.is_bn = is_bn

        # Model layers
        self.relu = nn.ReLU()

        if is_dilation:
            self.conv_block = nn.ModuleList([nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=2**i, padding=2**i) for i in range(num_layers-1)])
        else:
            self.conv_block = nn.ModuleList([nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=1, padding=1) for i in range(num_layers-1)])
        
        self.bn_block = nn.ModuleList([nn.BatchNorm1d(hidden_size) for i in range(num_layers-1)])
        self.conv_out = nn.Conv1d(hidden_size, 1, kernel_size=3, dilation=1, padding=1)

        self.linear_transition = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )

        self.hidden_all = []

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
                init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
            if isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def init_hidden(self, length=None):
        if length is None:
            return Variable(torch.ones(self.batch_size, self.hidden_size, 1)).cuda()
        else:
            return [Variable(torch.ones(self.batch_size, self.hidden_size, 1)).cuda() for _ in range(length)]

    def forward(self, x, teacher_forcing, temperature=0.5, bptt=True, bptt_len=20, flexible=True, max_prev_node=100):
        hidden_all_cat = torch.cat(self.hidden_all, dim=2)

        for i in range(self.num_layers-1):
            hidden_all_cat = self.conv_block[i](hidden_all_cat)
            if self.is_bn:
                hidden_all_cat = self.bn_block[i](hidden_all_cat)
            hidden_all_cat = self.relu(hidden_all_cat)

        x_pred = self.conv_out(hidden_all_cat)
        x_pred_sample = F.sigmoid(x_pred)
        thresh = 0.5
        x_thresh = Variable(torch.ones(x_pred_sample.size(0), x_pred_sample.size(1), x_pred_sample.size(2)) * thresh).cuda()
        x_pred_sample_long = torch.gt(x_pred_sample, x_thresh).long()

        if teacher_forcing:
            hidden_all_cat_select = hidden_all_cat * x
            x_sum = torch.sum(x, dim=2, keepdim=True).float()
        else:
            hidden_all_cat_select = hidden_all_cat * x_pred_sample
            x_sum = torch.sum(x_pred_sample_long, dim=2, keepdim=True).float()

        hidden_new = torch.sum(hidden_all_cat_select, dim=2, keepdim=True) / x_sum
        hidden_new = self.linear_transition(hidden_new.permute(0, 2, 1))
        hidden_new = hidden_new.permute(0, 2, 1)

        if flexible:
            if teacher_forcing:
                x_id = torch.min(torch.nonzero(torch.squeeze(x.data)))
                self.hidden_all = self.hidden_all[x_id:]
            else:
                x_id = torch.min(torch.nonzero(torch.squeeze(x_pred_sample_long.data)))
                start = max(len(self.hidden_all)-max_prev_node+1, x_id)
                self.hidden_all = self.hidden_all[start:]
        else:
            self.hidden_all = self.hidden_all[1:]

        self.hidden_all.append(hidden_new)

        return x_pred, x_pred_sample
