In [None]:
!pip install -q torch_geometric transformers gensim

In [None]:
import torch
import os
import json
import re
import gc
import torch
import gensim
import gc
import time 

import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data

from collections import defaultdict
from tqdm.notebook import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import ndcg_score, average_precision_score
from gensim.test.utils import common_texts
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from transformers import BertTokenizer, BertModel
from torch_geometric.data import Data, HeteroData, Batch
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn import to_hetero

# Test set performance

### Preparing data

In [None]:
li = []

# Retrieve ground truth values for each candidate-vacancy pair
for truth in os.listdir("./graph_data/ground_truth"):
    if ".csv" in truth:
        df = pd.read_csv(f"./graph_data/ground_truth/{truth}", header=None)
        li.append(df)
        
# Create truth_dict
truths = pd.concat(li, axis=0, ignore_index=True)
truth_dict = {key1: dict(group[[1, 2]].values) for key1, group in truths.groupby(0)}

# Some candidate-vacancy pairs are not available as a graph, so we filter those out here as well
with open("./data_subset.json") as f:
    data_subset = json.load(f)
    
# Filter to relevant graphs
truth_dict = {k: {graph: label for graph, label in v.items() if k in data_subset and graph in data_subset[k]} for k, v in truth_dict.items()}
truth_dict = {k: v for k, v in truth_dict.items() if not all([i <= 0 for i in v.values()])}

data = pd.read_csv("./data/cv-vacancy-pairs.csv")

relevant_candidates = set(truth_dict.keys())
relevant_vacancies = [set(v.keys()) for _, v in truth_dict.items()]
relevant_vacancies = set([item for sublist in relevant_vacancies for item in sublist])

cv_data = {}
req_data = {}

for row in data.itertuples():   
        if row[4] not in cv_data and f"c{row[4]}" in relevant_candidates:
            cv_data[row[4]] = re.sub(r"\s+", " ", re.sub("\n+", "\n ", re.sub(r"\W", " ", row[8]))).lower()

        if (row[5] not in req_data) and f"r{row[5]}" in relevant_vacancies:
            req_data[row[5]] = re.sub(r"\s+", " ", re.sub("\n+", "\n ", re.sub(r"\W", " ", row[7]))).lower()
            
training_slice = list(truth_dict.items())[:int(len(truth_dict) * 0.8)]
val_slice = list(truth_dict.items())[int(len(truth_dict) * 0.8):int(len(truth_dict) * 0.9)]
test_slice = list(truth_dict.items())[int(len(truth_dict) * 0.9):]

## Random

In [None]:
# The first 80% of the data will be used to train the TF-IDF vectorizer
training_set = []

for candidate, vacancies in training_slice:
    training_set.append(cv_data[int(candidate[1:])])
    
    for vacancy in vacancies:
        training_set.append(req_data[int(vacancy[1:])])
        
# Last 10% is the test set
test_set = []

for candidate, vacancies in test_slice:
    batch = []
    
    batch.append((candidate, cv_data[int(candidate[1:])]))
    
    for vacancy, label in vacancies.items():
        batch.append((vacancy, req_data[int(vacancy[1:])], label if label >= 0 else 0))
        
    test_set.append(batch)

In [None]:
ndcg10_scores = []
ndcg5_scores = []
ndcg3_scores = []

# Embed each CV and its corresponding vacancies using the model, calculate cosine similarity
# and evaluate using nDCG
for batch in tqdm(test_set):
    
    ground_truth = []
    y_pred = []
    
    cv_emb = np.random.random(32)
    
    for _, vacancy, label in batch[1:]:        
        req_emb = np.random.random(32)
        
        ground_truth.append(label)
        y_pred.append(cosine_similarity([cv_emb], [req_emb])[0][0])
    
    ground_truth = np.array(ground_truth)
    y_pred = np.array(y_pred)

    ndcg10_scores.append(ndcg_score([ground_truth], [y_pred], k=10))
    ndcg5_scores.append(ndcg_score([ground_truth], [y_pred], k=5))
    ndcg3_scores.append(ndcg_score([ground_truth], [y_pred], k=3))
    
print("Random test set score:", f"nDCG@10: {np.mean(ndcg10_scores)}, nDCG@5: {np.mean(ndcg5_scores)}, nDCG@3: {np.mean(ndcg3_scores)}")

## TF-IDF

In [None]:
before = time.time()

# Initialize model
vectorizer = TfidfVectorizer()
tf_idf_model = vectorizer.fit(training_set)

print(time.time() - before)

ndcg10_scores = []
ndcg5_scores = []
ndcg3_scores = []


# Embed each CV and its corresponding vacancies using the model, calculate cosine similarity
# and evaluate using nDCG
for batch in tqdm(test_set):
    
    ground_truth = []
    y_pred = []
    
    cv_emb = tf_idf_model.transform([batch[0][1]])
    
    for _, vacancy, label in batch[1:]:        
        req_emb = tf_idf_model.transform([vacancy])
        
        ground_truth.append(label)
        y_pred.append(cosine_similarity(cv_emb, req_emb)[0][0])
    
    ground_truth = np.array(ground_truth)
    y_pred = np.array(y_pred)

    # Calculate nDCG score of current batch        
    ndcg10_scores.append(ndcg_score([ground_truth], [y_pred], k=10))
    ndcg5_scores.append(ndcg_score([ground_truth], [y_pred], k=5))
    ndcg3_scores.append(ndcg_score([ground_truth], [y_pred], k=3))

print("TF-IDF test set score:", f"nDCG@10: {np.mean(ndcg10_scores)}, nDCG@5: {np.mean(ndcg5_scores)}, nDCG@3: {np.mean(ndcg3_scores)}")

## Doc2Vec

In [None]:
best_config = {"min_count": 5, "window_size": 10, "vector_size": 32, "epochs": 40}

In [None]:
before = time.time()

documents = [TaggedDocument(gensim.utils.simple_preprocess(doc), [i]) for i, doc in enumerate(training_set)]

d2v_model = Doc2Vec(vector_size=best_config["vector_size"], window=best_config["window_size"], min_count=best_config["min_count"], epochs=best_config["epochs"], workers=4)
d2v_model.build_vocab(documents)
d2v_model.train(documents, total_examples=d2v_model.corpus_count, epochs=d2v_model.epochs)

print(time.time() - before)

ndcg10_scores = []
ndcg5_scores = []
ndcg3_scores = []

for batch in test_set:
    ground_truth = []
    y_pred = []

    cv_emb = d2v_model.infer_vector(gensim.utils.simple_preprocess(batch[0][1]))

    for _, vacancy, label in batch[1:]:
        vacancy_emb = d2v_model.infer_vector(gensim.utils.simple_preprocess(vacancy))
        ground_truth.append(label)
        y_pred.append(cosine_similarity([cv_emb], [vacancy_emb])[0][0])

    ndcg10_scores.append(ndcg_score([ground_truth], [y_pred], k=10))
    ndcg5_scores.append(ndcg_score([ground_truth], [y_pred], k=5))
    ndcg3_scores.append(ndcg_score([ground_truth], [y_pred], k=3))

print("D2V test set score:", f"nDCG@10: {np.mean(ndcg10_scores)}, nDCG@5: {np.mean(ndcg5_scores)}, nDCG@3: {np.mean(ndcg3_scores)}")

## e5

In [None]:
e5_best_config = {"epochs": 1, "pooling": "mean", "learning_rate": 0.0000018572950835516}

In [None]:
class e5_ranker(torch.nn.Module):
    def __init__(self, pooling="mean"):
        super().__init__()
        self.model = AutoModel.from_pretrained("intfloat/multilingual-e5-small")
        self.pooling = pooling

    def forward(self, batch):
        
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        # Process all embeddings in one go
        outputs = self.model(input_ids, attention_mask=attention_mask)
        
        if self.pooling == "mean":
            # Mean pooling
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size())
            sum_embeddings = torch.sum(outputs.last_hidden_state * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)  # Avoid division by zero
            embeddings = sum_embeddings / sum_mask
        elif self.pooling == "sum":
            # Sum pooling
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size())
            embeddings = torch.sum(outputs.last_hidden_state * input_mask_expanded, 1)
        elif self.pooling == "max":
            # Max pooling
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size()).bool()
            masked_embeddings = outputs.last_hidden_state * input_mask_expanded  # Apply mask to zero out padding tokens
            embeddings, _ = torch.max(masked_embeddings, dim=1)  # Obtain max across the sequence dimension
        
        # Extract CV and request embeddings
        cv_embedding = embeddings[0].unsqueeze(0)  # CV is the first in the batch
        req_embeddings = embeddings[1:]  # rest are requests
            
        # Use the cosine similarity as the score (based on the paper)
        return F.cosine_similarity(cv_embedding, req_embeddings).squeeze()

In [None]:
class TokenDataLoader(Dataset):
    def __init__(self, truth_dict, cv_data, req_data, query_size=512, batch_size=32):
        self.ground_truths = list(truth_dict.items())
        self.cv_texts = cv_data
        self.req_texts = req_data
        
        self.query_size = query_size
        self.tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-small")
        self.batch_size = batch_size

    def __len__(self):
        # Adjust the length to account for the number of batches based on candidates
        return (len(self.ground_truths) + self.batch_size - 1) // self.batch_size

    def __getitem__(self, idx):

        candidate, vacancies = self.ground_truths[idx]
        req_ids, labels = vacancies.keys(), list(vacancies.values())

        candidate_text = "query: " + self.cv_texts[int(candidate[1:])]
        vacancy_texts = ["passage: " + self.req_texts[int(v[1:])] for v in req_ids]
        input_texts = [candidate_text] + vacancy_texts
        
        # Tokenize together to ensure consistent padding
        tokens = self.tokenizer(input_texts, add_special_tokens=True, padding=True, truncation=True, max_length=self.query_size, return_tensors='pt').to(device)

        # Differentiating between -1 and 0 is practically impossible, so they are considered to be the same
        labels = [i if i >=0 else 0 for i in labels]

        return (candidate, list(req_ids)), tokens, torch.LongTensor(labels).to(device)

In [None]:
device = ("cuda:0" if torch.cuda.is_available() else "cpu")
device, torch.cuda.get_device_name(0)

In [None]:
def listwise_loss(scores, labels):
    
    """
    Compute the LambdaRank loss. (assume sigma=1.)
    
    scores: tensor of size [N, 1] (the output of a neural network), where N = length of <query, document> pairs
    labels: tensor of size [N], contains the relevance labels 
    
    returns: a tensor of size [N, 1]
    """
    if labels.size(0) < 2:
        return torch.Tensor([[0]])

    N = torch.arange(len(scores))
    num_docs = len(scores)

    sigma = 1


    # Calculate lambda_{i, j} for every <i, j>.
    S_j = torch.stack([labels] * num_docs)
    S_i = S_j.T
    #TODO: remove torch.nan_to_num? Changing it to fill_diagonal(0) seemed to break it somehow, even though it shouldnt..
    S = torch.nan_to_num((S_i - S_j) / (S_i - S_j).abs())
    lamda = (sigma * (0.5 * (1 - S) - (1 / (1 + torch.exp(sigma * (scores - scores.T)))))) #.sum(axis=1).unsqueeze(1)

    # Calculate abs(Delta-NDCG) for each ordering <i, j> combination
    sorted_ind = torch.flip(scores.argsort(dim=0).flatten(), dims=[0])
    sorted_labels = labels[sorted_ind]
    ideal_labels = torch.sort(labels)[0].flip(dims=[0])
    k = (torch.arange(sorted_labels.shape[0]) + 1).to(device)
    DCG_ideal_labels = torch.sum((2**ideal_labels - 1) / torch.log(k + 1)) 
    doc_id_to_rank = torch.Tensor([(sorted_ind == i).nonzero(as_tuple=True)[0] for i in N]).int()
    doc_id_to_label = torch.Tensor([sorted_labels[R_i] for R_i in doc_id_to_rank]).int().to(device)
        
    #TODO: We always do this stack+transpose, make a function of this? (and can't something like meshgrid() do the same?)
    #TODO: Put comments to explain things.
    R_j = torch.stack([doc_id_to_rank] * num_docs).to(device)
    R_i = R_j.T
    label_j = torch.stack([doc_id_to_label] * num_docs).to(device)
    label_i = label_j.T
    DCG_discount = ((2**label_i - 1) / torch.log(R_i + 2) + (2**label_j - 1) / torch.log(R_j + 2)).to(device)
    DCG_gain = ((2**label_j - 1) / torch.log(R_i + 2) + (2**label_i - 1) / torch.log(R_j + 2)).to(device)
    delta_NDCG = ((DCG_gain - DCG_discount) / DCG_ideal_labels).abs()

    lambda_rank_loss =  (lamda * delta_NDCG).sum(axis=1).unsqueeze(1) 
    
    return lambda_rank_loss

In [None]:
def train_model(model, trainloader, learning_rate, epochs):
    """
    Train a model using the optimal configuration.
    
    - model: the model to train
    - trainloader: the dataloader to use, which should provide tokens/embeddings/texts and labels
    - learning_rate: the learning rate to use for the optimizer
    - epochs: the number of epochs to train the model for
    """
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    ndcg_scores = []
        
    for epoch in range(epochs):
        before = time.time()
        
        for i, batch in enumerate(trainloader):

            # Skip empty batches
            if not batch:               
                break
                
            # The LLM and graph models have slightly different datastructures
            if model.__class__.__name__ in ["e5_ranker", "conSultantBERT"]:
                batch_data, batch_labels = batch
            else:
                batch_data = batch.to(device)
                batch_labels = batch.y

            # Make prediction
            y_pred = model(batch_data)

            # Calculate and propagate loss
            optimizer.zero_grad()
            ground_truth = batch_labels.squeeze()
            lambda_i = listwise_loss(y_pred, ground_truth)
            torch.autograd.backward(y_pred, lambda_i.squeeze())
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)

            optimizer.step()

            print(f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}/{len(trainloader)}", end="\r")
        print(f"\n\nTraining time: {time.time() - before}s")
        
    return model

In [None]:
def evaluate_model(model, testloader):
    """
    Evaluate a model on unseen data.
    
    - model: the model to evaluate
    - testloader: the dataloader to use, which should provide tokens/embeddings/texts and labels
    """
        
    ndcg10_scores = []
    ndcg5_scores = []
    ndcg3_scores = []
    
    with torch.no_grad():
        for i, batch in enumerate(testloader):
            
            # Skip empty batches
            if not batch:
                break
                                
            # The LLM and graph models have slightly different datastructures
            if model.__class__.__name__ in ["e5_ranker", "conSultantBERT"]:
                _, batch_data, batch_labels = batch
            else:
                batch_data = batch.to(device)
                batch_labels = batch.y

            print(f"Batch: {i + 1}/{len(testloader)}", end="\r")
            
            # Make predictions
            y_pred = model(batch_data)
        
            # Calculate nDCG score of current batch        
            ndcg10_scores.append(ndcg_score(batch_labels.detach().cpu().unsqueeze(0), 
                                 y_pred.unsqueeze(0).detach().cpu(), k=10))
            ndcg5_scores.append(ndcg_score(batch_labels.detach().cpu().unsqueeze(0), 
                                 y_pred.unsqueeze(0).detach().cpu(), k=5))
            ndcg3_scores.append(ndcg_score(batch_labels.detach().cpu().unsqueeze(0), 
                                 y_pred.unsqueeze(0).detach().cpu(), k=3))
            
    return f"nDCG@10: {np.mean(ndcg10_scores)}, nDCG@5: {np.mean(ndcg5_scores)}, nDCG@3: {np.mean(ndcg3_scores)}"

In [None]:
def get_model_performance(model_type, best_config):

    if model_type == "e5":
        trainloader = torch.load("./dataloaders/e5_trainloader.pth")
        testloader = torch.load("./dataloaders/e5_testloader.pth")

        model = e5_ranker(pooling=best_config["pooling"]).to(device)
    elif model_type == "consultantbert":
        trainloader = torch.load("./dataloaders/bert_trainloader.pth")
        testloader = torch.load("./dataloaders/bert_testloader.pth")

        model = conSultantBERT(pooling=best_config["pooling"]).to(device)
    elif model_type == "baselineGNN":
        trainloader = torch.load("./dataloaders/graph_trainloader.pth")
        testloader = torch.load("./dataloaders/graph_testloader.pth")
        
        # Data.metadata() is needed to initialize the heterodata
        data = next(iter(trainloader))

        # All the different node types
        typings = ["candidate", "request", "function_name", "isco_code", 
                   "education", "language", "license", "skill", "company_name", 
                   "function_id", "isco_level", "workgroup", "klass", "literal"]

        model = baselineGNNModel(data, 
                                 typings,
                                 text_embedding_size=best_config["text_embedding_size"],
                                 embedding_size=best_config["embedding_size"]).to(device)
    elif model_type == "OKRA":
        trainloader = torch.load("./dataloaders/graph_trainloader.pth")
        testloader = torch.load("./dataloaders/graph_testloader.pth")
        
        # Data.metadata() is needed to initialize the heterodata
        data = next(iter(trainloader))

        # All the different node types
        typings = ["candidate", "request", "function_name", "isco_code", 
                   "education", "language", "license", "skill", "company_name", 
                   "function_id", "isco_level", "workgroup", "klass", "literal"]

        model = OKRA(data, 
                     typings,
                     text_embedding_size=best_config["text_embedding_size"],
                     embedding_size=best_config["embedding_size"]).to(device)
    else:
        raise Exception("Please select one of ['e5', 'consultantbert', 'baselineGNN', 'OKRA']")
        

    if not f"{model.__class__.__name__}.pt" in os.listdir("./trained_models/"):
        model = train_model(model, trainloader, learning_rate=best_config["learning_rate"], epochs=best_config["epochs"])
        torch.save(model.state_dict(), f"./trained_models/{model.__class__.__name__}.pt")
    else:
        model.load_state_dict(torch.load(f"./trained_models/{model.__class__.__name__}.pt"))
    
    test_score = evaluate_model(model, testloader)

    print(f"{model.__class__.__name__} test set score:", test_score)

    torch.cuda.empty_cache() 
    gc.collect()
    
    return model

In [None]:
e5_ranker = get_model_performance("e5", e5_best_config)

## conSultantBERT

In [None]:
bert_best_config = {"pooling": "mean", "learning_rate": 0.00000130695542009724, "epochs": 4}

In [None]:
class BERTTokenDataLoader(Dataset):
    def __init__(self, truth_dict, cv_data, req_data, query_size=512, batch_size=32):
        self.ground_truths = list(truth_dict.items())
        self.cv_texts = cv_data
        self.req_texts = req_data
        
        self.query_size = query_size
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
        self.batch_size = batch_size

    def __len__(self):
        # Adjust the length to account for the number of batches based on candidates
        return (len(self.ground_truths) + self.batch_size - 1) // self.batch_size

    def __getitem__(self, idx):

        candidate, vacancies = self.ground_truths[idx]

        req_ids, labels = vacancies.keys(), list(vacancies.values())

        candidate_text = self.cv_texts[int(candidate[1:])]
        vacancy_texts = [self.req_texts[int(v[1:])] for v in req_ids]
        input_texts = [candidate_text] + vacancy_texts
        
        tokens = self.tokenizer(input_texts, max_length=self.query_size, padding='max_length', truncation=True, return_tensors="pt").to(device)
            
        # Differentiating between -1 and 0 is practically impossible, so they are considered to be the same
        labels = [i if i >=0 else 0 for i in labels]
        
        return (candidate, list(req_ids)), tokens, torch.LongTensor(labels).to(device)

In [None]:
class conSultantBERT(torch.nn.Module):
    def __init__(self, pooling):
        super().__init__()
        
        self.model = BertModel.from_pretrained("bert-base-multilingual-cased").to(device)    
        self.pooling = pooling
        
    def forward(self, batch):
        
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        # Process all embeddings in one go
        outputs = self.model(input_ids, attention_mask=attention_mask)
        
        if self.pooling == "mean":
            # Mean pooling
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size())
            sum_embeddings = torch.sum(outputs.last_hidden_state * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)  # Avoid division by zero
            embeddings = sum_embeddings / sum_mask
        elif self.pooling == "sum":
            # Sum pooling
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size())
            embeddings = torch.sum(outputs.last_hidden_state * input_mask_expanded, 1)
        elif self.pooling == "max":
            # Max pooling
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size()).bool()
            masked_embeddings = outputs.last_hidden_state * input_mask_expanded  # Apply mask to zero out padding tokens
            embeddings, _ = torch.max(masked_embeddings, dim=1)  # Obtain max across the sequence dimension
        
        # Extract CV and request embeddings
        cv_embedding = embeddings[0].unsqueeze(0)  # CV is the first in the batch
        req_embeddings = embeddings[1:]  # rest are requests
            
        # Use the cosine similarity as the score (based on the paper)
        return F.cosine_similarity(cv_embedding, req_embeddings).squeeze()

In [None]:
cBERT = get_model_performance("consultantbert", bert_best_config)

## BaseGNN

In [None]:
gnn_best_config = {"epochs": 1, "embedding_size": 128, "text_embedding_size": 32, "learning_rate": 0.0004643543481813991}

In [None]:
# We embed the textual nodes (candidates and requests) separately at first
class base_text_embedding_layer(torch.nn.Module):
    def __init__(self, text_embedding_size=64):
        super().__init__()
        
        self.e5 = AutoModel.from_pretrained("intfloat/multilingual-e5-small").to(device)
                
        self.candidate_out = nn.Linear(in_features=384,
                                       out_features=text_embedding_size)

        self.company_out = nn.Linear(in_features=384,
                                     out_features=text_embedding_size)
        
    def average_pool(self, last_hidden_states, attention_mask):
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        
    def forward(self, x_can, x_req, att_mask_can, att_mask_req):
        
        # Feed tokens into model
        x_candidate = self.e5(x_can, att_mask_can)
        x_company = self.e5(x_req, att_mask_req)
        
        # Create embedding tensor
        candidate_embeddings = self.average_pool(x_candidate.last_hidden_state, attention_mask=att_mask_can)
        company_embeddings = self.average_pool(x_company.last_hidden_state, attention_mask=att_mask_req)

        # normalize embeddings
        candidate_embeddings = F.normalize(candidate_embeddings, p=2, dim=1)
        company_embeddings = F.normalize(company_embeddings, p=2, dim=1)
        
        # Run through MLP to match other embedding sizes
        x_candidate = self.candidate_out(candidate_embeddings).float()
        x_company = self.company_out(company_embeddings).float()
        
        return x_candidate, x_company    
    
# Then, we embed all nodes initially
class base_embedding_layer(torch.nn.Module):
    def __init__(self, embedding_size=32):
        super().__init__()        
        
        self.conv = geom_nn.TransformerConv((-1, -1), embedding_size)
        self.conv2 = geom_nn.TransformerConv((-1, -1), embedding_size)
    
    def forward(self, x, edge_index):
        
        x = self.conv(x, edge_index).relu()
        return self.conv2(x, edge_index)
        
class baselineGNNModel(torch.nn.Module):
    def __init__(self, data, typings, text_embedding_size=16, embedding_size=32):
        super().__init__()
        
        self.typings = typings
        
        self.text_embedder = base_text_embedding_layer(text_embedding_size=text_embedding_size)

        self.embedder = base_embedding_layer(embedding_size=embedding_size)
        self.embedder = to_hetero(self.embedder, data.metadata(), aggr='sum')        
        
        self.fc = nn.Linear(embedding_size, 1)
        
    def forward(self, data):
        # Embed textual features       
        x_candidate, x_request = self.text_embedder(data.x_dict["candidate"], data.x_dict["request"], data["candidate"].att_mask, data["request"].att_mask)
        
        # Store the textual embeddings along with the rest of the graph
        data.x_dict["candidate"] = x_candidate
        data.x_dict["request"] = x_request

        # Embed the graph as a whole
        embedded_data = self.embedder({k: v.float() for k, v in data.x_dict.items()}, data.edge_index_dict)
        
        # Each sub-graph gets its own embedding
        sub_graphs = defaultdict(list)
            
        # Find the sub-graph of each node in the embedding, and add it to the corresponding list
        for typing in self.typings:
            for i, emb in enumerate(embedded_data[typing]):            
                # Some subgraphs do not have all data types (e.g., a graph might not include any education nodes)
                if data[typing]:
                    # Find the sub-graph the current node belongs to
                    current_node_id = int(data[typing].unique_node_id[i].item())
                                        
                    # We were working with a dummy node
                    if current_node_id == 0:
                        continue
                        
                    sg = int(data[typing].sub_graph[i].item())
                                        
                    # Add its candidate embedding to its sub-graph embedding
                    sub_graphs[sg].append(emb.unsqueeze(0))              

        # Finally, mean pool every graph embedding (so the final embedding is the mean of all of the nodes)
        for sg in sub_graphs.keys():            
            sub_graphs[sg] = torch.mean(torch.stack(sub_graphs[sg]).squeeze(1), dim=0)
                                    
        # Stack all the sub-graph embeddings into a single matrix, both candidate- and company-sided
        sub_graphs = torch.stack([i[1] for i in sorted(sub_graphs.items())], dim=0)
                
        # Make predictions based on the sub-graph embeddings
        y_pred = self.fc(sub_graphs)
        
        return y_pred.squeeze()

In [None]:
gTransformer = get_model_performance("baselineGNN", gnn_best_config)

## Okra

In [None]:
okra_best_config = {"epochs": 3, "embedding_size": 32, "text_embedding_size": 128, 
                    "text_pooling": "token", "pooling_method": "mean", 
                    "learning_rate": 0.00005282859517546829}

In [None]:
import copy
import warnings
from typing import Any, Dict, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import Module, Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.fx import Transformer
from torch_geometric.typing import EdgeType, Metadata, NodeType, SparseTensor
from torch_geometric.utils.hetero import get_unused_node_types

try:
    from torch.fx import Graph, GraphModule, Node
except (ImportError, ModuleNotFoundError, AttributeError):
    GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node'


def to_hetero_with_bases(module: Module, metadata: Metadata, num_bases: int,
                         in_channels: Optional[Dict[str, int]] = None,
                         input_map: Optional[Dict[str, str]] = None,
                         debug: bool = False) -> GraphModule:

    transformer = ToHeteroWithBasesTransformer(module, metadata, num_bases,
                                               in_channels, input_map, debug)
    return transformer.transform()



class ToHeteroWithBasesTransformer(Transformer):
    def __init__(
        self,
        module: Module,
        metadata: Metadata,
        num_bases: int,
        in_channels: Optional[Dict[str, int]] = None,
        input_map: Optional[Dict[str, str]] = None,
        debug: bool = False,
    ):
        super().__init__(module, input_map, debug)

        self.metadata = metadata
        self.num_bases = num_bases
        self.in_channels = in_channels or {}
        assert len(metadata) == 2
        assert len(metadata[0]) > 0 and len(metadata[1]) > 0

        self.validate()

        # Compute IDs for each node and edge type:
        self.node_type2id = {k: i for i, k in enumerate(metadata[0])}
        self.edge_type2id = {k: i for i, k in enumerate(metadata[1])}

    def validate(self):
        unused_node_types = get_unused_node_types(*self.metadata)
        if len(unused_node_types) > 0:
            warnings.warn(
                f"There exist node types ({unused_node_types}) whose "
                f"representations do not get updated during message passing "
                f"as they do not occur as destination type in any edge type. "
                f"This may lead to unexpected behavior.")

        names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]
        for name in names:
            if not name.isidentifier():
                warnings.warn(
                    f"The type '{name}' contains invalid characters which "
                    f"may lead to unexpected behavior. To avoid any issues, "
                    f"ensure that your types only contain letters, numbers "
                    f"and underscores.")

    def transform(self) -> GraphModule:
        self._node_offset_dict_initialized = False
        self._edge_offset_dict_initialized = False
        self._edge_type_initialized = False
        out = super().transform()
        del self._node_offset_dict_initialized
        del self._edge_offset_dict_initialized
        del self._edge_type_initialized
        return out

    def placeholder(self, node: Node, target: Any, name: str):
        if node.type is not None:
            Type = EdgeType if self.is_edge_level(node) else NodeType
            node.type = Dict[Type, node.type]

        out = node

        # Create `node_offset_dict` and `edge_offset_dict` dictionaries in case
        # they are not yet initialized. These dictionaries hold the cumulated
        # sizes used to create a unified graph representation and to split the
        # output data.
        if self.is_edge_level(node) and not self._edge_offset_dict_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function',
                                         target=get_edge_offset_dict,
                                         args=(node, self.edge_type2id),
                                         name='edge_offset_dict')
            self._edge_offset_dict_initialized = True

        elif not self._node_offset_dict_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function',
                                         target=get_node_offset_dict,
                                         args=(node, self.node_type2id),
                                         name='node_offset_dict')
            self._node_offset_dict_initialized = True

        # Create a `edge_type` tensor used as input to `HeteroBasisConv`:
        if self.is_edge_level(node) and not self._edge_type_initialized:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_function', target=get_edge_type,
                                         args=(node, self.edge_type2id),
                                         name='edge_type')
            self._edge_type_initialized = True

        # Add `Linear` operation to align features to the same dimensionality:
        if name in self.in_channels:
            self.graph.inserting_after(out)
            out = self.graph.create_node('call_module',
                                         target=f'align_lin__{name}',
                                         args=(node, ),
                                         name=f'{name}__aligned')
            self._state[out.name] = self._state[name]

            lin = LinearAlign(self.metadata[int(self.is_edge_level(node))],
                              self.in_channels[name])
            setattr(self.module, f'align_lin__{name}', lin)

        # Perform grouping of type-wise values into a single tensor:
        if self.is_edge_level(node):
            self.graph.inserting_after(out)
            out = self.graph.create_node(
                'call_function', target=group_edge_placeholder,
                args=(out if name in self.in_channels else node,
                      self.edge_type2id,
                      self.find_by_name('node_offset_dict')),
                name=f'{name}__grouped')
            self._state[out.name] = 'edge'

        else:
            self.graph.inserting_after(out)
            out = self.graph.create_node(
                'call_function', target=group_node_placeholder,
                args=(out if name in self.in_channels else node,
                      self.node_type2id), name=f'{name}__grouped')
            self._state[out.name] = 'node'

        self.replace_all_uses_with(node, out)

    def call_message_passing_module(self, node: Node, target: Any, name: str):
        # Call the `HeteroBasisConv` wrapper instead instead of a single
        # message passing layer. We need to inject the `edge_type` as first
        # argument in order to do so.
        node.args = (self.find_by_name('edge_type'), ) + node.args

    def output(self, node: Node, target: Any, name: str):
        # Split the output to dictionaries, holding either node type-wise or
        # edge type-wise data.
        def _recurse(value: Any) -> Any:
            if isinstance(value, Node) and self.is_edge_level(value):
                self.graph.inserting_before(node)
                return self.graph.create_node(
                    'call_function', target=split_output,
                    args=(value, self.find_by_name('edge_offset_dict')),
                    name=f'{value.name}__split')

                pass
            elif isinstance(value, Node):
                self.graph.inserting_before(node)
                return self.graph.create_node(
                    'call_function', target=split_output,
                    args=(value, self.find_by_name('node_offset_dict')),
                    name=f'{value.name}__split')

            elif isinstance(value, dict):
                return {k: _recurse(v) for k, v in value.items()}
            elif isinstance(value, list):
                return [_recurse(v) for v in value]
            elif isinstance(value, tuple):
                return tuple(_recurse(v) for v in value)
            else:
                return value

        if node.type is not None and isinstance(node.args[0], Node):
            output = node.args[0]
            Type = EdgeType if self.is_edge_level(output) else NodeType
            node.type = Dict[Type, node.type]
        else:
            node.type = None

        node.args = (_recurse(node.args[0]), )

    def init_submodule(self, module: Module, target: str) -> Module:
        if not isinstance(module, MessagePassing):
            return module

        # Replace each `MessagePassing` module by a `HeteroBasisConv` wrapper:
        return HeteroBasisConv(module, len(self.metadata[1]), self.num_bases)


###############################################################################


class HeteroBasisConv(torch.nn.Module):
    # A wrapper layer that applies the basis-decomposition technique to a
    # heterogeneous graph.
    def __init__(self, module: MessagePassing, num_relations: int,
                 num_bases: int):
        super().__init__()

        self.num_relations = num_relations
        self.num_bases = num_bases

        # We make use of a post-message computation hook to inject the
        # basis re-weighting for each individual edge type.
        # This currently requires us to set `conv.fuse = False`, which leads
        # to a materialization of messages.
        def hook(module, inputs, output):
            assert isinstance(module._edge_type, Tensor)
            if module._edge_type.size(0) != output.size(0):
                raise ValueError(
                    f"Number of messages ({output.size(0)}) does not match "
                    f"with the number of original edges "
                    f"({module._edge_type.size(0)}). Does your message "
                    f"passing layer create additional self-loops? Try to "
                    f"remove them via 'add_self_loops=False'")
            weight = module.edge_type_weight.view(-1)[module._edge_type]
            weight = weight.view([-1] + [1] * (output.dim() - 1))
            return weight * output

        params = list(module.parameters())
        device = params[0].device if len(params) > 0 else 'cpu'

        self.convs = torch.nn.ModuleList()
        for _ in range(num_bases):
            conv = copy.deepcopy(module)
            conv.fuse = False  # Disable `message_and_aggregate` functionality.
            # We learn a single scalar weight for each individual edge type,
            # which is used to weight the output message based on edge type:
            conv.edge_type_weight = Parameter(
                torch.empty(1, num_relations, device=device))
            conv.register_message_forward_hook(hook)
            self.convs.append(conv)

        if self.num_bases > 1:
            self.reset_parameters()

    def reset_parameters(self):
        for conv in self.convs:
            if hasattr(conv, 'reset_parameters'):
                conv.reset_parameters()
            elif sum([p.numel() for p in conv.parameters()]) > 0:
                warnings.warn(
                    f"'{conv}' will be duplicated, but its parameters cannot "
                    f"be reset. To suppress this warning, add a "
                    f"'reset_parameters()' method to '{conv}'")
            torch.nn.init.xavier_uniform_(conv.edge_type_weight)

    def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:
        out = None
        
        attention = []
        
        # Call message passing modules and perform aggregation:
        for conv in self.convs:
            conv._edge_type = edge_type
                        
            res, (edge_ind_exp, att_weight_exp) = conv(*args, **kwargs)
            del conv._edge_type
            
            attention.append(att_weight_exp)
            
            out = res if out is None else out.add_(res)
            
            # jump
            
        return out, (edge_type, edge_ind_exp, torch.mean(torch.stack(attention, dim=0), dim=0))

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(num_relations='
                f'{self.num_relations}, num_bases={self.num_bases})')


class LinearAlign(torch.nn.Module):
    # Aligns representions to the same dimensionality. Note that this will
    # create lazy modules, and as such requires a forward pass in order to
    # initialize parameters.
    def __init__(self, keys: List[Union[NodeType, EdgeType]],
                 out_channels: int):
        super().__init__()
        self.out_channels = out_channels
        self.lins = torch.nn.ModuleDict()
        for key in keys:
            self.lins[key2str(key)] = Linear(-1, out_channels, bias=False)

    def forward(
        self, x_dict: Dict[Union[NodeType, EdgeType], Tensor]
    ) -> Dict[Union[NodeType, EdgeType], Tensor]:
        return {key: self.lins[key2str(key)](x) for key, x in x_dict.items()}

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(num_relations={len(self.lins)}, '
                f'out_channels={self.out_channels})')


###############################################################################

# These methods are used in order to receive the cumulated sizes of input
# dictionaries. We make use of them for creating a unified homogeneous graph
# representation, as well as to split the final output data once again.


def get_node_offset_dict(
    input_dict: Dict[NodeType, Union[Tensor, SparseTensor]],
    type2id: Dict[NodeType, int],
) -> Dict[NodeType, int]:
    cumsum = 0
    out: Dict[NodeType, int] = {}
    for key in type2id.keys():
        out[key] = cumsum
        cumsum += input_dict[key].size(0)
    return out


def get_edge_offset_dict(
    input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]],
    type2id: Dict[EdgeType, int],
) -> Dict[EdgeType, int]:
    cumsum = 0
    out: Dict[EdgeType, int] = {}
    for key in type2id.keys():
        out[key] = cumsum
        value = input_dict[key]
        if isinstance(value, SparseTensor):
            cumsum += value.nnz()
        elif value.dtype == torch.long and value.size(0) == 2:
            cumsum += value.size(-1)
        else:
            cumsum += value.size(0)
    return out


###############################################################################

# This method computes the edge type of the final homogeneous graph
# representation. It will be used in the `HeteroBasisConv` wrapper.


def get_edge_type(
    input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]],
    type2id: Dict[EdgeType, int],
) -> Tensor:

    inputs = [input_dict[key] for key in type2id.keys()]
    outs = []

    for i, value in enumerate(inputs):
        if value.size(0) == 2 and value.dtype == torch.long:  # edge_index
            out = value.new_full((value.size(-1), ), i, dtype=torch.long)
        elif isinstance(value, SparseTensor):
            out = torch.full((value.nnz(), ), i, dtype=torch.long,
                             device=value.device())
        else:
            out = value.new_full((value.size(0), ), i, dtype=torch.long)
        outs.append(out)

    return outs[0] if len(outs) == 1 else torch.cat(outs, dim=0)


###############################################################################

# These methods are used to group the individual type-wise components into a
# unfied single representation.


def group_node_placeholder(input_dict: Dict[NodeType, Tensor],
                           type2id: Dict[NodeType, int]) -> Tensor:

    inputs = [input_dict[key] for key in type2id.keys()]
    return inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=0)


def group_edge_placeholder(
    input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]],
    type2id: Dict[EdgeType, int],
    offset_dict: Dict[NodeType, int] = None,
) -> Union[Tensor, SparseTensor]:

    inputs = [input_dict[key] for key in type2id.keys()]

    if len(inputs) == 1:
        return inputs[0]

    # In case of grouping a graph connectivity tensor `edge_index` or `adj_t`,
    # we need to increment its indices:
    elif inputs[0].size(0) == 2 and inputs[0].dtype == torch.long:
        if offset_dict is None:
            raise AttributeError(
                "Can not infer node-level offsets. Please ensure that there "
                "exists a node-level argument before the 'edge_index' "
                "argument in your forward header.")

        outputs = []
        for value, (src_type, _, dst_type) in zip(inputs, type2id):
            value = value.clone()
            value[0, :] += offset_dict[src_type]
            value[1, :] += offset_dict[dst_type]
            outputs.append(value)

        return torch.cat(outputs, dim=-1)

    elif isinstance(inputs[0], SparseTensor):
        if offset_dict is None:
            raise AttributeError(
                "Can not infer node-level offsets. Please ensure that there "
                "exists a node-level argument before the 'SparseTensor' "
                "argument in your forward header.")

        # For grouping a list of SparseTensors, we convert them into a
        # unified `edge_index` representation in order to avoid conflicts
        # induced by re-shuffling the data.
        rows, cols = [], []
        for value, (src_type, _, dst_type) in zip(inputs, type2id):
            col, row, value = value.coo()
            assert value is None
            rows.append(row + offset_dict[src_type])
            cols.append(col + offset_dict[dst_type])

        row = torch.cat(rows, dim=0)
        col = torch.cat(cols, dim=0)
        return torch.stack([row, col], dim=0)

    else:
        return torch.cat(inputs, dim=0)


###############################################################################

# This method is used to split the output tensors into individual type-wise
# components:


def split_output(
    output: Tensor,
    offset_dict: Union[Dict[NodeType, int], Dict[EdgeType, int]],
) -> Union[Dict[NodeType, Tensor], Dict[EdgeType, Tensor]]:
    
    # Sometimes an edge index ends up here. Not sure why. TODO: fix --> we should be able to determine which edge belongs
    # to which edge type
    if type(output) == tuple:
        return output
    elif output.size(0) == 2:
        output = output.T
        
    cumsums = list(offset_dict.values()) + [output.size(0)]    
    sizes = [cumsums[i + 1] - cumsums[i] for i in range(len(offset_dict))]
    outputs = output.split(sizes)
    return {key: output for key, output in zip(offset_dict, outputs)}


###############################################################################


def key2str(key: Union[NodeType, EdgeType]) -> str:
    key = '__'.join(key) if isinstance(key, tuple) else key
    return key.replace(' ', '_').replace('-', '_').replace(':', '_')

In [None]:
# We embed the textual nodes (candidates and requests) separately at first
class text_embedding_layer(torch.nn.Module):
    def __init__(self, text_embedding_size=64, text_pooling="token"):
        super().__init__()
        
        self.e5 = AutoModel.from_pretrained("intfloat/multilingual-e5-small").to(device)
        
        self.text_pooling = text_pooling
                
        self.candidate_out = nn.Linear(in_features=384,
                                       out_features=text_embedding_size)

        self.company_out = nn.Linear(in_features=384,
                                     out_features=text_embedding_size)
        
    def average_pool(self, last_hidden_states, attention_mask):
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        
    def forward(self, x_can, x_req, att_mask_can, att_mask_req):
        
        # Feed tokens into model
        x_candidate = self.e5(x_can, att_mask_can)
        x_company = self.e5(x_req, att_mask_req)
        
        if self.text_pooling == "token":
            # Create embedding tensor
            candidate_embeddings = self.average_pool(x_candidate.last_hidden_state, attention_mask=att_mask_can)
            company_embeddings = self.average_pool(x_company.last_hidden_state, attention_mask=att_mask_req)

            # normalize embeddings
            candidate_embeddings = F.normalize(candidate_embeddings, p=2, dim=1)
            company_embeddings = F.normalize(company_embeddings, p=2, dim=1)
        elif self.text_pooling == "sentence":
            # Mean pooling
            input_mask_expanded = att_mask_can.unsqueeze(-1).expand(x_candidate.last_hidden_state.size())
            sum_embeddings = torch.sum(x_candidate.last_hidden_state * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)  # Avoid division by zero
            candidate_embeddings = sum_embeddings / sum_mask
            
            input_mask_expanded = att_mask_req.unsqueeze(-1).expand(x_company.last_hidden_state.size())
            sum_embeddings = torch.sum(x_company.last_hidden_state * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)  # Avoid division by zero
            company_embeddings = sum_embeddings / sum_mask

        # Run through MLP to match other embedding sizes
        x_candidate = self.candidate_out(candidate_embeddings).float()
        x_company = self.company_out(company_embeddings).float()
        
        return x_candidate, x_company  
    
# Then, we embed all nodes initially
class embedding_layer(torch.nn.Module):
    def __init__(self, embedding_size=32):
        super().__init__()        
        
        self.conv1 = geom_nn.TransformerConv((-1, -1), embedding_size)
        self.conv2 = geom_nn.TransformerConv((-1, -1), embedding_size)

    def forward(self, x, edge_index):
       
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)

        return x
        
class GNN(torch.nn.Module):
    def __init__(self, embedding_size=64, heads=4):
        super().__init__()

        self.can_pos = geom_nn.GATv2Conv((-1, -1),
                                         out_channels=embedding_size,
                                         add_self_loops=False,
                                         heads=heads,
                                         concat=False)
        
        self.can_neg = geom_nn.GATv2Conv((-1, -1),
                                         out_channels=embedding_size,
                                         add_self_loops=False,
                                         heads=heads,
                                         concat=False)
        self.com_pos = geom_nn.GATv2Conv((-1, -1),
                                         out_channels=embedding_size,
                                         add_self_loops=False,
                                         heads=heads,
                                         concat=False)

        self.com_neg = geom_nn.GATv2Conv((-1, -1),
                                         out_channels=embedding_size,
                                         add_self_loops=False,
                                         heads=heads,
                                         concat=False)
        
        # Different batch norm for each GATv2 output, as it includes learned parameters
        self.batch_norm1 = torch.nn.BatchNorm1d(embedding_size)
        self.batch_norm2 = torch.nn.BatchNorm1d(embedding_size)
        self.batch_norm3 = torch.nn.BatchNorm1d(embedding_size)
        self.batch_norm4 = torch.nn.BatchNorm1d(embedding_size)

        self.dense_can = nn.Linear(in_features=embedding_size,
                                   out_features=heads)
    
        self.dense_com = nn.Linear(in_features=embedding_size,
                                   out_features=heads)
    
        
        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, edge_index):            
        ### Positive candidate-side attention
        x_can_pos, can_pos_exp = self.can_pos(x, 
                                              edge_index.long(), 
                                              return_attention_weights=True)
        
        ### Negative candidate-side attention        
        x_can_neg, can_neg_exp = self.can_neg(x_can_pos, 
                                              edge_index.long(), 
                                              return_attention_weights=True)

        ### Positive company-side attention
        # We flip the edge index to distinguish this as a 'company-side' graph
        x_com_pos, com_pos_exp = self.com_pos(x * -1, 
                                              edge_index[[1, 0]].long(), 
                                              return_attention_weights=True)

        ### Negative company-side attention
        x_com_neg, com_neg_exp = self.com_neg(x_com_pos, 
                                              edge_index[[1, 0]].long(), 
                                              return_attention_weights=True)
        
        # Edge embedding
        e_im_can = self.sigmoid(
                            torch.sum(
                                torch.stack([can_pos_exp[2], 
                                             can_neg_exp[2]], 
                                      dim=0), 
                                dim=0)
                    )        
        
        e_im_com = self.sigmoid(
                            torch.sum(
                                torch.stack([com_pos_exp[2], 
                                             com_neg_exp[2]], 
                                        dim=0), 
                                dim=0)
                    )
        
        x_can_pos = self.batch_norm1(x_can_pos)
        x_can_neg = self.batch_norm2(x_can_neg)
        
        x_com_pos = self.batch_norm3(x_com_pos)
        x_com_neg = self.batch_norm4(x_com_neg)
         
        # Mean pool
        x_can = torch.mean(torch.stack([self.elu(x_can_pos), 
                                        self.elu(x_can_neg)]), dim=0)
        
        x_com = torch.mean(torch.stack([self.elu(x_com_pos), 
                                        self.elu(x_com_neg)]), dim=0)
            
        # Node embedding
        v_im_can = self.dense_can(x_can).relu()
        v_im_com = self.dense_com(x_com).relu()
        
        return x_can, x_com, v_im_can, v_im_com, e_im_can, e_im_com,\
               can_pos_exp, can_neg_exp, com_pos_exp, com_neg_exp
    
class OKRA(torch.nn.Module):
    def __init__(self, data, typings, embedding_size=64, text_embedding_size=64, pooling_method="mean", text_pooling="token", heads=4):
        super().__init__()
        
        self.typings = typings
        self.num_heads = heads
        self.embedding_size = embedding_size
        
        self.pooling = {
            "mean": lambda x, dim: torch.mean(x, dim=dim),
            "sum": lambda x, dim: torch.sum(x, dim=dim),
            # Only return the values for max pooling, ignoring the indices
            "max": lambda x, dim: torch.max(x, dim=dim)[0]
        }[pooling_method]
        
        self.text_embedder = text_embedding_layer(text_embedding_size=text_embedding_size, text_pooling=text_pooling)

        self.embedder = embedding_layer(embedding_size=embedding_size)
        self.embedder = to_hetero(self.embedder, data.metadata(), aggr='sum')

        self.gnn = GNN(embedding_size=embedding_size, heads=heads)
        self.gnn = to_hetero_with_bases(self.gnn, data.metadata(), num_bases=3)
        
        # Each embedding is the size heads * embedding_size * 3, as there is one heads * embedding_size embedding for each (head node, tail node, sub-graph)
        self.mlp_candidate = nn.Linear(in_features=heads * embedding_size * 3,
                                       out_features=1)
        
        self.mlp_company = nn.Linear(in_features=heads * embedding_size * 3,
                                     out_features=1)
        
    def forward(self, data):
        # Embed textual features       
        x_candidate, x_request = self.text_embedder(data.x_dict["candidate"], data.x_dict["request"], data["candidate"].att_mask, data["request"].att_mask)
        
        # Store the textual embeddings along with the rest of the graph
        data.x_dict["candidate"] = x_candidate
        data.x_dict["request"] = x_request

        # Embed the graph as a whole
        embedded_data = self.embedder({k: v.float() for k, v in data.x_dict.items()}, data.edge_index_dict)
       
        # Run the embedded graph through the GNN
        x_can, x_com, v_im_can1, v_im_com1, e_im_can, e_im_com, \
        can_pos_exp, can_neg_exp, com_pos_exp, com_neg_exp = self.gnn(embedded_data, 
                                                                      data.edge_index_dict)
        
        # Combine the attention with the values, once per head, for both the candidate and company side
        h_can = defaultdict(lambda : torch.Tensor([]).to(device))
        h_com = defaultdict(lambda : torch.Tensor([]).to(device))
    
        # Store each node as a combination of its head embeddings
        for typing in self.typings:
            for k in range(self.num_heads):
                if typing in x_can:
                    h_can[typing] = torch.cat([h_can[typing], (x_can[typing].T * v_im_can1[typing][:,k]).T], dim=1)
                else:
                    h_can[typing] = torch.cat([h_can[typing], torch.zeros_like(h_can[list(h_can.keys())[0]].T)])
                
                if typing in x_com:
                    h_com[typing] = torch.cat([h_com[typing], (x_com[typing].T * v_im_com1[typing][:,k]).T], dim=1)
                else:
                    h_com[typing] = torch.cat([h_com[typing], torch.zeros_like(h_com[list(h_com.keys())[0]].T)])
                            
        # Each sub-graph gets its own embedding
        sub_graphs_candidate = defaultdict(list)
        sub_graphs_company = defaultdict(list)
        
        # Additionally, the head and tail node (candidate and vacancy) get stored separately as well
        main_nodes_candidate = defaultdict(list)
        main_nodes_company = defaultdict(list)
        

                                               
        # Find the sub-graph of each node in the embedding, and add it to the corresponding list
        for typing in self.typings:
            for i, emb in enumerate(h_can[typing]):            
                # Some subgraphs do not have all data types (e.g., a graph might not include any education nodes)
                if data[typing]:
                    # Find the sub-graph the current node belongs to
                    current_node_id = int(data[typing].unique_node_id[i].item())
                                        
                    # We were working with a dummy node
                    if current_node_id == 0:
                        continue
                        
                    sg = int(data[typing].sub_graph[i].item())
                    
                    # If our node is a head/tail node, store it accordingly
                    if (in_head := (current_node_id in data.head_nodes[0])) or (in_tail := (current_node_id in data.tail_nodes[0])):                        
                        main_nodes_candidate[sg].append(emb)
                        main_nodes_company[sg].append(h_com[typing][i])
                    
                    # Add its candidate embedding to its sub-graph embedding
                    sub_graphs_candidate[sg].append(emb.unsqueeze(0))

                    # Do the same on the company side
                    sub_graphs_company[sg].append(h_com[typing][i].unsqueeze(0))               

        # Finally, pool every graph embedding (so the final embedding is the mean of all of the nodes)
        for sg in sub_graphs_candidate.keys():            
            sub_graphs_candidate[sg] = self.pooling(torch.stack(sub_graphs_candidate[sg]).squeeze(1), dim=0)
            sub_graphs_company[sg] = self.pooling(torch.stack(sub_graphs_company[sg]).squeeze(1), dim=0)
                                        
            # Add the head and tail node to the full embedding
            sub_graphs_candidate[sg] = torch.cat([torch.cat(main_nodes_candidate[sg], dim=0).squeeze(), sub_graphs_candidate[sg]])
            sub_graphs_company[sg] = torch.cat([torch.cat(main_nodes_company[sg], dim=0).squeeze(), sub_graphs_company[sg]])
                        
        # Stack all the sub-graph embeddings into a single matrix, both candidate- and company-sided
        sub_graphs_candidate = torch.stack([i[1] for i in sorted(sub_graphs_candidate.items())], dim=0)
        sub_graphs_company = torch.stack([i[1] for i in sorted(sub_graphs_company.items())], dim=0)
                
        # Make predictions based on the sub-graph embeddings
        y_candidate = torch.clamp(self.mlp_candidate(sub_graphs_candidate), min=-100, max=100)
        y_company = torch.clamp(self.mlp_company(sub_graphs_company), min=-100, max=100)
        
        # Final prediction is the harmonic mean of the candidate- and company-sided prediction
        y_pred = 2 * ((y_candidate * y_company) / (y_candidate + y_company))
        
        # The harmonic mean of X and 0 should be 0, not nan
        y_pred = torch.nan_to_num(y_pred).squeeze()
        
        return y_pred

In [None]:
okra = get_model_performance("OKRA", okra_best_config)

# Fairness evaluation

In [None]:
data = pd.read_csv("./data/cv-vacancy-pairs.csv")

In [None]:
requests = pd.read_csv("./data/requests.csv")
request_locations = pd.read_csv("./data/request_locations.csv")[["client_mgmcompany_companynumber", "client_mgmcompany_cacheddetails_address_postalcode"]]


In [None]:
request_locations["client_mgmcompany_companynumber"] = request_locations["client_mgmcompany_companynumber"].astype('Int64').fillna(-1)
request_locations["client_mgmcompany_cacheddetails_address_postalcode"] = request_locations["client_mgmcompany_cacheddetails_address_postalcode"].apply(lambda x: x[:4] if not pd.isna(x) else x)

In [None]:
request_locations = request_locations[request_locations['client_mgmcompany_cacheddetails_address_postalcode'].str.match(r'^\d{4}$', na=False)]
request_locations["client_mgmcompany_cacheddetails_address_postalcode"] = request_locations["client_mgmcompany_cacheddetails_address_postalcode"].apply(int)

In [None]:
requests["request_company_name"] = requests["request_company_name"].str.lower().str.strip()
requests["request_company_name"] = requests["request_company_name"].str.replace('[^a-zA-Z0-9]', '_', regex=True).str.strip()

requests["request_company_name"].value_counts().sort_values().plot(loglog=True)

In [None]:
requests.info()

In [None]:
merged_df = pd.merge(requests[["request_mondriaan_number", "request_company_number"]], 
                     request_locations, 
                     how="left", 
                     left_on="request_company_number", 
                     right_on="client_mgmcompany_companynumber")[["request_mondriaan_number", "client_mgmcompany_cacheddetails_address_postalcode"]]
                                                                 
merged_df = merged_df.drop_duplicates(subset="request_mondriaan_number") 

In [None]:
# Load candidate locations
candidate_locs = pd.read_csv("./data/location.csv") #[["candidate_id", "from_post_code"]]

candidate_locs.at[1108034, "date_start"] = "2018-02-01 00:00:00"
candidate_locs.at[1476463, "date_start"] = "2016-09-01 00:00:00"
candidate_locs.at[1596093, "date_start"] = "2020-10-26 00:00:00"


# Convert 'start_date' to datetime
candidate_locs['date_start'] = pd.to_datetime(candidate_locs['date_start'])

# Filter the DataFrame to keep only rows with the most recent 'start_date' for each 'user_id'
candidate_locs = candidate_locs[candidate_locs['date_start'] == candidate_locs.groupby('candidate_id')['date_start'].transform('max')]

# Postal code to municipality conversion
zip_to_mun = pd.read_csv("./data/georef-netherlands-postcode-pc4.csv", encoding="utf-8", on_bad_lines="skip", delimiter=";")[["PC4", "Provincie name", "Gemeente name"]]

randstad_municipalities = {"Amsterdam", "Rotterdam", "Utrecht", "Den Haag", "Haarlemmermeer", "Zaandam", "Heemskerk", "Beverwijk", "Velsen", "Oostzaan", "Landsmeer", "Haarlem", "Bloemendaal", 
                           "Zandvoort", "Diemen", "Gooise Meren", "Almere", "Huizen", "Blaricum", "Laren", "Hilversum", "Baarn", "Wijdemeren", "Amstelveen", "Ouderamstel", "Aalsmeer", "Uithoorn", 
                           "Heemstede", "Eemnes", "Bunschoten", "Amersfoort", "Soest", "Zesit", "De Bilt", "Stichtse Vechte", "De Ronde Venen", "Woerden", "Bunnik", "Houten", "Nieuwegein", "IJsselstein",
                           "Woerden", "Oudewater", "Montfoort", "Lopik", "Vijfherenlanden", "Nieuwkoop", "Kaag en Braassem", "Teylingen", "Lisse", "Hillegom", "Noordwijk", "Katwijk", "Oegstgeest",
                           "Leiden", "Leiderdorp", "Alphen aan de Rijn", "Wassenaar", "Voorschoten", "Zoeterwoude", "Bodegraven-Reeuwijk", "Waddinxveen", "Gouda", "Krimpenerwaard", "Molenlanden",
                           "Gorinchem", "Hardinxveld-Giessendam", "Sliedrecht", "Dordrecht", "Papendrecht", "Hoeksche Waard", "Nissewaard", "Voorne aan Zee", "Alblasserdam", "Ridderkerk", "Hendrik-Ido-Ambacht",
                           "Zwijndrecht", "Barendrecht", "Albrandswaard", "Krimpen", "Capelle", "Capelle aan den IJssel", "Zuidplas", "Lansingerland", "Zoetermeer", "Zoeterwoude", "Leidschendam-Voorburg",
                           "'s-Gravenhage", "Rijswijk", "Pijnacker-nootdrop", "Delft", "Midden-delfland", "Westland", "Maassluis", "Vlaardingen", "Schiedam"}

randstad_municipalities = {i.lower() for i in randstad_municipalities}

# Find which zip codes fall in Randstad-municipalities
randstad_zips = zip_to_mun[zip_to_mun["Gemeente name"].str.lower().isin(randstad_municipalities)][["PC4", "Gemeente name"]]

# Add True/False column
randstad_zips = set(randstad_zips["PC4"].unique())
candidate_locs["in_randstad"] = candidate_locs["from_post_code"].apply(lambda x: int(x[:4])).isin(randstad_zips)

merged_df["in_randstad"] = merged_df["client_mgmcompany_cacheddetails_address_postalcode"].apply(lambda x: int(x) if not pd.isna(x) else -1).isin(randstad_zips)

# Convert to dictionary for easy look-up
ran_dict = candidate_locs[["candidate_id", "in_randstad"]].to_dict()
ran_dict_candidates = {"c" + str(ran_dict["candidate_id"][k]): ran_dict["in_randstad"][k] for k in ran_dict["candidate_id"].keys()}

# Convert to dictionary for easy look-up
ran_dict = merged_df[["request_mondriaan_number", "in_randstad"]].to_dict()
ran_dict_requests = {"r" + str(ran_dict["request_mondriaan_number"][k]): ran_dict["in_randstad"][k] for k in ran_dict["request_mondriaan_number"].keys()}

In [None]:
# Rc of Randstad-based companies
np.mean(merged_df["in_randstad"])

In [None]:
np.mean(candidate_locs["in_randstad"])

In [None]:
def eval_location_fairness(model, test_set, ran_dict_candidates, ran_dict_requests, R_c):    
    if model == "random":
        print(model)
    else:
        print(model.__class__.__name__)
    
    torch_model = True
    
    if model.__class__.__name__ == "e5_ranker":
        dataloader = torch.load("./dataloaders/e5_testloader.pth")
    elif model.__class__.__name__ == "conSultantBERT":
        dataloader = torch.load("./dataloaders/bert_testloader.pth")
    elif model.__class__.__name__ in ["baselineGNNModel", "OKRA"]:
        dataloader = torch.load("./dataloaders/graph_testloader.pth")
    elif model.__class__.__name__ in ["TfidfVectorizer", "Doc2Vec"] or model == "random":
        dataloader = test_set
        torch_model = False
    else:
        raise Exception("Invalid model provided")
    
    ndcg_urban = []
    ndcg_rural = []
    
    rec_share_rural = []
    
    with torch.no_grad():
        for i, batch in enumerate(dataloader):                
            if torch_model:
                # The LLM and graph models have slightly different datastructures
                if model.__class__.__name__ in ["e5_ranker", "conSultantBERT"]:
                    (candidate, requests), batch_data, batch_labels = batch
                else:
                    batch_data = batch.to(device)
                    batch_labels = batch.y
                    candidate = batch_data.tups[0][0][0]
                    requests = [i[1] for i in batch_data.tups[0]]

                print(f"Batch: {i + 1}/{len(testloader)}", end="\r")

                # Make predictions
                y_pred = model(batch_data)


                # Calculate nDCG score of current batch        
                score = ndcg_score(batch_labels.detach().cpu().unsqueeze(0), 
                                   y_pred.unsqueeze(0).detach().cpu(), k=10)
                
            else:
                ground_truth = []
                y_pred = []
                requests = []
                
                candidate = batch[0][0]
                
                if model.__class__.__name__ == "TfidfVectorizer":
                    cv_emb = model.transform([batch[0][1]])
                    
                    for request, vacancy, label in batch[1:]:        
                        req_emb = model.transform([vacancy])
                        ground_truth.append(label)
                        y_pred.append(cosine_similarity(cv_emb, req_emb)[0][0])
                        requests.append(request)
                    
                elif model.__class__.__name__ == "Doc2Vec":
                    cv_emb = model.infer_vector(gensim.utils.simple_preprocess(batch[0][1]))
                    
                    for request, vacancy, label in batch[1:]:
                        req_emb = model.infer_vector(gensim.utils.simple_preprocess(vacancy))
                        ground_truth.append(label)
                        y_pred.append(cosine_similarity([cv_emb], [req_emb])[0][0])
                        requests.append(request)
                            
                elif model == "random":
                    cv_emb = np.random.random(32)
                    
                    for request, vacancy, label in batch[1:]:
                        req_emb = np.random.random(32)
                        ground_truth.append(label)
                        y_pred.append(cosine_similarity([cv_emb], [req_emb])[0][0])
                        requests.append(request)                

                ground_truth = np.array(ground_truth)
                y_pred = np.array(y_pred)

                score = ndcg_score([ground_truth], [y_pred], k=10)
            
            # We assume candidates for whom we do not have data live outside of the Randstad
            if ran_dict_candidates.get(candidate, False):
                ndcg_urban.append(score)
            else:
                ndcg_rural.append(score)
                
            # Calculate share of recommendations for items of the in-group
            if torch_model:
                request_dist = zip(requests, [ran_dict_requests.get(req, False) for req in requests], y_pred.detach().cpu().numpy())
            else:
                request_dist = zip(requests, [ran_dict_requests.get(req, False) for req in requests], y_pred)
                
            top_10_requests = list(sorted(request_dist, key = lambda x: -x[2]))[:10]
            
            # Rural means NOT in_randstad
            rui_i_in_Ic = np.sum([not(i[1]) for i in top_10_requests])
            R_hat = 10
            
            # Number of items in the protected group / number of recommendations
            rec_share_rural.append(rui_i_in_Ic / R_hat)
                          
    return {"Urban ndcg" : (np.mean(ndcg_urban), np.std(ndcg_urban), len(ndcg_urban)), 
            "Rural ndcg": (np.mean(ndcg_rural), np.std(ndcg_rural), len(ndcg_rural)),
            "Performance disparity": np.mean(ndcg_rural) - np.mean(ndcg_urban),
            "Disparate visibility" : (np.mean(rec_share_rural) - R_c, np.std(rec_share_rural), len(rec_share_rural))}

In [None]:
testloader = torch.load("./dataloaders/graph_testloader.pth")

In [None]:
for model_type in ["random", "tf-idf", "d2v", "e5_ranker", "conSultantBERT", "baselineGNNModel", "OKRA"]:
    if model_type == "random":
        model = "random"
    elif model_type == "tf-idf":
        model = tf_idf_model
    elif model_type == "d2v":
        model = d2v_model
    elif model_type == "e5_ranker":
        model = e5_ranker(pooling=e5_best_config["pooling"]).to(device)
    elif model_type == "conSultantBERT":
        model = conSultantBERT(pooling=bert_best_config["pooling"]).to(device)
    elif model_type == "baselineGNNModel":      
        # Data.metadata() is needed to initialize the heterodata
        data = next(iter(testloader))

        # All the different node types
        typings = ["candidate", "request", "function_name", "isco_code", 
                   "education", "language", "license", "skill", "company_name", 
                   "function_id", "isco_level", "workgroup", "klass", "literal"]

        model = baselineGNNModel(data, 
                                 typings,
                                 text_embedding_size=gnn_best_config["text_embedding_size"],
                                 embedding_size=gnn_best_config["embedding_size"]).to(device)
    elif model_type == "OKRA":
        # Data.metadata() is needed to initialize the heterodata
        data = next(iter(testloader))

        # All the different node types
        typings = ["candidate", "request", "function_name", "isco_code", 
                   "education", "language", "license", "skill", "company_name", 
                   "function_id", "isco_level", "workgroup", "klass", "literal"]

        model = OKRA(data,
                     typings,
                     text_embedding_size=okra_best_config["text_embedding_size"],
                     text_pooling=okra_best_config["text_pooling"],
                     embedding_size=okra_best_config["embedding_size"],
                     pooling_method=okra_best_config["pooling_method"],
                     heads=4).to(device)  

    if model_type in ["e5_ranker", "conSultantBERT", "baselineGNNModel", "OKRA"]:
        model.load_state_dict(torch.load(f"./trained_models/{model.__class__.__name__}.pt"))
    
    
    fairness_scores = eval_location_fairness(model, test_set, ran_dict_candidates, ran_dict_requests, R_c=1 - np.mean(merged_df["in_randstad"]))
    print(fairness_scores)
    print()