In [1]:
import cop
import torch
import deepsnap
import deepsnap.batch
from deepsnap.dataset import GraphDataset
from deepsnap.batch import Batch
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn

from sklearn.metrics import f1_score
from deepsnap.hetero_gnn import forward_op
from deepsnap.hetero_graph import HeteroGraph
from torch_sparse import SparseTensor, matmul

import pickle
import networkx as nx
import wandb
import optuna
import argparse

from hetero_gnn import HeteroGNN

In [None]:
train_args = {
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "hidden_size": 81,
    "epochs": 10,
    "weight_decay": 0.00002203762357664057,
    "lr": 0.003873757421883433,
    "attn_size": 48,
    "num_layers": 6,
    "aggr": "attn",
    "batch_size": 2,
}

In [None]:
class HeteroGNNConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels_src, in_channels_dst, out_channels):
        super(HeteroGNNConv, self).__init__(aggr="mean")

        self.in_channels_src = in_channels_src
        self.in_channels_dst = in_channels_dst
        self.out_channels = out_channels

        self.lin_dst = None
        self.lin_src = None

        self.lin_update = None

        self.lin_dst = nn.Linear(in_channels_dst, out_channels)
        self.lin_src = nn.Linear(in_channels_src, out_channels)
        self.lin_update = nn.Linear(2 * out_channels, out_channels)

    def forward(
        self,
        node_feature_src,
        node_feature_dst,
        edge_index,
        size=None,
        res_n_id=None,
        ):

        return self.propagate(edge_index, node_feature_src=node_feature_src, 
                    node_feature_dst=node_feature_dst, size=size, res_n_id=res_n_id)

    def message_and_aggregate(self, edge_index, node_feature_src):

        out = matmul(edge_index, node_feature_src, reduce='mean')

        return out

    def update(self, aggr_out, node_feature_dst, res_n_id):

        dst_out = self.lin_dst(node_feature_dst)
        aggr_out = self.lin_src(aggr_out)
        aggr_out = torch.cat([dst_out, aggr_out], -1)
        aggr_out = self.lin_update(aggr_out)

        return aggr_out

In [None]:
class HeteroGNNWrapperConv(deepsnap.hetero_gnn.HeteroConv):
    def __init__(self, convs, args, aggr="mean"):
        """
        Initializes the HeteroGNNWrapperConv instance.

        :param convs: Dictionary of convolution layers for each message type.
        :param args: Arguments dictionary containing hyperparameters like hidden_size and attn_size.
        :param aggr: Aggregation method, defaults to 'mean'.
        """
        
        super(HeteroGNNWrapperConv, self).__init__(convs, None)
        self.aggr = aggr

        # Map the index and message type
        self.mapping = {}

        # A numpy array that stores the final attention probability
        self.alpha = None

        self.attn_proj = None

        if self.aggr == "attn":

            self.attn_proj = nn.Sequential(
                nn.Linear(args['hidden_size'], args['attn_size']),
                nn.Tanh(),
                nn.Linear(args['attn_size'], 1, bias=False)
            )
    
    def reset_parameters(self):
        super(HeteroGNNWrapperConv, self).reset_parameters()
        if self.aggr == "attn":
            for layer in self.attn_proj.children():
                layer.reset_parameters()
    
    def forward(self, node_features, edge_indices):
        """
        Forward pass of the model.

        :param node_features: Dictionary of node features for each node type.
        :param edge_indices: Dictionary of edge indices for each message type.
        :return: Aggregated node embeddings for each node type.
        """
        
        message_type_emb = {}
        for message_key, message_type in edge_indices.items():
            src_type, edge_type, dst_type = message_key
            node_feature_src = node_features[src_type]
            node_feature_dst = node_features[dst_type]
            edge_index = edge_indices[message_key]
            message_type_emb[message_key] = (
                self.convs[message_key](
                    node_feature_src,
                    node_feature_dst,
                    edge_index,
                )
            )
            
        
        node_emb = {dst: [] for _, _, dst in message_type_emb.keys()}
        mapping = {}        
        
        for (src, edge_type, dst), item in message_type_emb.items():
            mapping[len(node_emb[dst])] = (src, edge_type, dst)
            node_emb[dst].append(item)
        self.mapping = mapping
        
        for node_type, embs in node_emb.items():
            if len(embs) == 1:
                node_emb[node_type] = embs[0]
            else:
                node_emb[node_type] = self.aggregate(embs)
                
        return node_emb
    
    def aggregate(self, xs):
        """
        Aggregates node embeddings using the specified aggregation method.

        :param xs: List of node embeddings to aggregate.
        :return: Aggregated node embeddings as a torch.Tensor.
        """

        if self.aggr == "mean":
            xs = torch.stack(xs)
            out = torch.mean(xs, dim=0)
            return out

        elif self.aggr == "attn":
            xs = torch.stack(xs, dim=0)
            s = self.attn_proj(xs).squeeze(-1)
            s = torch.mean(s, dim=-1)
            self.alpha = torch.softmax(s, dim=0).detach()
            out = self.alpha.reshape(-1, 1, 1) * xs
            out = torch.sum(out, dim=0)
            return out

In [None]:
def generate_convs(hetero_graph, conv, hidden_size, first_layer=False):
    """
    Generates convolutional layers for each message type in a heterogeneous graph.

    :param hetero_graph: The heterogeneous graph for which convolutions are to be created.
    :param conv: The convolutional layer class or constructor.
    :param hidden_size: The number of features in the hidden layer.
    :param first_layer: Boolean indicating if this is the first layer in the network.

    :return: A dictionary of convolutional layers, keyed by message type.
    """

    convs = {}

    # Extracting all types of messages/edges in the heterogeneous graph.
    all_messages_types = hetero_graph.message_types
    for message_type in all_messages_types:
        # Determine the input feature size for source and destination nodes.
        # If it's the first layer, use the feature size of the nodes.
        # Otherwise, use the hidden size, since from there on the size of embeddings
        # is the same for all nodes.
        if first_layer:
            in_channels_src = hetero_graph.num_node_features(message_type[0])
            in_channels_dst = hetero_graph.num_node_features(message_type[2])
        else:
            in_channels_src = hidden_size
            in_channels_dst = hidden_size
        out_channels = hidden_size

        # Create a convolutional layer for this message type and add it to the dictionary.
        convs[message_type] = conv(in_channels_src, in_channels_dst, out_channels)

    return convs

In [None]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hetero_graph, args, num_layers, aggr="mean", return_embedding=False, mask_unknown=True):
        """
        Initializes the HeteroGNN instance.
        :param hetero_graph: The heterogeneous graph for which convolutions are to be created.
        :param args: Arguments dictionary containing hyperparameters like hidden_size and attn_size.
        :param num_layers: Number of graph convolutional layers.
        :param aggr: Aggregation method 'mean' or 'attn', defaults to 'mean'.
        :param return_embedding: Boolean indicating if the model should return embeddings or predictions.
        :param mask_unknown: Boolean indicating if the model should mask unknown nodes (with target -1) when calculating loss.
        """
        super(HeteroGNN, self).__init__()

        self.aggr = aggr
        self.hidden_size = args["hidden_size"]
        self.num_layers = num_layers
        self.return_embedding = return_embedding
        self.mask_unknown = mask_unknown

        # Use a single ModuleDict for batch normalization and ReLU layers
        self.bns = nn.ModuleDict()
        self.relus = nn.ModuleDict()
        self.convs = nn.ModuleList()
        self.fc = nn.ModuleDict()  # Prediction heads

        # Initialize graph convolutional layers for each layer and message type
        for i in range(self.num_layers):
            first_layer = i == 0
            conv = HeteroGNNWrapperConv(
                    generate_convs(hetero_graph, HeteroGNNConv, self.hidden_size, first_layer),
                    args,
                    self.aggr)
            self.convs.append(conv)

        # Initialize batch normalization and ReLU layers for each layer and node type
        all_node_types = hetero_graph.node_types
        for i in range(self.num_layers):
            for node_type in all_node_types:
                key_bn = f"bn_{i}_{node_type}"
                key_relu = f"relu_{i}_{node_type}"
                self.bns[key_bn] = nn.BatchNorm1d(self.hidden_size, eps=1.0)
                self.relus[key_relu] = nn.LeakyReLU()

        # Initialize fully connected layers for each node type
        for node_type in all_node_types:
            self.fc[node_type] = nn.Linear(self.hidden_size, 1)

    def forward(self, node_feature, edge_index):
        """
        Forward pass of the model.

        :param node_feature: Dictionary of node features for each node type.
        :param edge_index: Dictionary of edge indices for each message type.
        :return: The output embeddings for each node type after passing through the model.
        """
        x = node_feature

        # Apply graph convolutional, batch normalization, and ReLU layers
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)  # Apply the i-th graph convolutional layer
            for node_type in x:
                key_bn = f"bn_{i}_{node_type}"
                key_relu = f"relu_{i}_{node_type}"
                x[node_type] = self.bns[key_bn](
                    x[node_type]
                )  # Apply batch normalization
                x[node_type] = self.relus[key_relu](x[node_type])  # Apply ReLU

        if self.return_embedding:
            return x

        # Apply the prediction head (linear layer)
        for node_type in x:
            x[node_type] = self.fc[node_type](x[node_type])

        return x

    def loss(self, preds, y, indices):
        """
        Computes the loss for the model.

        :param preds: Predictions made by the model.
        :param y: Ground truth target values.
        :param indices: Indices of nodes for which loss should be calculated.

        :return: The computed loss value.
        """

        # mape = MeanAbsolutePercentageError().to(train_args["device"])

        loss = 0
        loss_func = torch.nn.MSELoss()

        # loss_func = mape
        # MAPE PRODUCES BETTER EVAL RESULTS BUT WORSE PREDICTIONS

        if self.mask_unknown:
            mask = y["event"][indices["event"], 0] != -1
            non_zero_idx = torch.masked_select(indices["event"], mask)

            loss += loss_func(preds["event"][non_zero_idx], y["event"][non_zero_idx])
        else:
            # TODO: check if this is correct
            idx = indices["event"]
            loss += loss_func(preds["event"][id , y["event"][idx]])

        return loss


In [None]:
def train(model, optimizer, hetero_graph, train_idx):
    """
    Trains the model on the given heterogeneous graph using the specified indices.

    :param model: The graph neural network model to train.
    :param optimizer: The optimizer used for training the model.
    :param hetero_graph: The heterogeneous graph data.
    :param train_idx: Indices for training nodes.

    :return: The training loss as a float.
    """

    model.train()  # Set the model to training mode
    optimizer.zero_grad()  # Zero out any existing gradients

    # Compute predictions using the model
    # TODO: Use only train_idx instead of edge_index
    # TODO: Train only on events not on concepts

    preds = model(hetero_graph.node_feature, hetero_graph.edge_index)

    # Compute the loss using model's loss function
    loss = model.loss(preds, hetero_graph.node_target, train_idx)

    loss.backward()  # Backward pass: compute gradient of the loss
    optimizer.step()  # Perform a single optimization step, updates parameters

    return loss.item()


def test(model, graph, indices, best_model, best_tvt_scores):
    """
    Tests the model on given indices and updates the best model based on validation loss.

    :param model: The trained graph neural network model.
    :param graph: The heterogeneous graph data.
    :param indices: List of indices for training, validation, and testing nodes.
    :param best_model: The current best model based on validation loss.
    :param best_val: The current best validation loss.

    :return: A tuple containing the list of losses for each dataset, the best model, and the best validation loss.
    """

    model.eval()  # Set the model to evaluation mode
    tvt_scores = []

    # Evaluate the model on each set of indices
    for index in indices:
        preds = model(graph.node_feature, graph.edge_index)

        idx = index["event"]

        # mask = y['event'][indices['event'], 0] != -1
        # non_zero_idx = torch.masked_select(indices['event'], mask)
        # preds['event'][non_zero_idx], y['event'][non_zero_idx]

        # non_zero_targets = torch.masked_select(graph.node_target['event'][indices['event']], mask)
        # non_zero_truth = torch.masked_select(graph.node_target['event'][indices['event']], mask)

        mask = graph.node_target["event"][idx, 0] != -1
        non_zero_idx = torch.masked_select(idx, mask)

        L1 = (
            torch.sum(
                torch.abs(
                    preds["event"][non_zero_idx]
                    - graph.node_target["event"][non_zero_idx]
                )
            )
            / non_zero_idx.shape[0]
        )

        meanRelative = (
            torch.sum(
                torch.abs(
                    (
                        preds["event"][non_zero_idx]
                        - graph.node_target["event"][non_zero_idx]
                    )
                    / graph.node_target["event"][non_zero_idx]
                )
            )
            / non_zero_idx.shape[0]
        )

        tvt_scores.append((L1, meanRelative))

    # Update the best model and validation loss if the current model performs better
    if tvt_scores[1][0] < best_tvt_scores[1][0]:
        best_tvt_scores = tvt_scores
        # torch.to_pickle(model, 'best_model.pkl')
        # model.to_pickle('best_model.pkl')

        # best_model = copy.deepcopy(model)
        torch.save(model.state_dict(), "./best_model.pkl")

    return tvt_scores, best_tvt_scores, best_model


In [None]:
def create_split(hetero_graph):
    nEvents = hetero_graph.num_nodes("event")
    nConcepts = hetero_graph.num_nodes("concept")

    s1 = 0.7
    s2 = 0.8

    train_idx = {
        "event": torch.tensor(range(0, int(nEvents * s1))).to(train_args["device"]),
        "concept": torch.tensor(range(0, int(nConcepts * s1))).to(train_args["device"]),
    }
    val_idx = {
        "event": torch.tensor(range(int(nEvents * s1), int(nEvents * s2))).to(
            train_args["device"]
        ),
        "concept": torch.tensor(range(int(nConcepts * s1), int(nConcepts * s2))).to(
            train_args["device"]
        ),
    }
    test_idx = {
        "event": torch.tensor(range(int(nEvents * s2), nEvents)).to(
            train_args["device"]
        ),
        "concept": torch.tensor(range(int(nConcepts * s2), nConcepts)).to(
            train_args["device"]
        ),
    }

    return [train_idx, val_idx, test_idx]

In [None]:
def train_model(hetero_graph):
    best_model = None
    best_tvt_scores = (
        (float("inf"), float("inf")),
        (float("inf"), float("inf")),
        (float("inf"), float("inf")),
    )

    model = HeteroGNN(
        hetero_graph,
        train_args,
        num_layers=train_args["num_layers"],
        aggr=train_args["aggr"],
        return_embedding=True,
    ).to(train_args["device"])

    train_idx, val_idx, test_idx = create_split(hetero_graph)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=train_args["lr"], weight_decay=train_args["weight_decay"]
    )

    for epoch in range(train_args["epochs"]):
        # Train
        loss = train(model, optimizer, hetero_graph, train_idx)
        # Test for the accuracy of the model
        cur_tvt_scores, best_tvt_scores, best_model = test(
            model,
            hetero_graph,
            [train_idx, val_idx, test_idx],
            best_model,
            best_tvt_scores,
        )
        print(
            f"""Epoch: {epoch} Loss: {loss:.4f}
            Train: Abs={cur_tvt_scores[0][0].item():.4f} Rel={cur_tvt_scores[0][1].item():.4f}
            Val: Abs={cur_tvt_scores[1][0].item():.4f} Rel={cur_tvt_scores[1][1].item():.4f}
            Test: Abs={cur_tvt_scores[2][0].item():.4f} Rel={cur_tvt_scores[2][1].item():.4f}"""
        )

    print(
        f"""Best model
            Train: Abs={best_tvt_scores[0][0].item():.4f} Rel={best_tvt_scores[0][1].item():.4f}
            Val: Abs={best_tvt_scores[1][0].item():.4f} Rel={best_tvt_scores[1][1].item():.4f}
            Test: Abs={best_tvt_scores[2][0].item():.4f} Rel={best_tvt_scores[2][1].item():.4f}"""
    )

    model = HeteroGNN(
        hetero_graph,
        train_args,
        num_layers=train_args["num_layers"],
        aggr=train_args["aggr"],
    ).to(train_args["device"])

    model.load_state_dict(torch.load("./best_model.pkl"))

    preds = model(hetero_graph.node_feature, hetero_graph.edge_index)

    cur_tvt_scores, best_tvt_scores, best_model = test(
        model, hetero_graph, [train_idx, val_idx, test_idx], best_model, best_tvt_scores
    )

    display_predictions(preds, hetero_graph, test_idx)


def display_predictions(preds, hetero_graph, test_idx):
    for i in range(test_idx["event"].shape[0]):
        if hetero_graph.node_target["event"][test_idx["event"]][i] != -1:
            print(
                i,
                preds["event"][test_idx["event"]][i],
                hetero_graph.node_target["event"][test_idx["event"]][i],
            )
