# Building and Training MedGraphTrans - A Heterogeneous Graph Transformer based model used to provid context for LLMs in the medical doamin

## Setup

In [56]:
import torch
import time
import copy
import gc
import os
import pickle

import numpy as np
import pandas as pd
import torch.nn.functional as F

from typing import List
from datasets import load_dataset
from torch_geometric.nn import Linear
from torch_geometric.data import HeteroData
from torch_geometric.nn.conv import HGTConv
from sklearn.metrics import roc_auc_score
from dataclasses import dataclass

from config import ROOT_DIR
from src.utils import node_types, metadata
from src.medical_hgt.dataset_builder import MedicalQADatasetBuilder
from src.medical_hgt.ml_utils import compute_llm_confidence_diff, query_chatbot, find_subgraph_bfs, find_most_relevant_nodes

## HGT Model

In [57]:
class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()

        for node_type in node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()

        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, metadata, num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, data):
        x_dict = {node_type: self.lin_dict[node_type](x).relu_() for node_type, x in data.x_dict.items()}

        for conv in self.convs:
            x_dict = conv(x_dict, data.edge_index_dict)

        return x_dict

In [58]:
class Decoder(torch.nn.Module):
    def forward(self, x_question: torch.Tensor, x_answer: torch.Tensor, pos_edge_label_index: torch.Tensor, neg_edge_label_index: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        """
        Our decoder applies the dot-product between source and destination node embeddings to derive edge-level predictions:
        
        Args:
        x_question (torch.Tensor): Embeddings of 'question' nodes.
        x_answer (torch.Tensor): Embeddings of 'answer' nodes.
        pos_edge_label_index (torch.Tensor): Indices of positive edges (edges that exist).
        neg_edge_label_index (torch.Tensor): Indices of negative edges (edges that do not exist).

        Returns:
        tuple: A tuple containing two tensors (pos_pred, neg_pred) representing the predicted probabilities
               for positive and negative edges, respectively.
        """

        # Convert node embeddings to edge-level representations:
        pos_edge_feat_question = x_question[pos_edge_label_index[0]]
        pos_edge_feat_answer = x_answer[pos_edge_label_index[1]]

        pos_pred = F.sigmoid((pos_edge_feat_question * pos_edge_feat_answer).sum(dim=-1))

        if pos_pred.dim() == 0:
            pos_pred = pos_pred.view(1)

        neg_edge_feat_question = x_question[neg_edge_label_index[0]]
        neg_edge_feat_answer = x_answer[neg_edge_label_index[1]]

        neg_pred = F.sigmoid((neg_edge_feat_question * neg_edge_feat_answer).sum(dim=-1))

        if neg_pred.dim() == 0:
            neg_pred = neg_pred.view(1)

        return pos_pred, neg_pred

In [59]:
class MedicalHGT(torch.nn.Module):
    def __init__(self, hidden_channels=64):
        super().__init__()
        self.hgt = HGT(hidden_channels=hidden_channels, out_channels=64, num_heads=2, num_layers=1)
        self.decoder = Decoder()
        self.grads = {}  # for debugging purposes

    def forward(self, batch_data: HeteroData) -> (torch.Tensor, torch.Tensor, dict):

        self.grads = {}  # for debugging purposes
        z_dict = self.hgt(batch_data)
        for node_type in z_dict.keys():
            if z_dict[node_type].requires_grad:
                z_dict[node_type].register_hook(self.save_grad(node_type))  # for debugging purposes

        pos_pred, neg_pred = self.decoder(
            z_dict["question"],
            z_dict["answer"],
            batch_data["question", "question_correct_answer", "answer"].edge_label_index,
            batch_data["question", "question_wrong_answer", "answer"].edge_label_index,
        )

        return pos_pred, neg_pred, z_dict

    def save_grad(self, name):
        def hook(grad):
            self.grads[name] = grad

        return hook

## LLM

In [60]:
class LLM(torch.nn.Module):
    # def __init__(self, model_name="meta-llama/Llama-2-7b-chat-hf"):
    #     super().__init__()
    #     self.tokenizer = AutoTokenizer.from_pretrained(model_name)
    #     self.model = AutoModelForCausalLM.from_pretrained(model_name)

    def forward(self, knowledge_nodes_dict, nx_graph_data, dataset_question_dict, correct_answer):

        correct_answer_dict = {0: 'opa', 1: 'opb', 2: 'opc', 3: 'opd'}
        output_instructions = f'how confident are you that the correct answer is {correct_answer_dict[correct_answer]}? Return a float between 0 and 1.'
        confidence_without_context = query_chatbot(str(dataset_question_dict), output_instructions)

        confidence_diffs_dict = {}  # a dict of dicts in the form {node_type_0: {node_index_0: conf_diff_0, node_index_1: conf_diff_1...}, ...}
        for node_type, nodes_uids in knowledge_nodes_dict.items():
            if node_type not in confidence_diffs_dict:
                confidence_diffs_dict[node_type] = {}
                for node_uid in nodes_uids:
                    node_name = nx_graph_data.nodes[node_uid.item()]['name']
                    dataset_question_dict['context'] = f'The {node_type} {node_name}.'
                    confidence_with_context = query_chatbot(str(dataset_question_dict), output_instructions)
                    llm_confidence_diff = compute_llm_confidence_diff(float(confidence_without_context), float(confidence_with_context))
                    confidence_diffs_dict[node_type][node_uid] = llm_confidence_diff

        return confidence_diffs_dict

## Train Model

### Define helper classes and methods for logging and tracking the results

In [61]:
@dataclass(frozen=True)
class EpochResult:
    # "index" of the epoch
    # (this is also discernable from the position in ModelResult.epoch_results)
    epoch_num: int

    # Unix timestamps (seconds) when the epoch started/finished training, but not
    # counting evaluation
    train_start_time: int
    train_end_time: int

    # mean train loss taken across all batches
    mean_train_loss: float

    # accuracy on the training/validation set at the end of this epoch
    train_acc: float
    val_acc: float

In [62]:
@dataclass(frozen=True)
class ModelResult:
    # Unix timestamp for when the model started training
    start_time: int
    # Unix timestamp for when the model completely finished (including evaluation
    # on the test set)
    end_time: int

    # list of EpochResults -- see above
    epoch_results: list

    # model state for reloading
    state_dict: dict

    # final accuracy on the full test set (after all epochs)
    test_acc: float

    def get_total_train_time_sec(self):
        """
        Helper function for calculating the total amount of time spent training, not
        counting evaluation. In other words, this only counts the forward pass, the
        loss calculation, and backprop for each batch.
        """
        return sum([
            er.train_end_time - er.train_start_time
            for er in self.epoch_results])

    def get_total_train_time_min(self):
        """get_total_train_time_sec, converted to minutes. See above."""
        return self.get_total_train_time_sec() // 60

In [63]:
def get_time():
    """Returns the current Unix (epoch) timestamp, in seconds."""
    return round(time.time())

### Loss Functions

In [64]:
def compute_link_prediction_loss(pos_preds: torch.Tensor, neg_preds: torch.Tensor, pos_labels: torch.Tensor, neg_labels: torch.Tensor) -> torch.Tensor:
    """
    Args:
    pos_preds (torch.Tensor): Predictions for positive links, expected to be logits.
    neg_preds (torch.Tensor): Predictions for negative links, expected to be logits.
    pos_labels (torch.Tensor): Ground truth labels for positive links.
    neg_labels (torch.Tensor): Ground truth labels for negative links.

    Returns:
    torch.Tensor: The combined binary cross-entropy loss for positive and negative predictions.
    """

    # Calculate loss for positive predictions
    pos_loss = F.binary_cross_entropy_with_logits(pos_preds, pos_labels.view(-1).float())

    # Calculate loss for negative predictions
    neg_loss = F.binary_cross_entropy_with_logits(neg_preds, neg_labels.view(-1).float())

    # Combine the losses
    total_loss = pos_loss + (neg_loss / 3)

    return total_loss

In [65]:
def compute_llm_relevancy_loss(batch, z_dict, gradients_per_questions_list):
    # Initialize loss
    loss = 0.0
    num_nodes = 0

    # Iterate over all nodes to form triplets and compute loss
    for question_embedding, grads_dict in gradients_per_questions_list:
        for node_type, grad_info_dict in grads_dict.items():
            batch_node_indices = [torch.where(batch[node_type].node_uid == x)[0][0] for x in list(grad_info_dict.keys())]
            gradients_list = list(grad_info_dict.values())
            for i, node_index in enumerate(batch_node_indices):
                current_node_embedding = torch.index_select(z_dict[node_type], 0, node_index)

                # Calculate the distance between the node embedding and the central node embedding
                distance = torch.norm(current_node_embedding - question_embedding, p=2)

                # Determine the weight based on relevance
                relevance = gradients_list[i]

                # For positive relevance, penalize being far from the central node
                # For negative relevance, penalize being close to the central node
                if relevance > 0:
                    weighted_loss = relevance * distance
                elif relevance < 0:
                    # Invert the distance measure for negative relevance
                    weighted_loss = -relevance * (1 / (distance + 1e-6))  # adding a small constant to avoid division by zero
                else:  # relevance is around 0, neutral
                    weighted_loss = 0

                # Accumulate the loss
                loss += weighted_loss

            num_nodes += len(gradients_list)

    return loss / num_nodes

### Training and Evaluation

In [66]:
def evaluate_model(medical_hgt, split_loaders, split_name, device, prime_kg, frac=1.0):
    """
    Args:
    model (torch.nn.Module): The model to evaluate.
    split_loaders (dict): A dictionary containing the data loaders for different splits.
    split_name (str): The name of the split to evaluate (e.g., 'val', 'test').
    device (torch.device): The device to run the model on.
    frac (float): Fraction of the dataset to use for evaluation.

    Returns:
    float: The ROC AUC score for the evaluated split.
    """
    medical_hgt.eval()

    pos_y_true_tensors = []
    neg_y_true_tensors = []
    pos_y_pred_tensors = []
    neg_y_pred_tensors = []

    loader = split_loaders[split_name]
    num_batches = round(frac * len(loader))

    for i, batch in enumerate(loader):
        batch_num = i + 1
        print(f'\r{split_name} batch {batch_num} / {num_batches}', end='')

        batch = batch.to(device)

        with torch.no_grad():
            pos_pred, neg_pred, z_dict = medical_hgt(batch)

            pos_eval_y = batch["question", "question_correct_answer", "answer"].edge_label.squeeze()
            neg_eval_y = batch["question", "question_wrong_answer", "answer"].edge_label.squeeze()

            if pos_eval_y.dim() == 0:
                pos_eval_y = pos_eval_y.view(1)

            if neg_eval_y.dim() == 0:
                neg_eval_y = neg_eval_y.view(1)

            pos_y_pred_tensors.append(pos_pred.detach())
            neg_y_pred_tensors.append(neg_pred.detach())
            pos_y_true_tensors.append(pos_eval_y.detach())
            neg_y_true_tensors.append(neg_eval_y.detach())

            knowledge_nodes_per_question_dict = {}
            for node_index, question_node_representation in enumerate(z_dict['question']):
                subgraph_nodes_uid_dict = find_subgraph_bfs(batch, node_index, 'question')
                question_node_uid = batch['question'].node_uid[node_index]
                most_relevant_nodes = find_most_relevant_nodes(batch, z_dict, question_node_representation, subgraph_nodes_uid_dict, prime_kg)
                knowledge_nodes_per_question_dict[question_node_uid] = most_relevant_nodes

        if batch_num >= num_batches:
            break

    medical_hgt.train()

    pos_pred = torch.cat(pos_y_pred_tensors, dim=0).numpy()
    neg_pred = torch.cat(neg_y_pred_tensors, dim=0).numpy()
    pos_true = torch.cat(pos_y_true_tensors, dim=0).numpy()
    neg_true = torch.cat(neg_y_true_tensors, dim=0).numpy()

    pred = np.concatenate([pos_pred, neg_pred])
    true = np.concatenate([pos_true, neg_true])

    return roc_auc_score(true, pred), knowledge_nodes_per_question_dict

In [67]:
def train_model(medical_hgt, llm, split_loaders, device, file_name, qa_dataset, prime_kg, num_epochs=30, lr=0.001):
    medical_hgt = medical_hgt.to(device)
    llm = llm.to(device)

    medical_hgt.train()
    llm.train()

    opt = torch.optim.Adam(medical_hgt.parameters(), lr=lr)

    start_time = get_time()
    print(f'start time: {start_time}; will save results to {file_name}')

    train_loader = split_loaders['train']

    epoch_results = []

    for epoch_num in range(1, num_epochs + 1):
        train_start_time = get_time()

        train_losses = []
        pos_y_pred_tensors = []
        neg_y_pred_tensors = []
        pos_y_true_tensors = []
        neg_y_true_tensors = []

        num_batches = len(train_loader)

        for i, batch in enumerate(train_loader):
            batch_num = i + 1

            # this is a carriage return trick for overwriting past lines
            print(f'\rEpoch {epoch_num}: batch {batch_num} / {num_batches}', end='')

            opt.zero_grad()
            batch = batch.to(device)

            # internally, the medical_hgt is applied using all the batch's edges (i.e.,
            # batch.edge_index) but only outputs predictions on edges to be labeled
            # (i.e., batch.edge_label_index).
            pos_train_pred, neg_train_pred, z_dict = medical_hgt(batch)

            pos_train_y = batch["question", "question_correct_answer", "answer"].edge_label.squeeze()
            neg_train_y = batch["question", "question_wrong_answer", "answer"].edge_label.squeeze()

            if pos_train_y.dim() == 0:
                pos_train_y = pos_train_y.view(1)

            if neg_train_y.dim() == 0:
                neg_train_y = neg_train_y.view(1)

            confidence_diffs_per_question = []  # a list of tuples (question_embeddings, conf_diffs_dict_per_question)
            # compute the llm's feedback per question in the batch
            for node_index, question_node_representation in enumerate(z_dict['question']):
                qa_index = batch['question'].node_uid[node_index].item()
                subgraph_nodes_uid_dict = find_subgraph_bfs(batch, node_index, 'question')
                prompt_dict = dict(qa_dataset.iloc[qa_index].drop(['id', 'cop', 'exp']))
                correct_answer = qa_dataset.iloc[qa_index]['cop']
                current_confidence_diffs_dict = llm(subgraph_nodes_uid_dict, prime_kg, prompt_dict, correct_answer)
                confidence_diffs_per_question.append((question_node_representation, current_confidence_diffs_dict))
                break

            link_prediction_loss = compute_link_prediction_loss(pos_train_pred, neg_train_pred, pos_train_y, neg_train_y)
            llm_relevancy_loss = compute_llm_relevancy_loss(batch, z_dict, confidence_diffs_per_question)

            total_loss = (link_prediction_loss + (llm_relevancy_loss * 0.01))
            total_loss.backward()

            opt.step()

            pos_y_pred_tensors.append(pos_train_pred.detach())
            neg_y_pred_tensors.append(neg_train_pred.detach())
            pos_y_true_tensors.append(pos_train_y.detach().long())
            neg_y_true_tensors.append(neg_train_y.detach().long())

            train_losses.append(total_loss.detach().item())
            break

        train_end_time = get_time()

        pos_pred = torch.cat(pos_y_pred_tensors, dim=0).numpy()
        neg_pred = torch.cat(neg_y_pred_tensors, dim=0).numpy()
        pos_true = torch.cat(pos_y_true_tensors, dim=0).numpy()
        neg_true = torch.cat(neg_y_true_tensors, dim=0).numpy()

        pred = np.concatenate([pos_pred, neg_pred])
        true = np.concatenate([pos_true, neg_true])

        # the training ROC AUC is computed using all the predictions (and ground
        # truth labels) made during the entire epoch, across all batches. Note that
        # this is arguably a bit inconsistent with validation below since it doesn't
        # give the medical_hgt a "second try" for earlier batches, for which it couldn't
        # have yet applied anything it learned in later batches.
        train_acc = roc_auc_score(true, pred)

        # The validation ROC AUC is computed by running through the validation set
        # at the end of every epoch.
        val_acc, val_most_relevant_nodes = evaluate_model(medical_hgt, split_loaders, 'val', device, prime_kg=prime_kg)

        epoch_result = EpochResult(
            epoch_num=epoch_num,
            train_start_time=train_start_time,
            train_end_time=train_end_time,
            mean_train_loss=round(np.mean(train_losses), 4),
            train_acc=round(train_acc, 4),
            val_acc=round(val_acc, 4)
        )

        epoch_results.append(epoch_result)
        print(f'\r{epoch_result}')

    state_dict = copy.deepcopy(medical_hgt.state_dict())
    test_acc, test_most_relevant_nodes = evaluate_model(medical_hgt, split_loaders, 'test', device, prime_kg=prime_kg)

    medical_hgt.eval()

    end_time = get_time()
    medical_hgt_result = ModelResult(start_time, end_time, epoch_results, state_dict, round(test_acc, 4))
    torch.save(medical_hgt_result, file_name)

    train_time_min = medical_hgt_result.get_total_train_time_min()
    print(f'\rTest Accuracy: {test_acc:.3f}; Total Train Time: {train_time_min} min')

    return medical_hgt_result

### Set up training experiment

In [76]:
def run_experiment(data_loader_params, device, runs=2, train_data_path=None, val_data_path=None, test_data_path=None):
    """Runs a multi-trial experiment using the given DataLoaderParams."""
    # todo: instead of calling build_link_neighbor_loaders call MedicalQADatasetBuilder (make necessary adjustments) - loaders = dataset_builder.train_mini_batches, dataset_builder.val_mini_batches, dataset_builder.test_mini_batched
    if train_data_path is None:
        dataset_builder = MedicalQADatasetBuilder(
            ['datasets/graph_dataset_30_11_23/train'],
            val_ratio=data_loader_params.val_ratio,
            test_ratio=data_loader_params.test_ratio,
            disjoint_train_edges_ratio=data_loader_params.disjoint_train_edges_ratio,
            negative_sampling_ratio=data_loader_params.negative_sampling_ratio,
            batch_size=64)

        loaders = {'train': dataset_builder.train_mini_batches, 'val': dataset_builder.val_mini_batches, 'test': dataset_builder.test_mini_batches}
        qa_dataset = dataset_builder.qa_dataset
    else:
        train_data = pickle.load(open(train_data_path, 'rb'))
        val_data = pickle.load(open(val_data_path, 'rb'))
        test_data = pickle.load(open(test_data_path, 'rb'))

        loaders = {'train': train_data, 'val': val_data, 'test': test_data}
        qa_dataset = load_dataset("medmcqa")
        qa_dataset = pd.DataFrame(qa_dataset['train'])

    prime_kg = pickle.load(open(os.path.join(ROOT_DIR, 'datasets/primeKG_nx_medium.pickle'), 'rb'))

    for i in range(runs):
        file_name = data_loader_params.get_file_name() + f'_run{i + 1}.pth'
        # model = Model(all_edges_dict, hidden_channels=64)
        medical_hgt = MedicalHGT(hidden_channels=64)
        llm = LLM()
        train_model(medical_hgt, llm, loaders, device, file_name, num_epochs=data_loader_params.num_epochs, qa_dataset=qa_dataset, prime_kg=prime_kg)


In [77]:
@dataclass(frozen=True)
class DataLoaderParams:
    """Helper class for holding the parameters of LinkNeighborLoader."""
    val_ratio: float
    test_ratio: float
    disjoint_train_edges_ratio: float
    negative_sampling_ratio: int
    batch_size: int
    num_epochs: int

    def get_file_name(self):
        """Generates the file name for storing the results, based on the params."""
        folder_path = os.path.join(ROOT_DIR, 'experiments')
        return f'{folder_path}/dataloader-{self.negative_sampling_ratio}-{self.val_ratio}-{self.test_ratio}-{self.negative_sampling_ratio}-{self.batch_size}'



### Prepare preprocessed dataloaders, QA Dataset (MedMCQA) and the Knowledge Grapg (PrimeKG)

In [78]:
train_data_path = os.path.join(ROOT_DIR, 'datasets/train_data_01_12_23.pickle')
val_data_path = os.path.join(ROOT_DIR, 'datasets/val_data_01_12_23.pickle')
test_data_path = os.path.join(ROOT_DIR, 'datasets/test_data_01_12_23.pickle')

### Run training experiment model

In [80]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [81]:
data_loader_params_list = [
    # baseline
    DataLoaderParams(val_ratio=0.1, test_ratio=0.1, disjoint_train_edges_ratio=0.9, negative_sampling_ratio=3, batch_size=64, num_epochs=100)
]

In [82]:
for datat_loader_params_params in data_loader_params_list:
    run_experiment(datat_loader_params_params, device, train_data_path=train_data_path, val_data_path=val_data_path, test_data_path=test_data_path)


Found cached dataset medmcqa (/Users/shiraben-david/.cache/huggingface/datasets/medmcqa/default/1.1.0/f2fdfa9ccfbf9d148c0639e6afe3379f3c7e95c4d52d5e68ec1156e5004bd880)


  0%|          | 0/3 [00:00<?, ?it/s]

start time: 1701616074; will save results to /Users/shiraben-david/Documents/TUB/Thesis/MedTransNet/experiments/dataloader-3-0.1-0.1-3-64_run1.pth
val batch 32 / 86/ 77

KeyboardInterrupt: 