# Build heterogeneous graphs out of MedMCQA and PrimeKG

## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/Thesis/MedTransNet

import sys
sys.path.append('/content/drive/MyDrive/Thesis/MedTransNet')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Thesis/MedTransNet


In [None]:
!pip install -r requirements.txt



In [None]:
import os
import pickle
import torch

import networkx as nx
import pandas as pd

from tqdm import tqdm

from config import ROOT_DIR
from KnowledgeExtraction.subgraph_builder import SubgraphBuilder
from src.preprocess.build_raw_dataset import initiate_question_graph, embed_text
from src.preprocess.medical_ner import medical_ner
from src.utils import meta_relations_dict

## Create a SubgraphBuilder object

In [None]:
kg_path = os.path.join(ROOT_DIR, 'datasets/prime_kg_nx_63960.pickle')
embeddings_path = os.path.join(ROOT_DIR, 'datasets/prime_kg_embeddings_tensor_63960.pt')
trie_path = None
trie_save_path = os.path.join(ROOT_DIR, 'datasets/')

In [None]:
subgraph_builder = SubgraphBuilder(kg_name_or_path=kg_path,
                                       kg_embeddings_path=embeddings_path,
                                       dataset_name_or_path='medmcqa',
                                       meta_relation_types_dict=meta_relations_dict,
                                       embedding_method=embed_text,
                                       trie_path=trie_path,
                                       )



  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# Save a list of all nodes indices for efficient mapping later on
node_indices_list = [data['index'] for _, data in subgraph_builder.kg.nodes(data=True)]

In [None]:
medmcqa_df = pd.DataFrame(subgraph_builder.dataset['train'])

In [None]:
torch.cuda.is_available()

True

In [None]:
raw_data_list = []
start_index = 74251
for i, row in tqdm(medmcqa_df[start_index:74501].iterrows()):

    try:
        subgraph_builder.nx_subgraph = nx.Graph()
        original_question_series = row.drop(['id', 'cop', 'exp'])
        question = row['question']
        answer_choices = [row['opa'], row['opb'], row['opc'], row['opd']]
        correct_answer = row['cop']
        row_id = row['id']

        entities_list, entities_indices_list, num_entities_list = medical_ner([question] + answer_choices, subgraph_builder.node_embeddings, node_indices_list, subgraph_builder.kg)

        if len(entities_list) == 0:
            continue
        # reconstruct answer indices for initiating the question graph:
        start = num_entities_list[0]
        end = start + num_entities_list[1]
        index = 2
        answer_entities_dict = {}

        for choice in answer_choices:
            answer_entities_dict[choice] = entities_indices_list[start:end]
            start = end
            end += num_entities_list[min(index, len(num_entities_list) - 1)]
            index += 1

        subgraph_builder.nx_subgraph = initiate_question_graph(subgraph_builder.nx_subgraph, question, answer_choices, correct_answer, entities_indices_list[:num_entities_list[0]], answer_entities_dict, subgraph_builder.kg,
                                                                question_index=int(i))

        extracted_edges, extracted_edge_indices = subgraph_builder.extract_knowledge_from_kg(question, hops=2, neighbors_per_hop=5, entities_list=entities_list)

        if extracted_edge_indices is not None:
            subgraph_builder.expand_graph_with_knowledge(extracted_edge_indices)

        pickle.dump(subgraph_builder.nx_subgraph, open(os.path.join(ROOT_DIR, f'datasets/train/raw_train_dataset/graph_{i}.pickle'), 'wb'))
        raw_data_list.append(subgraph_builder.nx_subgraph)

        if i % 250 == 0:
            pickle.dump(raw_data_list, open(os.path.join(ROOT_DIR, f'datasets/raw_data_list_{start_index}-{i}.pickle'), 'wb'))
    except Exception as e:
        print(f"An error occurred on index {i}:", e)


250it [04:43,  1.14s/it]


In [None]:
pickle.dump(raw_data_list, open(os.path.join(ROOT_DIR, f'datasets/raw_data_list_{start_index}-{i}.pickle'), 'wb'))