# Building and Training MedGraphTrans - A Heterogeneous Graph Transformer based model used to provid context for LLMs in the medical doamin


## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/Thesis/MedTransNet

import sys
sys.path.append('/content/drive/MyDrive/Thesis/MedTransNet')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Thesis/MedTransNet


In [None]:
!pip install -r requirements.txt
!pip -q install git+https://github.com/huggingface/transformers # need to install from github
!pip -q install bitsandbytes accelerate xformers einops

Collecting transformers~=4.30.2 (from -r requirements.txt (line 19))
  Using cached transformers-4.30.2-py3-none-any.whl (7.2 MB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers~=4.30.2->-r requirements.txt (line 19))
  Using cached tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
Installing collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.15.0
    Uninstalling tokenizers-0.15.0:
      Successfully uninstalled tokenizers-0.15.0
  Attempting uninstall: transformers
    Found existing installation: transformers 4.37.0.dev0
    Uninstalling transformers-4.37.0.dev0:
      Successfully uninstalled transformers-4.37.0.dev0
Successfully installed tokenizers-0.13.3 transformers-4.30.2
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building w

In [None]:
import torch
import copy
import os
import pickle

import numpy as np
import pandas as pd
import torch.nn.functional as F

from tqdm import tqdm
from typing import Tuple
from datasets import load_dataset
from torch_geometric.data import HeteroData
from torchmetrics.classification import BinaryPrecision
from torch.optim.lr_scheduler import _LRScheduler
from sklearn.metrics import roc_auc_score
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import logging

from config import ROOT_DIR
from src.medical_hgt.model import MedicalHGT
from src.medical_hgt.llm import LLM
from src.medical_hgt.ml_utils import find_most_relevant_nodes, EpochResult, ModelResult, get_time, compute_llm_relevancy_loss, compute_link_prediction_loss

# Set up logging
log_file_path = os.path.join(ROOT_DIR, 'datasets', 'positive_llm_examples.log')
logging.basicConfig(filename=log_file_path, level=logging.INFO,
                    format='%(asctime)s %(levelname)s: %(message)s')

## Utils

In [None]:
@dataclass(frozen=True)
class ExperimentsParams:
    """Helper class for holding the parameters of a training experiment."""
    num_epochs: int
    lr: float
    channels: int
    num_heads: int
    num_layers: int

    def get_file_name(self):
        """Generates the file name for storing the results, based on the params."""
        folder_path = os.path.join(ROOT_DIR, 'experiments')
        return f'{folder_path}/experiment-{self.num_epochs}_epochs-{self.lr}_lr-{self.channels}_channels-{self.num_heads}_head-{self.num_layers}_layers_k-2_batch_norm'

## Load data

In [None]:
train_data_path = os.path.join(ROOT_DIR, 'datasets/train/train_mini_batches_32_cpu.pickle')
val_data_path = os.path.join(ROOT_DIR, 'datasets/validation/val_mini_batches_32_cpu.pickle')
test_data_path = os.path.join(ROOT_DIR, 'datasets/test_mini_batches_32_cpu.pickle')

train_data = pickle.load(open(train_data_path, 'rb'))
val_data = pickle.load(open(val_data_path, 'rb'))
test_data = pickle.load(open(test_data_path, 'rb'))

loaders = {'train': train_data, 'val': val_data, 'test': test_data}

In [None]:
len(val_data)

1485

In [None]:
qa_dataset = load_dataset("medmcqa")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/parquet/medmcqa-a004ab6a1cc08561/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/85.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.48M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/936k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/182822 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4183 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6150 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/parquet/medmcqa-a004ab6a1cc08561/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
prime_kg = pickle.load(open(os.path.join(ROOT_DIR, 'datasets/prime_kg_nx_63960.pickle'), 'rb'))

In [None]:
train_llm_feedbacks_dict = pickle.load(open(os.path.join(ROOT_DIR, 'datasets/llm_feedbacks_train.pickle'), 'rb'))

In [None]:
train_subgraphs = pickle.load(open(os.path.join(ROOT_DIR, 'datasets/subgraphs_dict_train.pickle'), 'rb'))

In [None]:
eval_llm_feedbacks_dict = pickle.load(open(os.path.join(ROOT_DIR, 'datasets/llm_feedbacks_validation.pickle'), 'rb'))
eval_question_to_subgraphs_mapping = pickle.load(open(os.path.join(ROOT_DIR, 'datasets/subgraphs_dict_val.pickle'), 'rb'))

## Load LLM

In [None]:
# 4bit quantization config

bnb_config = BitsAndBytesConfig(
          load_in_4bit=True,
          bnb_4bit_use_double_quant=True,
          bnb_4bit_quant_type="nf4",
          bnb_4bit_compute_dtype=torch.bfloat16
        )

model_name="mistralai/Mistral-7B-Instruct-v0.1"

model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, quantization_config=bnb_config, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

In [None]:
llm = LLM(model, tokenizer)

## Training and evaluation

In [None]:
class LinearDecayLR(_LRScheduler):
    def __init__(self, optimizer, decay_rate=0.00005, min_lr=0.00001, last_epoch=-1):
        self.decay_rate = decay_rate
        self.min_lr = min_lr
        super(LinearDecayLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        return [max(base_lr - self.decay_rate * self.last_epoch, self.min_lr) for base_lr in self.base_lrs]

In [None]:
helpful_llm_examples = []

In [None]:
def evaluate(llm, medical_hgt, split_loaders, split_name, device, qa_dataset, prime_kg, llm_feedbacks_dict, question_to_subgraphs_mapping, frac=1.0):
    """

    Args:
        llm: a loaded LLM
        medical_hgt: an initialized MedicalHGT
        split_loaders: a dict {train: train_batches_list, val: val_batches_list, test: test_batches_list}
        split_name: 'val' or 'test'
        device: 'cude' if available, else 'cpu'
        qa_dataset: the loaded MedMCQA dataset
        prime_kg: nx graph object a subset of PrimKG
        llm_feedbacks_dict: a mapping from questions in the MedMCQA dataset to the pre-computed LLM Feedback, answering the questions with and without context
        question_to_subgraphs_mapping: a mapping from questions in the MedMCQA dataset to their corresponding heterogeneous graphs' nodes (in for of tuples (node_type, node_uid)
        frac: a fraction of the batches to process

    Returns:
        pred: link prediction results
        true: link prediction ground truths
        llm_results: llm vanilla and context accuracies (dict)

    """

    medical_hgt.eval()

    pos_y_true_tensors = []
    neg_y_true_tensors = []
    pos_y_pred_tensors = []
    neg_y_pred_tensors = []
    average_llm_aided_confidence_list = []
    average_llm_aided_accuracy_list = []
    average_llm_vanilla_confidence_list = []
    average_llm_vanilla_accuracy_list = []

    loader = split_loaders[split_name]

    num_batches = round(frac * len(loader))

    print('Validation Batches...')
    for i, batch in enumerate(tqdm(loader)):
        batch_num = i + 1

        batch = batch.to(device)

        with torch.no_grad():

            # Forward pass
            pos_pred, neg_pred, z_dict = medical_hgt(batch)

            pos_eval_y = batch["question", "question_correct_answer", "answer"].edge_label.squeeze()

            if pos_eval_y.dim() == 0:
                pos_eval_y = pos_eval_y.view(1)

            # neg_eval_y = batch["question", "question_wrong_answer", "answer"].edge_label.squeeze()

            # Dynamically sample negative examples
            neg_indices = batch["question", "question_wrong_answer", "answer"].edge_label_index
            neg_labels = batch["question", "question_wrong_answer", "answer"].edge_label.squeeze()

            # Randomly sample a subset of negative examples
            num_neg_samples = 2 # or // 3,
            neg_sample_indices = torch.randperm(neg_indices.size(1))[:num_neg_samples * pos_eval_y.size(0)]

            neg_pred = neg_pred[neg_sample_indices]
            neg_eval_y = neg_labels[neg_sample_indices]

            if neg_eval_y.dim() == 0:
                neg_eval_y = neg_eval_y.view(1)

            pos_y_pred_tensors.append(pos_pred.detach())
            neg_y_pred_tensors.append(neg_pred.detach())
            pos_y_true_tensors.append(pos_eval_y.detach())
            neg_y_true_tensors.append(neg_eval_y.detach())

            # Retrieve the HGT's nodes representations and use them to create context for the validation questions
            correct_answer_map = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
            answer_letter_to_op_map = {'A': 'opa', 'B': 'opb', 'C': 'opc', 'D': 'opd'}

            vanilla_accuracy_list, vanilla_confidence_list, llm_aided_accuracy_list, llm_aided_confidence_list = [], [], [], []
            unseen_questions_indices = batch["question", "question_correct_answer", "answer"].edge_label_index[0]
            if unseen_questions_indices.dim() == 0:
                unseen_questions_indices = unseen_questions_indices.unsqueeze(-1)

            for question_index in unseen_questions_indices:

                question_node_representation = torch.index_select(z_dict['question'], 0, question_index)  # z_dict['question'][question_index]

                question_uid = batch['question'].node_uid[question_index].item()
                if question_uid not in llm_feedbacks_dict:
                    continue

                llm_feedback_without_context = llm_feedbacks_dict[question_uid]
                if llm_feedback_without_context.cop_confidence_without_context < 0.26:
                    subgraph_tuples = question_to_subgraphs_mapping[question_uid]
                    most_relevant_nodes = find_most_relevant_nodes(batch, z_dict, question_node_representation, subgraph_tuples, prime_kg, k=2)
                    dataset_row = qa_dataset.iloc[question_uid]
                    question_dict = dict(dataset_row.drop(['id', 'cop', 'exp']))
                    correct_answer = dataset_row['cop']
                    prompt = """Context: {}. Question: {} A. {} B. {} C. {} D. {}""".format(
                        ",".join(most_relevant_nodes),
                        question_dict['question'],
                        question_dict['opa'],
                        question_dict['opb'],
                        question_dict['opc'],
                        question_dict['opd']
                    )

                    # Process question with context
                    output_encodings, predictions = llm.inference(prompt)
                    llm_response_dict = llm.get_confidence(correct_answer_map[correct_answer], output_encodings, predictions)
                    if llm_response_dict['confidence'] == -1:
                        print(f'Wrong response format. Question {i} ignored during eval')
                        continue

                    # Accumulate Results
                    llm_aided_confidence_list.append(llm_response_dict['cop_confidence'])
                    llm_aided_accuracy_list.append(llm_response_dict['accuracy'])
                    vanilla_confidence_list.append(llm_feedback_without_context.cop_confidence_without_context)
                    vanilla_accuracy_list.append(llm_feedback_without_context.is_correct_without_context)

                    if not llm_feedback_without_context.is_correct_without_context and llm_response_dict['accuracy']:
                        log_str = f"Question {question_uid}: {question_dict['question']}\nContext-enriched Prompt: {prompt}\nLLM's response without context: {llm_feedback_without_context.response_without_context}: {question_dict[answer_letter_to_op_map[llm_feedback_without_context.response_without_context]]} --> WRONG!\nLLM's response with context: {llm_response_dict['response']}: {question_dict[answer_letter_to_op_map[llm_response_dict['response']]]} --> CORRECT!"
                        logging.info(log_str)
                        helpful_llm_examples.append(log_str)

                else:
                    # Accumulate Results
                    llm_aided_confidence_list.append(llm_feedback_without_context.cop_confidence_without_context)
                    llm_aided_accuracy_list.append(llm_feedback_without_context.is_correct_without_context)
                    vanilla_confidence_list.append(llm_feedback_without_context.cop_confidence_without_context)
                    vanilla_accuracy_list.append(llm_feedback_without_context.is_correct_without_context)

            # Calculate average performance of the batch
            batch_average_vanilla_confidence = sum(vanilla_confidence_list) / max(1, len(vanilla_confidence_list))
            batch_average_vanilla_accuracy = sum(vanilla_accuracy_list) / max(1, len(vanilla_accuracy_list))
            batch_average_context_confidence = sum(llm_aided_confidence_list) / max(1, len(llm_aided_confidence_list))
            batch_average_context_accuracy = sum(llm_aided_accuracy_list) / max(1, len(llm_aided_accuracy_list))

            if batch_average_context_confidence > 0:
                average_llm_aided_confidence_list.append(batch_average_context_confidence)
                average_llm_aided_accuracy_list.append(batch_average_context_accuracy)
                average_llm_vanilla_confidence_list.append(batch_average_vanilla_confidence)
                average_llm_vanilla_accuracy_list.append(batch_average_vanilla_accuracy)

        if batch_num >= num_batches:
            break

    medical_hgt.train()

    pos_pred = torch.cat(pos_y_pred_tensors, dim=0).cpu().numpy()
    neg_pred = torch.cat(neg_y_pred_tensors, dim=0).cpu().numpy()
    pos_true = torch.cat(pos_y_true_tensors, dim=0).cpu().numpy()
    neg_true = torch.cat(neg_y_true_tensors, dim=0).cpu().numpy()

    pred = np.concatenate([pos_pred, neg_pred])
    true = np.concatenate([pos_true, neg_true])

    llm_results = {
        'vanilla_accuracy': sum(average_llm_vanilla_accuracy_list) / max(1, len(average_llm_vanilla_accuracy_list)),
        'context_accuracy': sum(average_llm_aided_accuracy_list) / max(1, len(average_llm_aided_accuracy_list)),
    }
    pickle.dump(helpful_llm_examples, open(os.path.join(ROOT_DIR, 'datasets', f'positive_llms_examples_final.pickle'), 'wb'))


    return pred, true, llm_results

In [None]:
def train(llm,
          medical_hgt,
          split_loaders,
          device,
          file_name,
          qa_dataset,
          prime_kg,
          train_llm_feedbacks_dict,
          val_llm_feedbacks_dict,
          question_to_subgraphs_mapping,
          num_epochs=30,
          lr=0.001,
          link_prediction_loss_weight=0.3):
    """

    Args:
        llm: a loaded LLM
        medical_hgt: an initialized MedicalHGT
        split_loaders: a dict {train: train_batches_list, val: val_batches_list, test: test_batches_list}
        device: 'cude' if available, else 'cpu'
        file_name: used for saving the model during anf after training
        qa_dataset: the loaded MedMCQA dataset
        prime_kg: nx graph object a subset of PrimKG
        train_llm_feedbacks_dict: a mapping from questions in the MedMCQA train dataset to the pre-computed LLM Feedback, answering the questions with and without context
        val_llm_feedbacks_dict: a mapping from questions in the MedMCQA val dataset to the pre-computed LLM Feedback, answering the questions with and without context
        question_to_subgraphs_mapping: a mapping from questions in the MedMCQA validation dataset to their corresponding heterogeneous graphs' nodes (in for of tuples (node_type, node_uid)
        num_epochs: upper bound for the number of epochs
        lr: learning rate
        link_prediction_loss_weight: the weight of the link prediction performance to the performance of the model

    Returns:
        medical_hgt_result: a ModelResult object

    """

    medical_hgt = medical_hgt.to(device)

    medical_hgt.train()

    opt = torch.optim.Adam(medical_hgt.parameters(), lr=lr)
    scheduler = LinearDecayLR(opt)

    precision = BinaryPrecision()

    eval_qa_dataset = pd.DataFrame(qa_dataset['validation'])

    start_time = get_time()
    print(f'Saving results to {file_name}')

    train_loader = split_loaders['train']

    llm_relevancy_loss_weight = 1 - link_prediction_loss_weight

    epoch_results = []
    for epoch_num in range(1, num_epochs + 1):
        train_start_time = get_time()

        train_losses = []
        pos_y_pred_tensors = []
        neg_y_pred_tensors = []
        pos_y_true_tensors = []
        neg_y_true_tensors = []

        print("Train Batches...")
        for batch in tqdm(train_loader):
            batch = batch.to(device)

            opt.zero_grad()

            # HGT forward pass
            pos_train_pred, neg_train_pred, z_dict = medical_hgt(batch)

            pos_train_y = batch["question", "question_correct_answer", "answer"].edge_label.squeeze()
            if pos_train_y.dim() == 0:
                pos_train_y = pos_train_y.view(1)
            # neg_train_y = batch["question", "question_wrong_answer", "answer"].edge_label.squeeze()

            # Dynamically sample negative examples
            neg_indices = batch["question", "question_wrong_answer", "answer"].edge_label_index
            neg_labels = batch["question", "question_wrong_answer", "answer"].edge_label.squeeze()

            # Randomly sample a subset of negative examples
            num_neg_samples = 2 # or // 3,
            neg_sample_indices = torch.randperm(neg_indices.size(1))[:num_neg_samples * pos_train_y.size(0)]

            neg_train_pred = neg_train_pred[neg_sample_indices]
            neg_train_y = neg_labels[neg_sample_indices]

            if neg_train_y.dim() == 0:
                neg_train_y = neg_train_y.view(1)

            link_prediction_loss = compute_link_prediction_loss(pos_train_pred, neg_train_pred, pos_train_y, neg_train_y, device=device)

            llm_relevancy_loss = compute_llm_relevancy_loss(batch, z_dict, train_llm_feedbacks_dict)

            # Weighted dual-task loss
            total_loss = link_prediction_loss_weight * link_prediction_loss + llm_relevancy_loss_weight * llm_relevancy_loss

            # Backward pass
            total_loss.backward()
            opt.step()

            # Store results
            pos_y_pred_tensors.append(pos_train_pred.detach())
            neg_y_pred_tensors.append(neg_train_pred.detach())
            pos_y_true_tensors.append(pos_train_y.detach().long())
            neg_y_true_tensors.append(neg_train_y.detach().long())

            train_losses.append(total_loss.detach().item())

        train_end_time = get_time()

        # Accumulate train results
        pos_pred = torch.cat(pos_y_pred_tensors, dim=0).cpu().numpy()
        neg_pred = torch.cat(neg_y_pred_tensors, dim=0).cpu().numpy()
        pos_true = torch.cat(pos_y_true_tensors, dim=0).cpu().numpy()
        neg_true = torch.cat(neg_y_true_tensors, dim=0).cpu().numpy()

        pred = np.concatenate([pos_pred, neg_pred])
        true = np.concatenate([pos_true, neg_true])

        # the training ROC AUC is computed using all the predictions (and ground
        # truth labels) made during the entire epoch, across all batches. Note that
        # this is arguably a bit inconsistent with validation below since it doesn't
        # give the medical_hgt a "second try" for earlier batches, for which it couldn't
        # have yet applied anything it learned in later batches.
        train_roc_auc = roc_auc_score(true, pred)
        train_precision = precision(torch.tensor(pred), torch.tensor(true))

        # The validation ROC AUC is computed by running through the validation set
        # at the end of every epoch.
        val_pred, val_true, val_llm_acc_dict = evaluate(llm, medical_hgt, split_loaders, 'val', device, eval_qa_dataset, prime_kg, val_llm_feedbacks_dict, question_to_subgraphs_mapping)

        val_roc_auc = roc_auc_score(val_true, val_pred)
        val_precision = precision(torch.tensor(val_pred), torch.tensor(val_true))

        epoch_result = EpochResult(
            epoch_num=epoch_num,
            train_start_time=train_start_time,
            train_end_time=train_end_time,
            mean_train_loss=round(np.mean(train_losses), 4),
            train_roc_aoc=train_roc_auc,
            train_precision=train_precision,
            val_roc_aoc=val_roc_auc,
            val_precision=val_precision,
            llm_results = val_llm_acc_dict
        )

        epoch_results.append(epoch_result)
        print(f'\r{epoch_result}')

        scheduler.step()

        if epoch_num % 5 == 0:
            current_time = get_time()
            current_state_dict = copy.deepcopy(medical_hgt.state_dict())
            current_test_pred, current_test_true, current_test_llm_acc_dict = evaluate(llm, medical_hgt, split_loaders, 'test', device, eval_qa_dataset, prime_kg, val_llm_feedbacks_dict, question_to_subgraphs_mapping)

            current_test_roc_auc = roc_auc_score(current_test_true, current_test_pred)
            current_test_precision = precision(torch.tensor(current_test_pred), torch.tensor(current_test_true))
            current_hgt_result = ModelResult(start_time, current_time, epoch_results, current_state_dict, current_test_roc_auc, current_test_precision, current_test_llm_acc_dict)
            torch.save(current_hgt_result, os.path.join(ROOT_DIR, 'experiments', f'experiment_{epoch_num}-epochs_2-heads_1-layer_32-channels_2-k.pth'))

    state_dict = copy.deepcopy(medical_hgt.state_dict())

    # Run through the test set
    test_pred, test_true, test_llm_acc_dict = evaluate(llm, medical_hgt, split_loaders, 'test', device, eval_qa_dataset, prime_kg, val_llm_feedbacks_dict, question_to_subgraphs_mapping)

    test_roc_auc = roc_auc_score(test_true, test_pred)
    test_precision = precision(torch.tensor(test_pred), torch.tensor(test_true))
    medical_hgt.eval()

    end_time = get_time()

    medical_hgt_result = ModelResult(start_time, end_time, epoch_results, state_dict, test_roc_auc, test_precision, test_llm_acc_dict)
    torch.save(medical_hgt_result, file_name)

    train_time_min = medical_hgt_result.get_total_train_time_min()
    print(f'\rTest Accuracy: {test_roc_auc:.3f}; LLM Results: {test_llm_acc_dict}, Total Train Time: {train_time_min} min')

    return medical_hgt_result

## Run training experiment

In [None]:
def run_experiments(experiment_params, device, llm, train_llm_feedbacks_dict, val_llm_feedbacks_dict, question_to_subgraphs_mapping, prime_kg, qa_dataset, data_loaders):
    """Runs a multi-trial experiment using the given ExperimentsParams."""

    file_name = f'{experiment_params.get_file_name()}.pth'
    medical_hgt = MedicalHGT(channels=experiment_params.channels, num_heads=experiment_params.num_heads, num_layers=experiment_params.num_layers, batch_size=32)
    medical_hgt_result = train(llm=llm,
                               medical_hgt=medical_hgt,
                               split_loaders=data_loaders,
                               device=device,
                               file_name=file_name,
                               qa_dataset=qa_dataset,
                               prime_kg=prime_kg,
                               train_llm_feedbacks_dict=train_llm_feedbacks_dict,
                               val_llm_feedbacks_dict=val_llm_feedbacks_dict,
                               question_to_subgraphs_mapping=question_to_subgraphs_mapping,
                               num_epochs=experiment_params.num_epochs,
                               lr=experiment_params.lr)

    return medical_hgt_result

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Is cuda available? ', torch.cuda.is_available())
torch.cuda.empty_cache()

Is cuda available?  True


In [None]:
# Define a list of experiments
experiments_list = [
    ExperimentsParams(num_epochs=50, lr=0.001, channels=32, num_heads=2, num_layers=1), # k = 2
]

In [None]:
for i, experiment_params in enumerate(experiments_list):
    experiment_results_list = []
    print(f"Running experiment {i}\n")
    print(f"Experiment peremeters: lr: {experiment_params.lr}, channels: {experiment_params.channels}, num_heads: {experiment_params.num_heads}, num_layers: {experiment_params.num_layers}")
    experiment_result = run_experiments(experiment_params,
                                        device=device,
                                        llm=llm,
                                        train_llm_feedbacks_dict=train_llm_feedbacks_dict,
                                        val_llm_feedbacks_dict=eval_llm_feedbacks_dict,
                                        question_to_subgraphs_mapping=eval_question_to_subgraphs_mapping,
                                        prime_kg=prime_kg,
                                        qa_dataset=qa_dataset,
                                        data_loaders=loaders)

    experiment_results_list.append(experiment_result)



Running experiment 0

Experiment peremeters: lr: 0.001, channels: 32, num_heads: 2, num_layers: 1
Saving results to /content/drive/MyDrive/Thesis/MedTransNet/experiments/experiment-50_epochs-0.001_lr-32_channels-2_head-1_layers_k-2_batch_norm.pth
Train Batches...


100%|██████████| 2969/2969 [09:52<00:00,  5.01it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:09<00:00,  4.79it/s]


EpochResult(epoch_num=1, train_start_time=1704888801, train_end_time=1704889393, mean_train_loss=0.3268, train_roc_aoc=0.5512286451654214, train_precision=tensor(0.3650), val_roc_aoc=0.5558873139759999, val_precision=tensor(0.3691), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.34940436053045626})
Train Batches...


100%|██████████| 2969/2969 [10:28<00:00,  4.72it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:06<00:00,  4.85it/s]


EpochResult(epoch_num=2, train_start_time=1704889703, train_end_time=1704890332, mean_train_loss=0.3229, train_roc_aoc=0.5810726775563105, train_precision=tensor(0.3860), val_roc_aoc=0.5530679165017323, val_precision=tensor(0.3651), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.3526635198921106})
Train Batches...


100%|██████████| 2969/2969 [10:30<00:00,  4.71it/s]


Validation Batches...


 59%|█████▉    | 880/1485 [03:00<02:18,  4.37it/s]

Wrong response format. Question 879 ignored during eval


100%|█████████▉| 1484/1485 [05:05<00:00,  4.86it/s]


EpochResult(epoch_num=3, train_start_time=1704890638, train_end_time=1704891269, mean_train_loss=0.3199, train_roc_aoc=0.6008886043802906, train_precision=tensor(0.3996), val_roc_aoc=0.5531586509324896, val_precision=tensor(0.3660), llm_results={'vanilla_accuracy': 0.3191632928475034, 'context_accuracy': 0.35678137651821856})
Train Batches...


100%|██████████| 2969/2969 [10:30<00:00,  4.71it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:08<00:00,  4.81it/s]


EpochResult(epoch_num=4, train_start_time=1704891575, train_end_time=1704892206, mean_train_loss=0.3167, train_roc_aoc=0.6196168170260029, train_precision=tensor(0.4109), val_roc_aoc=0.5448555296680716, val_precision=tensor(0.3612), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.3565407956844235})
Train Batches...


100%|██████████| 2969/2969 [10:30<00:00,  4.71it/s]


Validation Batches...


 79%|███████▉  | 1171/1485 [04:01<01:09,  4.55it/s]

Wrong response format. Question 1170 ignored during eval


100%|█████████▉| 1484/1485 [05:05<00:00,  4.85it/s]


EpochResult(epoch_num=5, train_start_time=1704892514, train_end_time=1704893145, mean_train_loss=0.3133, train_roc_aoc=0.6371497884343068, train_precision=tensor(0.4244), val_roc_aoc=0.5478538176235486, val_precision=tensor(0.3609), llm_results={'vanilla_accuracy': 0.3191632928475034, 'context_accuracy': 0.3593117408906882})
Validation Batches...


 18%|█▊        | 559/3095 [01:47<06:30,  6.50it/s]

Wrong response format. Question 558 ignored during eval


 25%|██▌       | 780/3095 [02:33<06:41,  5.77it/s]

Wrong response format. Question 780 ignored during eval


 29%|██▉       | 903/3095 [02:57<08:52,  4.11it/s]

Wrong response format. Question 902 ignored during eval


 30%|██▉       | 915/3095 [03:00<06:39,  5.46it/s]

Wrong response format. Question 914 ignored during eval


100%|█████████▉| 3080/3095 [09:58<00:02,  5.69it/s]

Wrong response format. Question 3079 ignored during eval


100%|█████████▉| 3094/3095 [10:00<00:00,  5.15it/s]


Train Batches...


100%|██████████| 2969/2969 [10:32<00:00,  4.69it/s]


Validation Batches...


 79%|███████▉  | 1171/1485 [04:03<01:08,  4.58it/s]

Wrong response format. Question 1170 ignored during eval


100%|█████████▉| 1484/1485 [05:08<00:00,  4.82it/s]


EpochResult(epoch_num=6, train_start_time=1704894051, train_end_time=1704894684, mean_train_loss=0.3099, train_roc_aoc=0.6520021152786954, train_precision=tensor(0.4339), val_roc_aoc=0.5498882824647547, val_precision=tensor(0.3701), llm_results={'vanilla_accuracy': 0.3191632928475034, 'context_accuracy': 0.35666891587944216})
Train Batches...


100%|██████████| 2969/2969 [10:36<00:00,  4.67it/s]


Validation Batches...


 79%|███████▉  | 1171/1485 [04:03<01:07,  4.65it/s]

Wrong response format. Question 1170 ignored during eval


100%|█████████▉| 1484/1485 [05:08<00:00,  4.81it/s]


EpochResult(epoch_num=7, train_start_time=1704894992, train_end_time=1704895628, mean_train_loss=0.3067, train_roc_aoc=0.6657407803567404, train_precision=tensor(0.4444), val_roc_aoc=0.5512450016256074, val_precision=tensor(0.3662), llm_results={'vanilla_accuracy': 0.3191632928475034, 'context_accuracy': 0.3544197031039136})
Train Batches...


100%|██████████| 2969/2969 [10:35<00:00,  4.67it/s]


Validation Batches...


 74%|███████▍  | 1097/1485 [03:49<00:59,  6.53it/s]

Wrong response format. Question 1098 ignored during eval


 79%|███████▉  | 1171/1485 [04:03<01:07,  4.64it/s]

Wrong response format. Question 1170 ignored during eval


100%|█████████▉| 1484/1485 [05:07<00:00,  4.82it/s]


EpochResult(epoch_num=8, train_start_time=1704895937, train_end_time=1704896573, mean_train_loss=0.3033, train_roc_aoc=0.6790211567265306, train_precision=tensor(0.4536), val_roc_aoc=0.5464670173591296, val_precision=tensor(0.3713), llm_results={'vanilla_accuracy': 0.3191632928475034, 'context_accuracy': 0.3527327935222671})
Train Batches...


100%|██████████| 2969/2969 [10:32<00:00,  4.69it/s]


Validation Batches...


 59%|█████▉    | 880/1485 [03:02<02:17,  4.41it/s]

Wrong response format. Question 879 ignored during eval


 79%|███████▉  | 1171/1485 [04:03<01:08,  4.57it/s]

Wrong response format. Question 1170 ignored during eval


100%|█████████▉| 1484/1485 [05:07<00:00,  4.82it/s]


EpochResult(epoch_num=9, train_start_time=1704896881, train_end_time=1704897513, mean_train_loss=0.3004, train_roc_aoc=0.6894849520809652, train_precision=tensor(0.4618), val_roc_aoc=0.5467056403174207, val_precision=tensor(0.3734), llm_results={'vanilla_accuracy': 0.3193787981093856, 'context_accuracy': 0.3522957461174882})
Train Batches...


100%|██████████| 2969/2969 [10:33<00:00,  4.69it/s]


Validation Batches...


 59%|█████▉    | 880/1485 [03:01<02:19,  4.33it/s]

Wrong response format. Question 879 ignored during eval


 79%|███████▉  | 1171/1485 [04:02<01:08,  4.56it/s]

Wrong response format. Question 1170 ignored during eval


100%|█████████▉| 1484/1485 [05:07<00:00,  4.83it/s]


EpochResult(epoch_num=10, train_start_time=1704897821, train_end_time=1704898455, mean_train_loss=0.2969, train_roc_aoc=0.7012688251978746, train_precision=tensor(0.4699), val_roc_aoc=0.5419052486982863, val_precision=tensor(0.3690), llm_results={'vanilla_accuracy': 0.3193787981093856, 'context_accuracy': 0.3522957461174882})
Validation Batches...


 18%|█▊        | 557/3095 [01:47<04:16,  9.89it/s]

Wrong response format. Question 558 ignored during eval


 29%|██▉       | 903/3095 [02:59<09:10,  3.99it/s]

Wrong response format. Question 902 ignored during eval


 30%|██▉       | 915/3095 [03:01<06:40,  5.45it/s]

Wrong response format. Question 914 ignored during eval


100%|█████████▉| 3080/3095 [10:02<00:02,  5.79it/s]

Wrong response format. Question 3079 ignored during eval


100%|█████████▉| 3094/3095 [10:04<00:00,  5.12it/s]


Train Batches...


100%|██████████| 2969/2969 [10:32<00:00,  4.70it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:09<00:00,  4.80it/s]


EpochResult(epoch_num=11, train_start_time=1704899367, train_end_time=1704899999, mean_train_loss=0.2934, train_roc_aoc=0.7130892474594498, train_precision=tensor(0.4788), val_roc_aoc=0.5408742649160528, val_precision=tensor(0.3663), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.3537311755450663})
Train Batches...


100%|██████████| 2969/2969 [10:38<00:00,  4.65it/s]


Validation Batches...


 59%|█████▉    | 880/1485 [03:02<02:20,  4.32it/s]

Wrong response format. Question 879 ignored during eval


100%|█████████▉| 1484/1485 [05:09<00:00,  4.80it/s]


EpochResult(epoch_num=12, train_start_time=1704900309, train_end_time=1704900947, mean_train_loss=0.291, train_roc_aoc=0.721069470028231, train_precision=tensor(0.4859), val_roc_aoc=0.5479816316187023, val_precision=tensor(0.3729), llm_results={'vanilla_accuracy': 0.3191632928475034, 'context_accuracy': 0.35048358074673863})
Train Batches...


100%|██████████| 2969/2969 [10:49<00:00,  4.57it/s]


Validation Batches...


 59%|█████▉    | 880/1485 [03:02<02:24,  4.20it/s]

Wrong response format. Question 879 ignored during eval


100%|█████████▉| 1484/1485 [05:08<00:00,  4.81it/s]


EpochResult(epoch_num=13, train_start_time=1704901256, train_end_time=1704901906, mean_train_loss=0.2878, train_roc_aoc=0.7305332828487088, train_precision=tensor(0.4926), val_roc_aoc=0.5350962988210172, val_precision=tensor(0.3617), llm_results={'vanilla_accuracy': 0.3191632928475034, 'context_accuracy': 0.3520017993702204})
Train Batches...


100%|██████████| 2969/2969 [10:42<00:00,  4.62it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:07<00:00,  4.82it/s]


EpochResult(epoch_num=14, train_start_time=1704902215, train_end_time=1704902857, mean_train_loss=0.2851, train_roc_aoc=0.7385102556434829, train_precision=tensor(0.4989), val_roc_aoc=0.5404986808515235, val_precision=tensor(0.3619), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.3527197122948977})
Train Batches...


100%|██████████| 2969/2969 [10:38<00:00,  4.65it/s]


Validation Batches...


 59%|█████▉    | 880/1485 [03:02<02:23,  4.22it/s]

Wrong response format. Question 879 ignored during eval


100%|█████████▉| 1484/1485 [05:09<00:00,  4.80it/s]


EpochResult(epoch_num=15, train_start_time=1704903165, train_end_time=1704903804, mean_train_loss=0.2824, train_roc_aoc=0.746176553401656, train_precision=tensor(0.5044), val_roc_aoc=0.5415059680909475, val_precision=tensor(0.3687), llm_results={'vanilla_accuracy': 0.3191632928475034, 'context_accuracy': 0.3567813765182186})
Validation Batches...


 18%|█▊        | 557/3095 [01:48<04:11, 10.10it/s]

Wrong response format. Question 558 ignored during eval


 29%|██▉       | 903/3095 [03:00<09:08,  4.00it/s]

Wrong response format. Question 902 ignored during eval


 30%|██▉       | 915/3095 [03:03<06:48,  5.33it/s]

Wrong response format. Question 914 ignored during eval


100%|█████████▉| 3080/3095 [10:07<00:02,  5.77it/s]

Wrong response format. Question 3079 ignored during eval


100%|█████████▉| 3094/3095 [10:09<00:00,  5.07it/s]


Train Batches...


100%|██████████| 2969/2969 [10:39<00:00,  4.64it/s]


Validation Batches...


 59%|█████▉    | 880/1485 [03:02<02:22,  4.25it/s]

Wrong response format. Question 879 ignored during eval


100%|█████████▉| 1484/1485 [05:08<00:00,  4.81it/s]


EpochResult(epoch_num=16, train_start_time=1704904724, train_end_time=1704905363, mean_train_loss=0.2802, train_roc_aoc=0.7525296854018468, train_precision=tensor(0.5101), val_roc_aoc=0.5431803191494586, val_precision=tensor(0.3720), llm_results={'vanilla_accuracy': 0.3191632928475034, 'context_accuracy': 0.3549820062977958})
Train Batches...


100%|██████████| 2969/2969 [10:40<00:00,  4.64it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:08<00:00,  4.81it/s]


EpochResult(epoch_num=17, train_start_time=1704905672, train_end_time=1704906313, mean_train_loss=0.2776, train_roc_aoc=0.7592813119931834, train_precision=tensor(0.5161), val_roc_aoc=0.538977252301143, val_precision=tensor(0.3628), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.3535625983367049})
Train Batches...


100%|██████████| 2969/2969 [10:39<00:00,  4.65it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:08<00:00,  4.80it/s]


EpochResult(epoch_num=18, train_start_time=1704906622, train_end_time=1704907261, mean_train_loss=0.2754, train_roc_aoc=0.7650591212130624, train_precision=tensor(0.5209), val_roc_aoc=0.5400152959342135, val_precision=tensor(0.3674), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.35221398066981346})
Train Batches...


100%|██████████| 2969/2969 [10:38<00:00,  4.65it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:09<00:00,  4.79it/s]


EpochResult(epoch_num=19, train_start_time=1704907570, train_end_time=1704908209, mean_train_loss=0.2736, train_roc_aoc=0.769634806918042, train_precision=tensor(0.5245), val_roc_aoc=0.5438325265867844, val_precision=tensor(0.3692), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.35356259833670484})
Train Batches...


100%|██████████| 2969/2969 [10:37<00:00,  4.66it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:09<00:00,  4.80it/s]


EpochResult(epoch_num=20, train_start_time=1704908518, train_end_time=1704909156, mean_train_loss=0.272, train_roc_aoc=0.773989299583837, train_precision=tensor(0.5287), val_roc_aoc=0.5389193615242795, val_precision=tensor(0.3627), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.3531130591144077})
Validation Batches...


 18%|█▊        | 557/3095 [01:49<04:14,  9.97it/s]

Wrong response format. Question 558 ignored during eval


 29%|██▉       | 903/3095 [03:01<09:18,  3.93it/s]

Wrong response format. Question 902 ignored during eval


 30%|██▉       | 915/3095 [03:03<06:46,  5.36it/s]

Wrong response format. Question 914 ignored during eval


100%|█████████▉| 3080/3095 [10:06<00:02,  5.83it/s]

Wrong response format. Question 3079 ignored during eval


100%|█████████▉| 3094/3095 [10:09<00:00,  5.08it/s]


Train Batches...


100%|██████████| 2969/2969 [10:39<00:00,  4.64it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:08<00:00,  4.82it/s]


EpochResult(epoch_num=21, train_start_time=1704910075, train_end_time=1704910714, mean_train_loss=0.2705, train_roc_aoc=0.7778877629296795, train_precision=tensor(0.5322), val_roc_aoc=0.5386731875950931, val_precision=tensor(0.3658), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.35345021353113054})
Train Batches...


100%|██████████| 2969/2969 [10:38<00:00,  4.65it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:08<00:00,  4.82it/s]


EpochResult(epoch_num=22, train_start_time=1704911023, train_end_time=1704911661, mean_train_loss=0.2706, train_roc_aoc=0.7778197079737292, train_precision=tensor(0.5324), val_roc_aoc=0.5404930329708539, val_precision=tensor(0.3667), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.35345021353113054})
Train Batches...


100%|██████████| 2969/2969 [10:39<00:00,  4.64it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:12<00:00,  4.75it/s]


EpochResult(epoch_num=23, train_start_time=1704911969, train_end_time=1704912609, mean_train_loss=0.2703, train_roc_aoc=0.778536958580609, train_precision=tensor(0.5324), val_roc_aoc=0.5436182754613827, val_precision=tensor(0.3730), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.35345021353113054})
Train Batches...


100%|██████████| 2969/2969 [10:38<00:00,  4.65it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:08<00:00,  4.81it/s]


EpochResult(epoch_num=24, train_start_time=1704912921, train_end_time=1704913560, mean_train_loss=0.2701, train_roc_aoc=0.7789081724070646, train_precision=tensor(0.5334), val_roc_aoc=0.5391969671371925, val_precision=tensor(0.3656), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.35345021353113054})
Train Batches...


100%|██████████| 2969/2969 [10:39<00:00,  4.64it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:10<00:00,  4.78it/s]


EpochResult(epoch_num=25, train_start_time=1704913868, train_end_time=1704914508, mean_train_loss=0.2704, train_roc_aoc=0.7780869543566427, train_precision=tensor(0.5325), val_roc_aoc=0.5434303606491035, val_precision=tensor(0.3676), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.3531130591144077})
Validation Batches...


 18%|█▊        | 557/3095 [01:49<04:19,  9.78it/s]

Wrong response format. Question 558 ignored during eval


 29%|██▉       | 903/3095 [03:01<09:11,  3.97it/s]

Wrong response format. Question 902 ignored during eval


 30%|██▉       | 915/3095 [03:04<06:50,  5.31it/s]

Wrong response format. Question 914 ignored during eval


100%|█████████▉| 3080/3095 [10:09<00:02,  5.76it/s]

Wrong response format. Question 3079 ignored during eval


100%|█████████▉| 3094/3095 [10:12<00:00,  5.05it/s]


Train Batches...


100%|██████████| 2969/2969 [10:46<00:00,  4.59it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:14<00:00,  4.72it/s]


EpochResult(epoch_num=26, train_start_time=1704915431, train_end_time=1704916078, mean_train_loss=0.2698, train_roc_aoc=0.7796121453010232, train_precision=tensor(0.5342), val_roc_aoc=0.5393236761122151, val_precision=tensor(0.3686), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.3537873679478534})
Train Batches...


100%|██████████| 2969/2969 [11:06<00:00,  4.46it/s]


Validation Batches...


100%|█████████▉| 1484/1485 [05:15<00:00,  4.70it/s]


EpochResult(epoch_num=27, train_start_time=1704916392, train_end_time=1704917059, mean_train_loss=0.2701, train_roc_aoc=0.7788589635080403, train_precision=tensor(0.5332), val_roc_aoc=0.5451365731213921, val_precision=tensor(0.3676), llm_results={'vanilla_accuracy': 0.3189480782198247, 'context_accuracy': 0.3531130591144077})
Train Batches...


 33%|███▎      | 993/2969 [03:56<07:23,  4.46it/s]