In [1]:
import copy
import torch
import deepsnap
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn

from sklearn.metrics import f1_score
from deepsnap.hetero_gnn import forward_op
from deepsnap.hetero_graph import HeteroGraph
from torch_sparse import SparseTensor, matmul

import pickle
import networkx as nx

In [2]:
class HeteroGNNConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels_src, in_channels_dst, out_channels):
        super(HeteroGNNConv, self).__init__(aggr="mean")

        self.in_channels_src = in_channels_src
        self.in_channels_dst = in_channels_dst
        self.out_channels = out_channels

        self.lin_dst = None
        self.lin_src = None

        self.lin_update = None

        self.lin_dst = nn.Linear(in_channels_dst, out_channels)
        self.lin_src = nn.Linear(in_channels_src, out_channels)
        self.lin_update = nn.Linear(2 * out_channels, out_channels)

    def forward(
        self,
        node_feature_src,
        node_feature_dst,
        edge_index,
        size=None,
        res_n_id=None,
        ):

        return self.propagate(edge_index, node_feature_src=node_feature_src, 
                    node_feature_dst=node_feature_dst, size=size, res_n_id=res_n_id)

    def message_and_aggregate(self, edge_index, node_feature_src):

        out = matmul(edge_index, node_feature_src, reduce='mean')

        return out

    def update(self, aggr_out, node_feature_dst, res_n_id):

        dst_out = self.lin_dst(node_feature_dst)
        aggr_out = self.lin_src(aggr_out)
        aggr_out = torch.cat([dst_out, aggr_out], -1)
        aggr_out = self.lin_update(aggr_out)

        return aggr_out

In [3]:
class HeteroGNNWrapperConv(deepsnap.hetero_gnn.HeteroConv):
    def __init__(self, convs, args, aggr="mean"):
        """
        Initializes the HeteroGNNWrapperConv instance.

        :param convs: Dictionary of convolution layers for each message type.
        :param args: Arguments dictionary containing hyperparameters like hidden_size and attn_size.
        :param aggr: Aggregation method, defaults to 'mean'.
        """
        
        super(HeteroGNNWrapperConv, self).__init__(convs, None)
        self.aggr = aggr

        # Map the index and message type
        self.mapping = {}

        # A numpy array that stores the final attention probability
        self.alpha = None

        self.attn_proj = None

        if self.aggr == "attn":

            self.attn_proj = nn.Sequential(
                nn.Linear(args['hidden_size'], args['attn_size']),
                nn.Tanh(),
                nn.Linear(args['attn_size'], 1, bias=False)
            )
    
    def reset_parameters(self):
        super(HeteroGNNWrapperConv, self).reset_parameters()
        if self.aggr == "attn":
            for layer in self.attn_proj.children():
                layer.reset_parameters()
    
    def forward(self, node_features, edge_indices):
        """
        Forward pass of the model.

        :param node_features: Dictionary of node features for each node type.
        :param edge_indices: Dictionary of edge indices for each message type.
        :return: Aggregated node embeddings for each node type.
        """
        
        message_type_emb = {}
        for message_key, message_type in edge_indices.items():
            src_type, edge_type, dst_type = message_key
            node_feature_src = node_features[src_type]
            node_feature_dst = node_features[dst_type]
            edge_index = edge_indices[message_key]
            message_type_emb[message_key] = (
                self.convs[message_key](
                    node_feature_src,
                    node_feature_dst,
                    edge_index,
                )
            )
            
        
        node_emb = {dst: [] for _, _, dst in message_type_emb.keys()}
        mapping = {}        
        
        for (src, edge_type, dst), item in message_type_emb.items():
            mapping[len(node_emb[dst])] = (src, edge_type, dst)
            node_emb[dst].append(item)
        self.mapping = mapping
        
        for node_type, embs in node_emb.items():
            if len(embs) == 1:
                node_emb[node_type] = embs[0]
            else:
                node_emb[node_type] = self.aggregate(embs)
                
        return node_emb
    
    def aggregate(self, xs):
        """
        Aggregates node embeddings using the specified aggregation method.

        :param xs: List of node embeddings to aggregate.
        :return: Aggregated node embeddings as a torch.Tensor.
        """

        if self.aggr == "mean":
            xs = torch.stack(xs)
            out = torch.mean(xs, dim=0)
            return out

        elif self.aggr == "attn":
            xs = torch.stack(xs, dim=0)
            s = self.attn_proj(xs).squeeze(-1)
            s = torch.mean(s, dim=-1)
            self.alpha = torch.softmax(s, dim=0).detach()
            out = self.alpha.reshape(-1, 1, 1) * xs
            out = torch.sum(out, dim=0)
            return out

In [4]:
def generate_convs(hetero_graph, conv, hidden_size, first_layer=False):
    
    """
    Generates convolutional layers for each message type in a heterogeneous graph.

    :param hetero_graph: The heterogeneous graph for which convolutions are to be created.
    :param conv: The convolutional layer class or constructor.
    :param hidden_size: The number of features in the hidden layer.
    :param first_layer: Boolean indicating if this is the first layer in the network.
    
    :return: A dictionary of convolutional layers, keyed by message type.
    """

    convs = {}
    
    # Extracting all types of messages/edges in the heterogeneous graph.
    all_messages_types = hetero_graph.message_types
    for message_type in all_messages_types:
        
        # Determine the input feature size for source and destination nodes.
        # If it's the first layer, use the feature size of the nodes.
        # Otherwise, use the hidden size, since from there on the size of embeddings
        # is the same for all nodes.
        if first_layer:
            in_channels_src = hetero_graph.num_node_features(message_type[0])
            in_channels_dst = hetero_graph.num_node_features(message_type[2])
        else:
            in_channels_src = hidden_size
            in_channels_dst = hidden_size
        out_channels = hidden_size
        
        # Create a convolutional layer for this message type and add it to the dictionary.
        convs[message_type] = conv(in_channels_src, in_channels_dst, out_channels)
    
    return convs

In [5]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hetero_graph, args, num_layers, aggr="mean"):
        super(HeteroGNN, self).__init__()

        self.aggr = aggr
        self.hidden_size = args['hidden_size']

        self.bns1 = nn.ModuleDict()
        self.bns2 = nn.ModuleDict()
        self.relus1 = nn.ModuleDict()
        self.relus2 = nn.ModuleDict()
        self.post_mps = nn.ModuleDict()
        self.fc = nn.ModuleDict()
        
        # Initialize the graph convolutional layers
        self.convs1 = HeteroGNNWrapperConv(
            generate_convs(hetero_graph, HeteroGNNConv, self.hidden_size, first_layer=True), 
            args, self.aggr)
        self.convs2 = HeteroGNNWrapperConv(
            generate_convs(hetero_graph, HeteroGNNConv, self.hidden_size, first_layer=False), 
            args, self.aggr)

        # Initialize batch normalization, ReLU, and fully connected layers for each node type
        all_node_types = hetero_graph.node_types
        for node_type in all_node_types:
            
            self.bns1[node_type] = nn.BatchNorm1d(self.hidden_size, eps=1.0)
            self.bns2[node_type] = nn.BatchNorm1d(self.hidden_size, eps=1.0)
            
            self.relus1[node_type] = nn.LeakyReLU()
            self.relus2[node_type] = nn.LeakyReLU()
            self.fc[node_type] = nn.Linear(self.hidden_size, 1)
            
    def forward(self, node_feature, edge_index):
        """
        Forward pass of the model.

        :param node_feature: Dictionary of node features for each node type.
        :param edge_index: Dictionary of edge indices for each message type.
        :return: The output embeddings for each node type after passing through the model.
        """
        x = node_feature

        # Apply graph convolutional, batch normalization, and ReLU layers
        x = self.convs1(x, edge_index)
        x = forward_op(x, self.bns1)
        x = forward_op(x, self.relus1)

        x = self.convs2(x, edge_index)
        x = forward_op(x, self.bns2)
        x = forward_op(x, self.relus2)
        
        x = forward_op(x, self.fc)

        return x

    def loss(self, preds, y, indices):
        """
        Computes the loss for the model.

        :param preds: Predictions made by the model.
        :param y: Ground truth target values.
        :param indices: Indices of nodes for which loss should be calculated.
        
        :return: The computed loss value.
        """
        
        loss = 0
        loss_func = torch.nn.MSELoss() 
             
        mask = y['event'][indices['event'], 0] != -1
        non_zero_idx = torch.masked_select(indices['event'], mask)
                
        loss += loss_func(preds['event'][non_zero_idx], y['event'][non_zero_idx])

        return loss

In [6]:
def train(model, optimizer, hetero_graph, train_idx):
    """
    Trains the model on the given heterogeneous graph using the specified indices.

    :param model: The graph neural network model to train.
    :param optimizer: The optimizer used for training the model.
    :param hetero_graph: The heterogeneous graph data.
    :param train_idx: Indices for training nodes.

    :return: The training loss as a float.
    """

    model.train() # Set the model to training mode
    optimizer.zero_grad() # Zero out any existing gradients 
    
    # Compute predictions using the model
    preds = model(hetero_graph.node_feature, hetero_graph.edge_index)

    # Compute the loss using model's loss function
    loss = model.loss(preds, hetero_graph.node_target, train_idx)

    loss.backward() # Backward pass: compute gradient of the loss
    optimizer.step() # Perform a single optimization step, updates parameters
    
    return loss.item() 

def test(model, graph, indices, best_model=None, best_val=0):
    """
    Tests the model on given indices and updates the best model based on validation loss.

    :param model: The trained graph neural network model.
    :param graph: The heterogeneous graph data.
    :param indices: List of indices for training, validation, and testing nodes.
    :param best_model: The current best model based on validation loss.
    :param best_val: The current best validation loss.
    
    :return: A tuple containing the list of losses for each dataset, the best model, and the best validation loss.
    """
    
    model.eval() # Set the model to evaluation mode
    accs = []
    
    # Evaluate the model on each set of indices
    for index in indices:
        preds = model(graph.node_feature, graph.edge_index)
        
        idx = index['event']

        L1 = torch.sum(torch.abs(preds['event'][idx] - graph.node_target['event'][idx]))
        
        accs.append(L1)
    
    # Update the best model and validation loss if the current model performs better
    if accs[1] < best_val:
        best_val = accs[1]
        best_model = copy.deepcopy(model)
    
    return accs, best_model, best_val

In [7]:
args = {
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'hidden_size': 64,
    'epochs': 200,
    'weight_decay': 1e-5,
    'lr': 0.201,
    'attn_size': 32,
}

In [8]:
"""
This cell creates a small heterogeneous graph, primarily for testing purposes.
"""

S_node_feature = {
    "event": torch.tensor([
                [1, 1, 1],   # event 0
                [2, 2, 2]    # event 1
    ], dtype=torch.float32),
    "concept": torch.tensor([
                [2, 2, 2],   # concept 0
                [3, 3, 3]    # concept 1
    ], dtype=torch.float32)
}

# S_node_label = {
#     "event": torch.tensor([0, 1], dtype=torch.long), # Class 0, Class 1
#     "concept": torch.tensor([0, 1], dtype=torch.long)  # Class 0, Class 1
# }

S_node_targets = {
    "event": torch.tensor([[50], [2000]], dtype=torch.float32),
    # "concept": torch.tensor([[0], [0]], dtype=torch.float32)
}

S_edge_index = {
    ("event", "similar", "event"): torch.tensor([[0,1],[1,0]], dtype=torch.int64),
    ("event", "related", "concept"): torch.tensor([[0,1],[0,1]], dtype=torch.int64),
    ("concept", "related", "event"): torch.tensor([[0,1],[0,1]], dtype=torch.int64)
}

# Testing
hetero_graph = HeteroGraph(
    node_feature=S_node_feature,
    node_target=S_node_targets,
    edge_index=S_edge_index
)

train_idx = {"event": torch.tensor([0, 1]).to(args['device']), "concept": torch.tensor([0, 1]).to(args['device'])}
val_idx = {"event": torch.tensor([0, 1]).to(args['device']), "concept": torch.tensor([0, 1]).to(args['device'])}
test_idx = {"event": torch.tensor([0, 1]).to(args['device']), "concept": torch.tensor([0, 1]).to(args['device'])}

In [9]:
with open("./10_concepts_similar_llm_noUnknown_CHEAT.pkl", "rb") as f:
    G = pickle.load(f)
    # Convert to directed graph for compatibility with Deepsnap
    # G = G.to_directed()
    
hetero_graph = HeteroGraph(G, netlib=nx, directed=True)

In [10]:
"""
This code block ensures that all tensors in a heterogeneous graph are transferred to the same 
computational device, as specified in the 'args' dictionary.
"""

for message_type in hetero_graph.message_types:
    print("TYPE", message_type)
    print("\t Feature", hetero_graph.num_node_features(message_type[0]))
    print("\t Feature", hetero_graph.num_node_features(message_type[2]))

# Send node features to device
for key in hetero_graph.node_feature:
    hetero_graph.node_feature[key] = hetero_graph.node_feature[key].to(args['device'])

# for key in hetero_graph.node_label:
#     hetero_graph.node_label[key] = hetero_graph.node_label[key].to(args['device'])

# Create a torch.SparseTensor from edge_index and send it to device
for key in hetero_graph.edge_index:
    print("KEY", key, type(key))
    print("KEY NUMS", key, hetero_graph.num_nodes(key[0]), hetero_graph.num_nodes(key[2]))
    
    edge_index = hetero_graph.edge_index[key]

    print("MAX EDGES", edge_index[0].max(), edge_index[1].max(), hetero_graph.num_nodes(key[0]), hetero_graph.num_nodes(key[2]))
    adj = SparseTensor(row=edge_index[0].long(), col=edge_index[1].long(), sparse_sizes=(hetero_graph.num_nodes(key[0]), hetero_graph.num_nodes(key[2])))
    hetero_graph.edge_index[key] = adj.t().to(args['device'])
    
# Send node targets to device
for key in hetero_graph.node_target:
    hetero_graph.node_target[key] = hetero_graph.node_target[key].to(args['device'])

TYPE ('concept', 'related', 'event')
	 Feature 1
	 Feature 770
TYPE ('event', 'related', 'concept')
	 Feature 770
	 Feature 1
TYPE ('event', 'similar', 'event')
	 Feature 770
	 Feature 770
KEY ('concept', 'related', 'event') <class 'tuple'>
KEY NUMS ('concept', 'related', 'event') 36989 5000
MAX EDGES tensor(36988) tensor(4999) 36989 5000
KEY ('event', 'related', 'concept') <class 'tuple'>
KEY NUMS ('event', 'related', 'concept') 5000 36989
MAX EDGES tensor(4999) tensor(36988) 5000 36989
KEY ('event', 'similar', 'event') <class 'tuple'>
KEY NUMS ('event', 'similar', 'event') 5000 5000
MAX EDGES tensor(4986) tensor(4986) 5000 5000


AttributeError: 'HeteroGraph' object has no attribute 'node_target'

In [None]:
"""
This code block creates a basic split of a graph's nodes into training, validation, and testing sets. 
It uses predefined ratios to divide 'event' and 'concept' nodes in the heterogeneous graph for a simple 
dataset split, mainly for testing purposes.
"""

nEvents = hetero_graph.num_nodes("event")
nConcepts = hetero_graph.num_nodes("concept")

s1 = 0.7
s2 = 0.8

train_idx = {   "event": torch.tensor(range(0, int(nEvents * s1))).to(args["device"]), 
                "concept": torch.tensor(range(0, int(nConcepts * s1))).to(args["device"])
            }
val_idx = {   "event": torch.tensor(range(int(nEvents * s1), int(nEvents * s2))).to(args["device"]), 
                "concept": torch.tensor(range(int(nConcepts * s1), int(nConcepts * s2))).to(args["device"])
            }
test_idx = {   "event": torch.tensor(range(int(nEvents * s2), nEvents)).to(args["device"]), 
                "concept": torch.tensor(range(int(nConcepts * s2), nConcepts)).to(args["device"])
            }

print(train_idx["event"].shape)
print(test_idx["event"].shape)
print(val_idx["event"].shape)


# TODO: Add node labels to the nodes and try to make the deepsnap split work even for regression!

# dataset = deepsnap.dataset.GraphDataset([hetero_graph], task='node')

# dataset_train, dataset_val, dataset_test = dataset.split(transductive=True, split_ratio=[0.4, 0.3, 0.3])
# datasets = {'train': dataset_train, 'val': dataset_val, 'test': dataset_test}

# datasets

In [None]:
"""
Creates a HeteroGNN model from the previously constructed hetero graph and trains it.
"""

best_model = None
best_val = float("inf")

model = HeteroGNN(hetero_graph, args, num_layers=2, aggr="mean").to(args['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

for epoch in range(args['epochs']):
    # Train
    loss = train(model, optimizer, hetero_graph, train_idx)
    # Test for the accuracy of the model
    accs, best_model, best_val = test(model, hetero_graph, [train_idx, val_idx, test_idx], best_model, best_val)
    print(f"Epoch {epoch} Loss {loss} Accs {accs}")

# Get the accuracy of the best model
best_accs, _, _ = test(best_model, hetero_graph, [train_idx, val_idx, test_idx])

print("Best accs", best_accs)

In [None]:
preds = best_model(hetero_graph.node_feature, hetero_graph.edge_index)
# mask = preds['event'] > 0
# preds['event'][preds['event'] > 0].shape

# print(preds['event'][0], hetero_graph.node_target['event'][0]) 


#print(hetero_graph.node_feature['event'])

# for i in range(3000):
#     if hetero_graph.node_target['event'][i] != -1: # concepts have node target -1
#         print(preds['event'][i], hetero_graph.node_target['event'][i])
        
    
for i in range(1000):    
    if hetero_graph.node_target['event'][test_idx['event']][i] != -1:
        print(preds['event'][test_idx['event']][i], hetero_graph.node_target['event'][test_idx['event']][i])