In [None]:
!pip install -q torch-geometric==2.5.2 gensim transformers wandb optuna optuna_integration

In [None]:
import torch
import os 
import json
import torch_geometric
import re
import gc
import wandb
import optuna
import warnings
import time

from torch.utils.data import TensorDataset, DataLoader, WeightedRandomSampler
from torch_geometric.data import Data, HeteroData, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn import to_hetero
from gensim.models import Word2Vec
from collections import defaultdict, Counter
from tqdm import tqdm
from sklearn.metrics import ndcg_score
from itertools import groupby, permutations
from transformers import AutoTokenizer, AutoModel
from optuna.integration.wandb import WeightsAndBiasesCallback

import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import networkx as nx
import numpy as np
import pandas as pd
import torch.nn as nn
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data

In [None]:
torch_geometric.__version__, torch.__version__

In [None]:
wandb.login()

In [None]:
device = ("cuda:0" if torch.cuda.is_available() else "cpu")
device, torch.cuda.get_device_name(0)

In [None]:
trainloader = torch.load("./dataloaders/graph_trainloader.pth")
valloader = torch.load("./dataloaders/graph_valloader.pth")

In [None]:
import copy
import warnings
from typing import Any, Dict, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import Module, Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.fx import Transformer
from torch_geometric.typing import EdgeType, Metadata, NodeType, SparseTensor
from torch_geometric.utils.hetero import get_unused_node_types

try:
    from torch.fx import Graph, GraphModule, Node
except (ImportError, ModuleNotFoundError, AttributeError):
    GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node'


def to_hetero_with_bases(module: Module, metadata: Metadata, num_bases: int,
                         in_channels: Optional[Dict[str, int]] = None,
                         input_map: Optional[Dict[str, str]] = None,
                         debug: bool = False) -> GraphModule:

    transformer = ToHeteroWithBasesTransformer(module, metadata, num_bases,
                                               in_channels, input_map, debug)
    return transformer.transform()



class ToHeteroWithBasesTransformer(Transformer):
    def __init__(
        self,
        module: Module,
        metadata: Metadata,
        num_bases: int,
        in_channels: Optional[Dict[str, int]] = None,
        input_map: Optional[Dict[str, str]] = None,
        debug: bool = False,
    ):
        super().__init__(module, input_map, debug)

        self.metadata = metadata
        self.num_bases = num_bases
        self.in_channels = in_channels or {}
        assert len(metadata) == 2
        assert len(metadata[0]) > 0 and len(metadata[1]) > 0

        self.validate()

        # Compute IDs for each node and edge type:
        self.node_type2id = {k: i for i, k in enumerate(metadata[0])}
        self.edge_type2id = {k: i for i, k in enumerate(metadata[1])}

    def validate(self):
        unused_node_types = get_unused_node_types(*self.metadata)
        if len(unused_node_types) > 0:
            warnings.warn(
                f"There exist node types ({unused_node_types}) whose "
                f"representations do not get updated during message passing "
                f"as they do not occur as destination type in any edge type. "
                f"This may lead to unexpected behavior.")

        names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
        for name in names:
            if not name.isidentifier():
                warnings.warn(
                    f"The type '{name}' contains invalid characters which "
                    f"may lead to unexpected behavior. To avoid any issues, "
                    f"ensure that your types only contain letters, numbers "
                    f"and underscores.")

    def transform(self) -> GraphModule:
        self._node_offset_dict_initialized = False
        self._edge_offset_dict_initialized = False
        self._edge_type_initialized = False
        out = super().transform()
        del self._node_offset_dict_initialized
        del self._edge_offset_dict_initialized
        del self._edge_type_initialized
        return out

    def placeholder(self, node: Node, target: Any, name: str):
        if node.type is not None:
            Type = EdgeType if self.is_edge_level(node) else NodeType
            node.type = Dict[Type, node.type]

        out = node

        # Create `node_offset_dict` and `edge_offset_dict` dictionaries in case
        # they are not yet initialized. These dictionaries hold the cumulated
        # sizes used to create a unified graph representation and to split the
        # output data.
        if self.is_edge_level(node) and not self._edge_offset_dict_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function',
                                         target=get_edge_offset_dict,
                                         args=(node, self.edge_type2id),
                                         name='edge_offset_dict')
            self._edge_offset_dict_initialized = True

        elif not self._node_offset_dict_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function',
                                         target=get_node_offset_dict,
                                         args=(node, self.node_type2id),
                                         name='node_offset_dict')
            self._node_offset_dict_initialized = True

        # Create a `edge_type` tensor used as input to `HeteroBasisConv`:
        if self.is_edge_level(node) and not self._edge_type_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function', target=get_edge_type,
                                         args=(node, self.edge_type2id),
                                         name='edge_type')
            self._edge_type_initialized = True

        # Add `Linear` operation to align features to the same dimensionality:
        if name in self.in_channels:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_module',
                                         target=f'align_lin__{name}',
                                         args=(node, ),
                                         name=f'{name}__aligned')
            self._state[out.name] = self._state[name]

            lin = LinearAlign(self.metadata[int(self.is_edge_level(node))],
                              self.in_channels[name])
            setattr(self.module, f'align_lin__{name}', lin)

        # Perform grouping of type-wise values into a single tensor:
        if self.is_edge_level(node):
            self.graph.inserting_after(out)
            out = self.graph.create_node(
                'call_function', target=group_edge_placeholder,
                args=(out if name in self.in_channels else node,
                      self.edge_type2id,
                      self.find_by_name('node_offset_dict')),
                name=f'{name}__grouped')
            self._state[out.name] = 'edge'

        else:
            self.graph.inserting_after(out)
            out = self.graph.create_node(
                'call_function', target=group_node_placeholder,
                args=(out if name in self.in_channels else node,
                      self.node_type2id), name=f'{name}__grouped')
            self._state[out.name] = 'node'

        self.replace_all_uses_with(node, out)

    def call_message_passing_module(self, node: Node, target: Any, name: str):
        # Call the `HeteroBasisConv` wrapper instead instead of a single
        # message passing layer. We need to inject the `edge_type` as first
        # argument in order to do so.
        node.args = (self.find_by_name('edge_type'), ) + node.args

    def output(self, node: Node, target: Any, name: str):
        # Split the output to dictionaries, holding either node type-wise or
        # edge type-wise data.
        def _recurse(value: Any) -> Any:
            if isinstance(value, Node) and self.is_edge_level(value):
                self.graph.inserting_before(node)
                return self.graph.create_node(
                    'call_function', target=split_output,
                    args=(value, self.find_by_name('edge_offset_dict')),
                    name=f'{value.name}__split')

                pass
            elif isinstance(value, Node):
                self.graph.inserting_before(node)
                return self.graph.create_node(
                    'call_function', target=split_output,
                    args=(value, self.find_by_name('node_offset_dict')),
                    name=f'{value.name}__split')

            elif isinstance(value, dict):
                return {k: _recurse(v) for k, v in value.items()}
            elif isinstance(value, list):
                return [_recurse(v) for v in value]
            elif isinstance(value, tuple):
                return tuple(_recurse(v) for v in value)
            else:
                return value

        if node.type is not None and isinstance(node.args[0], Node):
            output = node.args[0]
            Type = EdgeType if self.is_edge_level(output) else NodeType
            node.type = Dict[Type, node.type]
        else:
            node.type = None

        node.args = (_recurse(node.args[0]), )

    def init_submodule(self, module: Module, target: str) -> Module:
        if not isinstance(module, MessagePassing):
            return module

        # Replace each `MessagePassing` module by a `HeteroBasisConv` wrapper:
        return HeteroBasisConv(module, len(self.metadata[1]), self.num_bases)


###############################################################################


class HeteroBasisConv(torch.nn.Module):
    # A wrapper layer that applies the basis-decomposition technique to a
    # heterogeneous graph.
    def __init__(self, module: MessagePassing, num_relations: int,
                 num_bases: int):
        super().__init__()

        self.num_relations = num_relations
        self.num_bases = num_bases

        # We make use of a post-message computation hook to inject the
        # basis re-weighting for each individual edge type.
        # This currently requires us to set `conv.fuse = False`, which leads
        # to a materialization of messages.
        def hook(module, inputs, output):
            assert isinstance(module._edge_type, Tensor)
            if module._edge_type.size(0) != output.size(0):
                raise ValueError(
                    f"Number of messages ({output.size(0)}) does not match "
                    f"with the number of original edges "
                    f"({module._edge_type.size(0)}). Does your message "
                    f"passing layer create additional self-loops? Try to "
                    f"remove them via 'add_self_loops=False'")
            weight = module.edge_type_weight.view(-1)[module._edge_type]
            weight = weight.view([-1] + [1] * (output.dim() - 1))
            return weight * output

        params = list(module.parameters())
        device = params[0].device if len(params) > 0 else 'cpu'

        self.convs = torch.nn.ModuleList()
        for _ in range(num_bases):
            conv = copy.deepcopy(module)
            conv.fuse = False  # Disable `message_and_aggregate` functionality.
            # We learn a single scalar weight for each individual edge type,
            # which is used to weight the output message based on edge type:
            conv.edge_type_weight = Parameter(
                torch.empty(1, num_relations, device=device))
            conv.register_message_forward_hook(hook)
            self.convs.append(conv)

        if self.num_bases > 1:
            self.reset_parameters()

    def reset_parameters(self):
        for conv in self.convs:
            if hasattr(conv, 'reset_parameters'):
                conv.reset_parameters()
            elif sum([p.numel() for p in conv.parameters()]) > 0:
                warnings.warn(
                    f"'{conv}' will be duplicated, but its parameters cannot "
                    f"be reset. To suppress this warning, add a "
                    f"'reset_parameters()' method to '{conv}'")
            torch.nn.init.xavier_uniform_(conv.edge_type_weight)

    def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:
        out = None
        
        attention = []
        
        # Call message passing modules and perform aggregation:
        for conv in self.convs:
            conv._edge_type = edge_type
                        
            res, (edge_ind_exp, att_weight_exp) = conv(*args, **kwargs)
            del conv._edge_type
            
            attention.append(att_weight_exp)
            
            out = res if out is None else out.add_(res)
            
            # jump
        
        return out, (edge_type, edge_ind_exp, torch.mean(torch.stack(attention, dim=0), dim=0))

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(num_relations='
                f'{self.num_relations}, num_bases={self.num_bases})')


class LinearAlign(torch.nn.Module):
    # Aligns representions to the same dimensionality. Note that this will
    # create lazy modules, and as such requires a forward pass in order to
    # initialize parameters.
    def __init__(self, keys: List[Union[NodeType, EdgeType]],
                 out_channels: int):
        super().__init__()
        self.out_channels = out_channels
        self.lins = torch.nn.ModuleDict()
        for key in keys:
            self.lins[key2str(key)] = Linear(-1, out_channels, bias=False)

    def forward(
        self, x_dict: Dict[Union[NodeType, EdgeType], Tensor]
    ) -> Dict[Union[NodeType, EdgeType], Tensor]:
        
        return {key: self.lins[key2str(key)](x) for key, x in x_dict.items()}

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(num_relations={len(self.lins)}, '
                f'out_channels={self.out_channels})')


###############################################################################

# These methods are used in order to receive the cumulated sizes of input
# dictionaries. We make use of them for creating a unified homogeneous graph
# representation, as well as to split the final output data once again.


def get_node_offset_dict(
    input_dict: Dict[NodeType, Union[Tensor, SparseTensor]],
    type2id: Dict[NodeType, int],
) -> Dict[NodeType, int]:
    cumsum = 0
    out: Dict[NodeType, int] = {}
    for key in type2id.keys():
        out[key] = cumsum
        cumsum += input_dict[key].size(0)

    return out


def get_edge_offset_dict(
    input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]],
    type2id: Dict[EdgeType, int],
) -> Dict[EdgeType, int]:
    cumsum = 0
    out: Dict[EdgeType, int] = {}
    for key in type2id.keys():
        out[key] = cumsum
        value = input_dict[key]
        if isinstance(value, SparseTensor):
            cumsum += value.nnz()
        elif value.dtype == torch.long and value.size(0) == 2:
            cumsum += value.size(-1)
        else:
            cumsum += value.size(0)

    return out


###############################################################################

# This method computes the edge type of the final homogeneous graph
# representation. It will be used in the `HeteroBasisConv` wrapper.


def get_edge_type(
    input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]],
    type2id: Dict[EdgeType, int],
) -> Tensor:

    inputs = [input_dict[key] for key in type2id.keys()]
    outs = []

    for i, value in enumerate(inputs):
        if value.size(0) == 2 and value.dtype == torch.long:  # edge_index
            out = value.new_full((value.size(-1), ), i, dtype=torch.long)
        elif isinstance(value, SparseTensor):
            out = torch.full((value.nnz(), ), i, dtype=torch.long,
                             device=value.device())
        else:
            out = value.new_full((value.size(0), ), i, dtype=torch.long)
        outs.append(out)
    
    return outs[0] if len(outs) == 1 else torch.cat(outs, dim=0)


###############################################################################

# These methods are used to group the individual type-wise components into a
# unfied single representation.


def group_node_placeholder(input_dict: Dict[NodeType, Tensor],
                           type2id: Dict[NodeType, int]) -> Tensor:

    inputs = [input_dict[key] for key in type2id.keys()]
    return inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=0)


def group_edge_placeholder(
    input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]],
    type2id: Dict[EdgeType, int],
    offset_dict: Dict[NodeType, int] = None,
) -> Union[Tensor, SparseTensor]:

    inputs = [input_dict[key] for key in type2id.keys()]

    if len(inputs) == 1:
        return inputs[0]

    # In case of grouping a graph connectivity tensor `edge_index` or `adj_t`,
    # we need to increment its indices:
    elif inputs[0].size(0) == 2 and inputs[0].dtype == torch.long:
        if offset_dict is None:
            raise AttributeError(
                "Can not infer node-level offsets. Please ensure that there "
                "exists a node-level argument before the 'edge_index' "
                "argument in your forward header.")

        outputs = []
        for value, (src_type, _, dst_type) in zip(inputs, type2id):
            value = value.clone()
            value[0, :] += offset_dict[src_type]
            value[1, :] += offset_dict[dst_type]
            outputs.append(value)

        return torch.cat(outputs, dim=-1)

    elif isinstance(inputs[0], SparseTensor):
        if offset_dict is None:
            raise AttributeError(
                "Can not infer node-level offsets. Please ensure that there "
                "exists a node-level argument before the 'SparseTensor' "
                "argument in your forward header.")

        # For grouping a list of SparseTensors, we convert them into a
        # unified `edge_index` representation in order to avoid conflicts
        # induced by re-shuffling the data.
        rows, cols = [], []
        for value, (src_type, _, dst_type) in zip(inputs, type2id):
            col, row, value = value.coo()
            assert value is None
            rows.append(row + offset_dict[src_type])
            cols.append(col + offset_dict[dst_type])

        row = torch.cat(rows, dim=0)
        col = torch.cat(cols, dim=0)
        return torch.stack([row, col], dim=0)

    else:
        return torch.cat(inputs, dim=0)


###############################################################################

# This method is used to split the output tensors into individual type-wise
# components:


def split_output(
    output: Tensor,
    offset_dict: Union[Dict[NodeType, int], Dict[EdgeType, int]],
) -> Union[Dict[NodeType, Tensor], Dict[EdgeType, Tensor]]:
    
    # Sometimes an edge index ends up here. Not sure why. TODO: fix --> we should be able to determine which edge belongs
    # to which edge type
    if type(output) == tuple:
        return output
    elif output.size(0) == 2:
        output = output.T
        
    cumsums = list(offset_dict.values()) + [output.size(0)]    
    sizes = [cumsums[i + 1] - cumsums[i] for i in range(len(offset_dict))]
    outputs = output.split(sizes)
    
    return {key: output for key, output in zip(offset_dict, outputs)}


###############################################################################


def key2str(key: Union[NodeType, EdgeType]) -> str:
    key = '__'.join(key) if isinstance(key, tuple) else key
    return key.replace(' ', '_').replace('-', '_').replace(':', '_')

In [None]:
def listwise_loss(scores, labels):
    
    """
    Compute the LambdaRank loss. (assume sigma=1.)
    
    scores: tensor of size [N, 1] (the output of a neural network), where N = length of <query, document> pairs
    labels: tensor of size [N], contains the relevance labels 
    
    returns: a tensor of size [N, 1]
    """
    if labels.size(0) < 2:
        return torch.Tensor([[0]])

    N = torch.arange(len(scores))
    num_docs = len(scores)

    sigma = 1


    # Calculate lambda_{i, j} for every <i, j>.
    S_j = torch.stack([labels] * num_docs)
    S_i = S_j.T
    #TODO: remove torch.nan_to_num? Changing it to fill_diagonal(0) seemed to break it somehow, even though it shouldnt..
    S = torch.nan_to_num((S_i - S_j) / (S_i - S_j).abs())
    lamda = (sigma * (0.5 * (1 - S) - (1 / (1 + torch.exp(sigma * (scores - scores.T)))))) #.sum(axis=1).unsqueeze(1)

    # Calculate abs(Delta-NDCG) for each ordering <i, j> combination
    sorted_ind = torch.flip(scores.argsort(dim=0).flatten(), dims=[0])
    sorted_labels = labels[sorted_ind]
    ideal_labels = torch.sort(labels)[0].flip(dims=[0])
    k = (torch.arange(sorted_labels.shape[0]) + 1).to(device)
    DCG_ideal_labels = torch.sum((2**ideal_labels - 1) / torch.log(k + 1)) 
    doc_id_to_rank = torch.Tensor([(sorted_ind == i).nonzero(as_tuple=True)[0] for i in N]).int()
    doc_id_to_label = torch.Tensor([sorted_labels[R_i] for R_i in doc_id_to_rank]).int().to(device)
        
    #TODO: We always do this stack+transpose, make a function of this? (and can't something like meshgrid() do the same?)
    #TODO: Put comments to explain things.
    R_j = torch.stack([doc_id_to_rank] * num_docs).to(device)
    R_i = R_j.T
    label_j = torch.stack([doc_id_to_label] * num_docs).to(device)
    label_i = label_j.T
    DCG_discount = ((2**label_i - 1) / torch.log(R_i + 2) + (2**label_j - 1) / torch.log(R_j + 2)).to(device)
    DCG_gain = ((2**label_j - 1) / torch.log(R_i + 2) + (2**label_i - 1) / torch.log(R_j + 2)).to(device)
    delta_NDCG = ((DCG_gain - DCG_discount) / DCG_ideal_labels).abs()

    lambda_rank_loss =  (lamda * delta_NDCG).sum(axis=1).unsqueeze(1) 
    
    return lambda_rank_loss

In [None]:
# We embed the textual nodes (candidates and requests) separately at first
class text_embedding_layer(torch.nn.Module):
    def __init__(self, text_embedding_size=64, text_pooling="token"):
        super().__init__()
        
        self.e5 = AutoModel.from_pretrained("intfloat/multilingual-e5-small").to(device)
        
        self.text_pooling = text_pooling
                
        self.candidate_out = nn.Linear(in_features=384,
                                       out_features=text_embedding_size)

        self.company_out = nn.Linear(in_features=384,
                                     out_features=text_embedding_size)
        
    def average_pool(self, last_hidden_states, attention_mask):
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        
    def forward(self, x_can, x_req, att_mask_can, att_mask_req):
        
        # Feed tokens into model
        x_candidate = self.e5(x_can, att_mask_can)
        x_company = self.e5(x_req, att_mask_req)
        
        if self.text_pooling == "token":
            # Create embedding tensor
            candidate_embeddings = self.average_pool(x_candidate.last_hidden_state, attention_mask=att_mask_can)
            company_embeddings = self.average_pool(x_company.last_hidden_state, attention_mask=att_mask_req)

            # normalize embeddings
            candidate_embeddings = F.normalize(candidate_embeddings, p=2, dim=1)
            company_embeddings = F.normalize(company_embeddings, p=2, dim=1)
        elif self.text_pooling == "sentence":
            # Mean pooling
            input_mask_expanded = att_mask_can.unsqueeze(-1).expand(x_candidate.last_hidden_state.size())
            sum_embeddings = torch.sum(x_candidate.last_hidden_state * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)  # Avoid division by zero
            candidate_embeddings = sum_embeddings / sum_mask
            
            input_mask_expanded = att_mask_req.unsqueeze(-1).expand(x_company.last_hidden_state.size())
            sum_embeddings = torch.sum(x_company.last_hidden_state * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)  # Avoid division by zero
            company_embeddings = sum_embeddings / sum_mask

        # Run through MLP to match other embedding sizes
        x_candidate = self.candidate_out(candidate_embeddings).float()
        x_company = self.company_out(company_embeddings).float()
        
        return x_candidate, x_company  
    
# Then, we embed all nodes initially
class embedding_layer(torch.nn.Module):
    def __init__(self, embedding_size=32):
        super().__init__()        
        
        self.conv1 = geom_nn.TransformerConv((-1, -1), embedding_size)
        self.conv2 = geom_nn.TransformerConv((-1, -1), embedding_size)

    def forward(self, x, edge_index):
       
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)

        return x
        
class GNN(torch.nn.Module):
    def __init__(self, embedding_size=64, heads=4):
        super().__init__()

        self.can_pos = geom_nn.GATv2Conv((-1, -1),
                                         out_channels=embedding_size,
                                         add_self_loops=False,
                                         heads=heads,
                                         concat=False)
        
        self.can_neg = geom_nn.GATv2Conv((-1, -1),
                                         out_channels=embedding_size,
                                         add_self_loops=False,
                                         heads=heads,
                                         concat=False)
        self.com_pos = geom_nn.GATv2Conv((-1, -1),
                                         out_channels=embedding_size,
                                         add_self_loops=False,
                                         heads=heads,
                                         concat=False)

        self.com_neg = geom_nn.GATv2Conv((-1, -1),
                                         out_channels=embedding_size,
                                         add_self_loops=False,
                                         heads=heads,
                                         concat=False)
        
        # Different batch norm for each GATv2 output, as it includes learned parameters
        self.batch_norm1 = torch.nn.BatchNorm1d(embedding_size)
        self.batch_norm2 = torch.nn.BatchNorm1d(embedding_size)
        self.batch_norm3 = torch.nn.BatchNorm1d(embedding_size)
        self.batch_norm4 = torch.nn.BatchNorm1d(embedding_size)

        self.dense_can = nn.Linear(in_features=embedding_size,
                                   out_features=heads)
    
        self.dense_com = nn.Linear(in_features=embedding_size,
                                   out_features=heads)
    
        
        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, edge_index):            
        ### Positive candidate-side attention
        x_can_pos, can_pos_exp = self.can_pos(x, 
                                              edge_index.long(), 
                                              return_attention_weights=True)
        
        ### Negative candidate-side attention        
        x_can_neg, can_neg_exp = self.can_neg(x_can_pos, 
                                              edge_index.long(), 
                                              return_attention_weights=True)

        ### Positive company-side attention
        # We flip the edge index to distinguish this as a 'company-side' graph
        x_com_pos, com_pos_exp = self.com_pos(x * -1, 
                                              edge_index[[1, 0]].long(), 
                                              return_attention_weights=True)

        ### Negative company-side attention
        x_com_neg, com_neg_exp = self.com_neg(x_com_pos, 
                                              edge_index[[1, 0]].long(), 
                                              return_attention_weights=True)
        
        # Edge embedding
        e_im_can = self.sigmoid(
                            torch.sum(
                                torch.stack([can_pos_exp[2], 
                                             can_neg_exp[2]], 
                                      dim=0), 
                                dim=0)
                    )        
        
        e_im_com = self.sigmoid(
                            torch.sum(
                                torch.stack([com_pos_exp[2], 
                                             com_neg_exp[2]], 
                                        dim=0), 
                                dim=0)
                    )
        
        x_can_pos = self.batch_norm1(x_can_pos)
        x_can_neg = self.batch_norm2(x_can_neg)
        
        x_com_pos = self.batch_norm3(x_com_pos)
        x_com_neg = self.batch_norm4(x_com_neg)
         
        # Mean pool
        x_can = torch.mean(torch.stack([self.elu(x_can_pos), 
                                        self.elu(x_can_neg)]), dim=0)
        
        x_com = torch.mean(torch.stack([self.elu(x_com_pos), 
                                        self.elu(x_com_neg)]), dim=0)
            
        # Node embedding
        v_im_can = self.dense_can(x_can).relu()
        v_im_com = self.dense_com(x_com).relu()
        
        return x_can, x_com, v_im_can, v_im_com, e_im_can, e_im_com,\
               can_pos_exp, can_neg_exp, com_pos_exp, com_neg_exp
    
class OKRA(torch.nn.Module):
    def __init__(self, data, typings, embedding_size=64, text_embedding_size=64, pooling_method="mean", text_pooling="token", heads=4):
        super().__init__()
        
        self.typings = typings
        self.num_heads = heads
        self.embedding_size = embedding_size
        
        self.pooling = {
            "mean": lambda x, dim: torch.mean(x, dim=dim),
            "sum": lambda x, dim: torch.sum(x, dim=dim),
            # Only return the values for max pooling, ignoring the indices
            "max": lambda x, dim: torch.max(x, dim=dim)[0]
        }[pooling_method]
        
        self.text_embedder = text_embedding_layer(text_embedding_size=text_embedding_size, text_pooling=text_pooling)

        self.embedder = embedding_layer(embedding_size=embedding_size)
        self.embedder = to_hetero(self.embedder, data.metadata(), aggr='sum')

        self.gnn = GNN(embedding_size=embedding_size, heads=heads)
        self.gnn = to_hetero_with_bases(self.gnn, data.metadata(), num_bases=3)
        
        # Each embedding is the size heads * embedding_size * 3, as there is one heads * embedding_size embedding for each (head node, tail node, sub-graph)
        self.mlp_candidate = nn.Linear(in_features=heads * embedding_size * 3,
                                       out_features=1)
        
        self.mlp_company = nn.Linear(in_features=heads * embedding_size * 3,
                                     out_features=1)
        
    def forward(self, data):
        # Embed textual features       
        x_candidate, x_request = self.text_embedder(data.x_dict["candidate"], data.x_dict["request"], data["candidate"].att_mask, data["request"].att_mask)
        
        # Store the textual embeddings along with the rest of the graph
        data.x_dict["candidate"] = x_candidate
        data.x_dict["request"] = x_request

        # Embed the graph as a whole
        embedded_data = self.embedder({k: v.float() for k, v in data.x_dict.items()}, data.edge_index_dict)
       
        # Run the embedded graph through the GNN
        x_can, x_com, v_im_can1, v_im_com1, e_im_can, e_im_com, \
        can_pos_exp, can_neg_exp, com_pos_exp, com_neg_exp = self.gnn(embedded_data, 
                                                                      data.edge_index_dict)
        
        # Combine the attention with the values, once per head, for both the candidate and company side
        h_can = defaultdict(lambda : torch.Tensor([]).to(device))
        h_com = defaultdict(lambda : torch.Tensor([]).to(device))
    
        # Store each node as a combination of its head embeddings
        for typing in self.typings:
            for k in range(self.num_heads):
                if typing in x_can:
                    h_can[typing] = torch.cat([h_can[typing], (x_can[typing].T * v_im_can1[typing][:,k]).T], dim=1)
                else:
                    h_can[typing] = torch.cat([h_can[typing], torch.zeros_like(h_can[list(h_can.keys())[0]].T)])
                
                if typing in x_com:
                    h_com[typing] = torch.cat([h_com[typing], (x_com[typing].T * v_im_com1[typing][:,k]).T], dim=1)
                else:
                    h_com[typing] = torch.cat([h_com[typing], torch.zeros_like(h_com[list(h_com.keys())[0]].T)])
                            
        # Each sub-graph gets its own embedding
        sub_graphs_candidate = defaultdict(list)
        sub_graphs_company = defaultdict(list)
        
        # Additionally, the head and tail node (candidate and vacancy) get stored separately as well
        main_nodes_candidate = defaultdict(list)
        main_nodes_company = defaultdict(list)
        

                                               
        # Find the sub-graph of each node in the embedding, and add it to the corresponding list
        for typing in self.typings:
            for i, emb in enumerate(h_can[typing]):            
                # Some subgraphs do not have all data types (e.g., a graph might not include any education nodes)
                if data[typing]:
                    # Find the sub-graph the current node belongs to
                    current_node_id = int(data[typing].unique_node_id[i].item())
                                        
                    # We were working with a dummy node
                    if current_node_id == 0:
                        continue
                        
                    sg = int(data[typing].sub_graph[i].item())
                    
                    # If our node is a head/tail node, store it accordingly
                    if (in_head := (current_node_id in data.head_nodes[0])) or (in_tail := (current_node_id in data.tail_nodes[0])):                        
                        main_nodes_candidate[sg].append(emb)
                        main_nodes_company[sg].append(h_com[typing][i])
                    
                    # Add its candidate embedding to its sub-graph embedding
                    sub_graphs_candidate[sg].append(emb.unsqueeze(0))

                    # Do the same on the company side
                    sub_graphs_company[sg].append(h_com[typing][i].unsqueeze(0))               

        # Finally, pool every graph embedding (so the final embedding is the mean of all of the nodes)
        for sg in sub_graphs_candidate.keys():            
            sub_graphs_candidate[sg] = self.pooling(torch.stack(sub_graphs_candidate[sg]).squeeze(1), dim=0)
            sub_graphs_company[sg] = self.pooling(torch.stack(sub_graphs_company[sg]).squeeze(1), dim=0)
                                        
            # Add the head and tail node to the full embedding
            sub_graphs_candidate[sg] = torch.cat([torch.cat(main_nodes_candidate[sg], dim=0).squeeze(), sub_graphs_candidate[sg]])
            sub_graphs_company[sg] = torch.cat([torch.cat(main_nodes_company[sg], dim=0).squeeze(), sub_graphs_company[sg]])
                        
        # Stack all the sub-graph embeddings into a single matrix, both candidate- and company-sided
        sub_graphs_candidate = torch.stack([i[1] for i in sorted(sub_graphs_candidate.items())], dim=0)
        sub_graphs_company = torch.stack([i[1] for i in sorted(sub_graphs_company.items())], dim=0)
                
        # Make predictions based on the sub-graph embeddings
        y_candidate = torch.clamp(self.mlp_candidate(sub_graphs_candidate), min=-100, max=100)
        y_company = torch.clamp(self.mlp_company(sub_graphs_company), min=-100, max=100)
        
        # Final prediction is the harmonic mean of the candidate- and company-sided prediction
        y_pred = 2 * ((y_candidate * y_company) / (y_candidate + y_company))
        
        # The harmonic mean of X and 0 should be 0, not nan
        y_pred = torch.nan_to_num(y_pred).squeeze()
        
        return y_pred, y_candidate, y_company, (can_pos_exp, can_neg_exp, com_pos_exp, com_neg_exp, v_im_can1, v_im_com1, e_im_can, e_im_com)

In [None]:
def train_loop(model, optimizer, trainloader, valloader, epochs=10):
    ndcg_scores = []
    
    for epoch in range(epochs):
        for i, data in enumerate(trainloader):  

            # Make prediction
            y_pred, y_candidate, y_company, explanation = model(data.detach().clone().to(device))

            print("                                                                                                                    ", end="\r")
            print(f"Epoch: {epoch + 1}/{epochs}, batch (train): {i + 1}/{len(trainloader)}, y_pred mean: {y_pred.mean()}", end="\r")

            # Calculate and backpropagate gradients
            optimizer.zero_grad()
            lambda_i = listwise_loss(y_pred, data.y.to(device))
            torch.autograd.backward(y_pred, lambda_i.squeeze())
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()

            # Log loss in wandb
            wandb.log({"loss": lambda_i.squeeze().mean()})

            # Calculate nDCG score of current batch
            ndcg_scores.append(ndcg_score(data.y.unsqueeze(0).cpu(), 
                                          y_pred.unsqueeze(0).detach().cpu(), k=10))
            
        # Log epoch-level metrics to WandB
        wandb.log({"Epoch": epoch+1, "Training nDCG": np.mean(ndcg_scores)})
        print(f"\n\nTraining nDCG: {np.mean(ndcg_scores)}\n")
        ndcg_scores = []
        
        # Evaluate model
        ndcg_val = val_loop(model, valloader)
        print(f"\nValidation nDCG: {np.mean(ndcg_val)}\n")
        wandb.log({"Validation nDCG": np.mean(ndcg_val)})
        
    # Return nDCG of final trained model
    return ndcg_val

def val_loop(model, valloader):
    ndcg_scores = []

    with torch.no_grad():
        for i, data_val in enumerate(valloader):
            print(f"Batch (val): {i + 1}/{len(valloader)}", end="\r")
            
            # Make prediction
            y_pred_val, y_candidate_val, y_company_val, explanation_val = model(data_val.detach().clone().to(device))
             
            # Calculate nDCG score of current batch   
            ndcg_scores.append(ndcg_score(data_val.y.unsqueeze(0).cpu(), 
                                           y_pred_val.unsqueeze(0).detach().cpu(), k=10))
            
    return ndcg_scores


def optimize_model(trial, trainloader, valloader, epochs=10):
    
    # Set up WandB integration
    wandb.init(project="okra", job_type="optimize")
    
    # Data.metadata() is needed to initialize the heterodata
    data = next(iter(trainloader))

    # Search space    
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
    text_embedding_size = trial.suggest_categorical('text_embedding_size', [32, 128, 256])
    embedding_size = trial.suggest_categorical('embedding_size', [32, 128, 256])
    pooling_method = trial.suggest_categorical('pooling_method', ["mean", "max", "sum"])
    text_pooling = trial.suggest_categorical('text_pooling', ["token", "sentence"])             
                                        
    # Configuration for WandB
    config = {
        "learning_rate": learning_rate,
        "text_embedding_size": text_embedding_size,
        "embedding_size": embedding_size,
        "pooling_method": pooling_method,
        "text_pooling": text_pooling
    }
    wandb.config.update(config)

    print(f"""\nConfig:\n- learning_rate = {learning_rate}\n- text_embedding_size = {text_embedding_size}\n- embedding_size = {embedding_size}\n- pooling_method = {pooling_method}\n- text_pooling = {text_pooling}\n""")
    
    # All the different node types
    typings = ["candidate", "request", "function_name", "isco_code", 
               "education", "language", "license", "skill", "company_name", 
               "function_id", "isco_level", "workgroup", "klass", "literal"]
    
    # Initiate the model (number of heads is locked, as that is required for the multi-explanation component to function)
    model = OKRA(data,
                 typings,
                 text_embedding_size=text_embedding_size,
                 text_pooling=text_pooling,
                 embedding_size=embedding_size,
                 pooling_method=pooling_method,
                 heads=4).to(device)

    # Configure Adam
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    start_time = time.time() 
    # Train and evaluate model
    ndcg_scores_val = train_loop(model, optimizer, trainloader, valloader, epochs=epochs)
    
    end_time = time.time()
    
    print(f"Training for {epochs} epochs took {end_time - start_time} seconds ({(end_time - start_time) / epochs} seconds per epoch)")
    
    return np.mean(ndcg_scores_val)

In [None]:
def objective_wrapper(trainloader, valloader):
    def objective(trial):
        return optimize_model(trial, trainloader, valloader, epochs=6)
    
    return objective

In [None]:
torch.cuda.empty_cache() 
gc.collect()

# Hide user/future warnings
warnings.filterwarnings('ignore')

# Define the Optuna study
study = optuna.create_study(direction='maximize')

# We need to provide trainloader and valloader to the training/validation loop
wrapped_objective = objective_wrapper(trainloader, valloader)

# Start optimization
study.optimize(wrapped_objective, n_trials=32)  

print("Best hyperparameters:", study.best_trial.params)

with open("okra_results.txt", "w+") as f:
    json.dump(study.best_trial.params, f)