In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.optim.lr_scheduler import MultiStepLR

import numpy as np
import time as tm

from generator_utils import *
# from args import *

In [3]:
class GRU_plain(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, num_layers, has_input=True, has_output=False, output_size=None):
        super(GRU_plain, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.has_input = has_input
        self.has_output = has_output

        if has_input:
            self.input = nn.Linear(input_size, embedding_size)
            self.rnn = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
                              batch_first=True)
        else:
            self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        if has_output:
            self.output = nn.Sequential(
                nn.Linear(hidden_size, embedding_size),
                nn.ReLU(),
                nn.Linear(embedding_size, output_size)
            )

        self.relu = nn.ReLU()
        # initialize
        self.hidden = None  # need initialize before forward run

        for name, param in self.rnn.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0.25)
            elif 'weight' in name:
                nn.init.xavier_uniform_(param,gain=nn.init.calculate_gain('sigmoid'))
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data = init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))

    def init_hidden(self, batch_size):
        return Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).to(choose_device())

    def forward(self, input_raw, pack=False, input_len=None):
        if self.has_input:
            input = self.input(input_raw)
            input = self.relu(input)
        else:
            input = input_raw
        if pack:
            input = pack_padded_sequence(input, input_len, batch_first=True)
        output_raw, self.hidden = self.rnn(input, self.hidden)
        if pack:
            output_raw = pad_packed_sequence(output_raw, batch_first=True)[0]
        if self.has_output:
            output_raw = self.output(output_raw)
        # return hidden state at each time step
        return output_raw

In [4]:
class GraphRNN(nn.Module):
    def __init__(self, args, device=choose_device()) -> None:
        super().__init__()
        self.args = args
        self.device = device
        self.rnn = GRU_plain(input_size=self.args.max_prev_node, embedding_size=self.args.embedding_size_rnn,
                        hidden_size=self.args.hidden_size_rnn, num_layers=self.args.num_layers, has_input=True,
                        has_output=True, output_size=self.args.hidden_size_rnn_output).to(self.device)
        self.output = GRU_plain(input_size=1, embedding_size=self.args.embedding_size_rnn_output,
                            hidden_size=self.args.hidden_size_rnn_output, num_layers=self.args.num_layers, has_input=True,
                            has_output=True, output_size=1).to(self.device)

        # load data state
        if args.load:
            fname = args.model_save_path + args.fname + 'lstm_' + str(args.load_epoch) + '.dat'
            self.rnn.load_state_dict(torch.load(fname))
            fname = args.model_save_path + args.fname + 'output_' + str(args.load_epoch) + '.dat'
            self.output.load_state_dict(torch.load(fname))

            args.lr = 0.00001
            epoch = args.load_epoch
            print('model loaded!, lr: {}'.format(args.lr))
        else:
            epoch = 1

    # ====Call these in training loop====
    def init_optimizer(self, lr):
        """Initialize optimizers and schedular for both RNNs"""
        self.optimizer_rnn = optim.Adam(list(self.rnn.parameters()), lr=lr)
        self.optimizer_output = optim.Adam(list(self.output.parameters()), lr=lr)
        self.scheduler_rnn = MultiStepLR(self.optimizer_rnn, milestones=self.args.milestones)
        self.scheduler_output = MultiStepLR(self.optimizer_output, milestones=self.args.milestones)
        return self.optimizer_rnn, self.optimizer_output, self.scheduler_rnn, self.scheduler_output

    def clear_gradient_models(self):
        self.rnn.zero_grad()
        self.output.zero_grad()

    def train(self, flag):
        if flag:
            self.rnn.train(True)
            self.output.train(True)
        else:
            self.rnn.train(False)
            self.output.train(False)

    def clear_gradient_opts(self):
        self.optimizer_rnn.zero_grad()
        self.optimizer_output.zero_grad()

    def all_steps(self):
        self.optimizer_rnn.step()
        self.optimizer_output.step()
        self.scheduler_rnn.step()
        self.scheduler_output.step()

    # ======================================

    def forward(self, X):
        """
        X: noise/latent vector
        args: arguments dictionary
        test_batch_size: number of graphs you want to generate
        """
        # provide a option to change number of graphs generated
        output_batch_size = self.args.test_batch_size
        input_hidden = torch.stack(self.rnn.num_layers*[X]).to(self.device)
        self.rnn.hidden = input_hidden # expected shape: (num_layer, batch_size, hidden_size)

        # TODO: change this part to noise vector might need resizing
        y_pred_long = Variable(torch.zeros(output_batch_size, self.args.max_num_node, self.args.max_prev_node)).to(self.device) # discrete prediction
        # x_step = X.to(self.device) # shape:(batch_size, 1, self.args.max_prev_node)
        x_step = Variable(torch.ones(output_batch_size, 1, self.args.max_prev_node)).to(self.device)

        # iterative graph generation
        for i in range(self.args.max_num_node):
            # for each node
            # 1. we use rnn to create new node embedding
            # 2. we use output to create new edges

            # (1)
            h = self.rnn(x_step)
            hidden_null = Variable(torch.zeros(self.args.num_layers - 1, h.size(0), h.size(2))).to(self.device)
            x_step = Variable(torch.zeros(output_batch_size, 1, self.args.max_prev_node)).to(self.device)
            output_x_step = Variable(torch.ones(output_batch_size, 1, 1)).to(self.device)
            # (2)
            self.output.hidden = torch.cat((h.permute(1,0,2), hidden_null), dim=0).to(self.device)
            for j in range(min(self.args.max_prev_node,i+1)):
                output_y_pred_step = self.output(output_x_step)
                # print(output_y_pred_step.requires_grad)
                output_x_step = sample_sigmoid(output_y_pred_step, sample=True, sample_time=1, device=self.device)
                x_step[:,:,j:j+1] = output_x_step
                # self.output.hidden = Variable(self.output.hidden.data).to(self.device)
            y_pred_long[:, i:i + 1, :] = x_step
            # self.rnn.hidden = Variable(self.rnn.hidden.data).to(self.device)
        y_pred_long_data = y_pred_long.data.long()

        init_adj_pred = decode_adj(y_pred_long_data[0].cpu())
        adj_pred_list = torch.zeros((output_batch_size, init_adj_pred.size(0), init_adj_pred.size(1)))
        for i in range(output_batch_size):
            # adj_pred = decode_adj(y_pred_long_data[i].cpu().numpy())
            # adj_pred_list = np.append(adj_pred_list, adj_pred)
            # adj_pred_list.append(adj_pred)
            adj_pred_list[i, :, :] = decode_adj(y_pred_long_data[i].cpu())

        # return torch.Tensor(np.array(adj_pred_list))
        return adj_pred_list