In [1]:
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
import numpy as np
import pickle
import networkx as nx
from dataloader import *
from RGCN import *

In [2]:
with open("DGL_graph.pkl", "rb") as f:
    g = pickle.load(f)
    
with open("data/conceptnet/embedding_values.pkl", "rb") as f:
    embedding_values = pickle.load(f)
        
g.ndata["x"] = embedding_values

# subsample a strongly-connected subgraph

G_nx = g.to_networkx()
sub_G_nx = nx.strongly_connected_components(G_nx)
SCC = []
for item in sub_G_nx:
    if len(item) > 2:
        SCC.append(item)
component = list(SCC[0])

# assign embedding to graph
sub_graph = g.subgraph(component)
sub_graph.copy_from_parent()
sub_graph_G = sub_graph

X,y = gen_training_set(sub_graph_G, 3, 100)
print(len(X), len(y))

100 100


In [3]:
# configurations
n_hidden = 16 # number of hidden units
n_bases = -1 # use number of relations as number of bases
n_hidden_layers = 0 # use 1 input layer, 1 output layer, no hidden layer
n_epochs = 25 # epochs to train
lr = 0.01 # learning rate
l2norm = 0 # L2 norm coefficient

n_input = 300
n_output = 256
num_rels = 34

# create graph
edge_norm = torch.ones(g.edata['rel_type'].shape[0])
g.edata.update({'norm': edge_norm.view(-1,1)})

# create model
model = Model(n_input,
              n_hidden,
              n_output,
              num_rels,
              num_bases=n_bases,
              num_hidden_layers=n_hidden_layers)

In [4]:
g.ndata['x'] = g.ndata['x'].float()
output, g_emb = model(g)

In [21]:
class graph_to_graph(nn.Module):
    # take a Commonsense graph as input, output generated new nodes(phrase) and edges(phrase)
    def __init__(self, input_size, hidden_size, node_output_size, phrase_output_size, edge_output_size, num_rels, n_hidden_layers, n_bases = -1):
        super(graph_to_graph, self).__init__()
        self.node = LSTM_node_generator(hidden_size, node_output_size)
        self.phrase = LSTM_phrase_generator(hidden_size, phrase_output_size) #TODO: replace by gpt-2
        self.edge = LSTM_edge_generator(hidden_size, edge_output_size)
        # USE vanilla GCN
        #self.graph_encoder = Net(input_size, 256, hidden_size)
        # USE R-GCN
        self.graph_encoder = Model(input_size,
              hidden_size,
              hidden_size,
              num_rels,
              num_bases=n_bases,
              num_hidden_layers=n_hidden_layers)
        
    def generate_graph_embedding(self, g):
        return self.graph_encoder(g)
    
    def node_policy(self, *args):
        return self.node(*args)
       
    def generate_node_baseline(self, g):
        node_embedding, g_embedding = self.graph_encoder(g)
        node_embedding = list(node_embedding)
        
        node_decoder_input = torch.tensor([[SOS_token]], device=device)
        node_decoder_hidden = (g_embedding.view(1,1,-1), torch.zeros_like(g_embedding).view(1,1,-1))
        
        new_node_list = []
        
        # TODO: implementing teacher-forcing
        for ni in range(max_length):
            node_decoder_output, node_decoder_hidden = self.node(
                node_decoder_input, node_decoder_hidden)
            new_node_embedding = node_decoder_hidden
            topv, topi = node_decoder_output.topk(1)
            node_decoder_input = topi.squeeze().detach()  # detach from history as input
            print(node_decoder_input)
            if node_decoder_input.item() == EOS_token: # stop generating node
                break
            else:  # new node embedding generated
                new_node_list.append(node_decoder_input.item())
        return new_node_list
    
    def generate_node(self, g):
        node_embedding, g_embedding = self.graph_encoder(g)
        node_embedding = list(node_embedding)
        
        node_decoder_input = torch.tensor([[SOS_token]], device=device)
        node_decoder_hidden = (g_embedding.view(1,1,-1), torch.zeros_like(g_embedding).view(1,1,-1))
        
        new_node_list = []
        new_phrase_list = []
        new_edge_list = []
        
        # TODO: implementing teacher-forcing
        for ni in range(max_length):
            node_decoder_output, node_decoder_hidden = self.node(
                node_decoder_input, node_decoder_hidden)
            new_node_embedding = node_decoder_hidden
            topv, topi = node_decoder_output.topk(1)
            node_decoder_input = topi.squeeze().detach()  # detach from history as input
            print(node_decoder_input)
            if node_decoder_input.item() == EOS_token: # stop generating node
                break
            else:  # new node embedding generated
                # add new node embedding to the list
                new_node_list.append(new_node_embedding)
                # generate new phrase
                new_phrase = []
                phrase_decoder_input = torch.tensor([[SOS_token]], device=device)
                phrase_decoder_hidden = (new_node_embedding.view(1,1,-1), torch.zeros_like(new_node_embedding.view(1,1,-1)))
                for pi in range(max_length):
                    phrase_decoder_output, phrase_decoder_hidden = self.phrase(
                        phrase_decoder_input, phrase_decoder_hidden)

                    topv, topi = phrase_decoder_output.topk(1)
                    phrase_decoder_input = topi.squeeze().detach()  # detach from history as input

                    if phrase_decoder_input.item() == EOS_token: # stop generating node
                        break
                    new_phrase.append(phrase_decoder_input)
                    
                new_phrase_list.append(new_phrase)
                
        return new_phrase_list
    
    def forward(self, g, max_length = 100):
        # get graph embedding
        node_embedding, g_embedding = self.graph_encoder(g)
        node_embedding = list(node_embedding)
        
        node_decoder_input = torch.tensor([[SOS_token]], device=device)
        node_decoder_hidden = g_embedding
        node_decoder_hidden = (g_embedding.view(1,1,-1), torch.zeros_like(g_embedding).view(1,1,-1))
        
        new_node_list = []
        new_phrase_list = []
        new_edge_list = []
        
        # TODO: implementing teacher-forcing
        for ni in range(max_length):
            node_decoder_output, node_decoder_hidden = self.node(
                node_decoder_input, node_decoder_hidden)
            new_node_embedding = node_decoder_hidden[0]
            topv, topi = node_decoder_output.topk(1)
            node_decoder_input = topi.squeeze().detach()  # detach from history as input
            if node_decoder_input.item() == EOS_token: # stop generating node
                break
            else:  # new node embedding generated
                # add new node embedding to the list
                new_node_list.append(new_node_embedding)
                # generate new phrase
                new_phrase = []
                phrase_decoder_input = torch.tensor([[SOS_token]], device=device)
                for pi in range(max_length):
                    phrase_decoder_output, phrase_decoder_hidden = self.phrase(
                        phrase_decoder_input, phrase_decoder_hidden)

                    topv, topi = phrase_decoder_output.topk(1)
                    phrase_decoder_input = topi.squeeze().detach()  # detach from history as input

                    if phrase_decoder_input.item() == EOS_token: # stop generating node
                        break
                    new_phrase.append(phrase_decoder_input)
                    
                new_phrase_list.append(new_phrase)
        # generate edge between nodes

        for i, node1 in enumerate(node_embedding + new_node_list):
            for j, node2 in enumerate(node_embedding + new_node_list):
                edge_decoder_hidden = torch.cat([node1, node2])
                edge_decoder_input = torch.tensor([[SOS_token]], device=device)
                new_edge = []
                for ei in range(max_length):
                    edge_decoder_output, edge_decoder_hidden = self.edge(
                    edge_decoder_input, edge_decoder_hidden)
                    
                    topv, topi = edge_decoder_output.topk(1)
                    edge_decoder_input = topi.squeeze().detach()  # detach from history as input

                    if edge_decoder_input.item() == EOS_token: # stop generating node
                        break
                    new_edge.append(edge_decoder_input)
                new_edge_list.append((i,j,new_edge))
                
        return new_phrase_list, new_edge_list

In [22]:
#from model import *
CORPUS_SIZE = 10000
input_size = 300
hidden_size = 256
node_output_size = g.ndata['x'].shape[0] + 2
phrase_output_size = CORPUS_SIZE
edge_output_size = CORPUS_SIZE
num_rels = 34
n_hidden_layers = 2
n_bases = -1

graph_generator = graph_to_graph(input_size, 
                                 hidden_size, 
                                 node_output_size, 
                                 phrase_output_size, 
                                 edge_output_size, 
                                 num_rels, 
                                 n_hidden_layers, 
                                 n_bases = -1)

In [23]:
graph_S = sub_graph_G.subgraph(X[0])
graph_S.copy_from_parent()

graph_S.ndata["x"] = graph_S.ndata["x"].float()
edge_norm = torch.ones(graph_S.edata['rel_type'].shape[0])
graph_S.edata.update({'norm': edge_norm.view(-1,1)})


print(graph_generator.generate_node_baseline(graph_S))

tensor(1785)
tensor(1722)
tensor(360)
tensor(360)
tensor(3174)
tensor(3170)
tensor(360)
tensor(3170)
tensor(600)
tensor(4091)
tensor(741)
tensor(1920)
tensor(531)
tensor(1722)
tensor(1577)
tensor(600)
tensor(636)
tensor(1722)
tensor(360)
tensor(360)
[1785, 1722, 360, 360, 3174, 3170, 360, 3170, 600, 4091, 741, 1920, 531, 1722, 1577, 600, 636, 1722, 360, 360]


In [139]:
# Policy gradient
from torch.distributions import Categorical

gamma = 0.99
class Policy(nn.Module):
    # Wrap up the LSTM decisions
    def __init__(self):
        super(Policy, self).__init__()
        #self.state_space = env.observation_space.shape[0]
        #self.action_space = env.action_space.n

        self.gamma = gamma
        self.graph_generator =  graph_to_graph(input_size, 
                                 hidden_size, 
                                 node_output_size, 
                                 phrase_output_size, 
                                 edge_output_size, 
                                 num_rels, 
                                 n_hidden_layers, 
                                 n_bases = -1)
        
        # Episode policy and reward history 
        self.policy_history = torch.Tensor([])
        self.reward_episode = []
        # Overall reward and loss history
        self.reward_history = []
        self.loss_history = []
        self.action_history = []
        
    def forward(self, *args):    
            return self.graph_generator.node_policy(*args)

def select_action(policy, *args):
    # state: (h_i, c_i), h_i ~ (1,1,hid_dim)
    #Select an action (0 or 1) by running policy model and choosing based on the probabilities in state
    d_action,state = policy(*args)
    c = Categorical(torch.exp(d_action))
    action = c.sample()

    # Add log probability of our chosen action to our history    
    if policy.policy_history.dim() != 0:
        #print(policy.policy_history.shape, c.log_prob(action).shape)
        policy.policy_history = torch.cat([policy.policy_history, c.log_prob(action).view(-1)])
    else:
        policy.policy_history = (c.log_prob(action))
    return action, state


In [111]:
node_embedding, g_embedding = graph_generator.generate_graph_embedding(graph_S)

node_decoder_input = torch.tensor([[SOS_token]], device=device)
node_decoder_hidden = g_embedding
node_decoder_hidden = (g_embedding.view(1,1,-1), torch.zeros_like(g_embedding).view(1,1,-1))

output = select_action(policy, node_decoder_input, node_decoder_hidden)

In [178]:
def update_policy():
    R = 0
    rewards = []
    
    # Discount future rewards back to the present using gamma
    for r in policy.reward_episode[::-1]:
        R = r + policy.gamma * R
        rewards.insert(0,R)
        
    # Scale rewards
    rewards = torch.FloatTensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
    
    # Calculate loss
    loss = (torch.sum(torch.mul(policy.policy_history, rewards).mul(-1), -1))
    # Update network weights
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    
    #Save and intialize episode history counters
    policy.loss_history.append(loss.item)
    policy.reward_history.append(np.sum(policy.reward_episode))
    policy.policy_history = torch.Tensor()
    policy.reward_episode= []
    policy.action_history = []
    return loss.item()

In [202]:
import itertools
policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-3)

def main(episodes):
    running_reward = 10
    for episode in range(episodes):
        #state = env.reset() # Reset environment and record the starting state
        # initial state:
        #policy = Policy(graph_generator)
        node_embedding, g_embedding = policy.graph_generator.generate_graph_embedding(graph_S)
        node_decoder_input = torch.tensor([[SOS_token]], device=device)
        node_decoder_hidden = g_embedding
        node_decoder_hidden = (g_embedding.view(1,1,-1), torch.zeros_like(g_embedding).view(1,1,-1))

        for time in range(20):
            action, state = select_action(policy, node_decoder_input, node_decoder_hidden)
            policy.action_history.append(action.item())
            # Step through environment using chosen action
            #state, reward, done, _ = env.step(action.data[0])
            reward = compute_reward(policy.action_history, y[0])
            node_decoder_input = action
            node_decoder_hidden = state
            # Save reward
            policy.reward_episode.append(reward)
        
        # Used to determine when the environment is solved.
        running_reward = (running_reward * 0.99) + (time * 0.01)
        action_history = policy.action_history
        loss = update_policy()
        if episode % 50 == 0:
            print(action_history)
            print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}\t Loss: {:.4f}'.format(episode, time, running_reward, loss))

In [203]:
def compute_reward(action_history, y):
    #print(action_history, y)
    score = len(set(y).intersection(set(action_history))) + (len(y) - len(action_history))
    return score

In [204]:
main(1000)

[1080, 3082, 1357, 1286, 2101, 1352, 1175, 3810, 2327, 1877, 270, 303, 2152, 3673, 4340, 2607, 1121, 3521, 2029, 3440]
Episode 0	Last length:    19	Average length: 10.09	 Loss: 0.4051
[2122, 1240, 4548, 956, 4011, 94, 1548, 4234, 3653, 2514, 2667, 966, 2773, 1767, 2260, 3834, 457, 289, 2679, 3701]
Episode 50	Last length:    19	Average length: 13.61	 Loss: -0.5620
[4371, 3408, 3587, 45, 1524, 4312, 3763, 573, 828, 2633, 1960, 71, 1844, 535, 1505, 896, 4163, 1922, 3244, 1570]
Episode 100	Last length:    19	Average length: 15.74	 Loss: 7.3640
[3162, 3729, 407, 2396, 4237, 4261, 1428, 3991, 2099, 1931, 215, 3982, 4438, 2708, 2875, 4238, 1807, 1843, 1975, 3993]
Episode 150	Last length:    19	Average length: 17.03	 Loss: 1.3405
[4369, 4002, 2540, 873, 3285, 3056, 1304, 4332, 770, 2182, 3162, 1551, 2279, 837, 653, 2406, 2043, 1918, 2844, 4290]
Episode 200	Last length:    19	Average length: 17.81	 Loss: -4.3198
[4036, 3813, 3271, 4389, 1847, 3993, 2456, 1287, 293, 2867, 3184, 1757, 6, 3818, 45

KeyboardInterrupt: 

In [193]:
y[0]

[1061, 1827, 447]