In [None]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import os
import joblib
import numpy as np

import torch
from torch_geometric.data import HeteroData

In [None]:
data = torch.load("/content/drive/MyDrive/AIT/ML/Project/personalized_medical_recommendation/preprocessing/processed_graph.pt", weights_only=False)

ckpt_path = "/content/drive/MyDrive/AIT/ML/Project/personalized_medical_recommendation/checkpoints"

In [None]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv

class HeteroSAGELinkPredictor(nn.Module):
    def __init__(self, metadata, output_dims, hidden_channels=64, out_channels=32):
        super().__init__()

        # Message passing layers
        self.conv1 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_channels)
            for edge_type in metadata[1]
        }, aggr='sum')

        self.conv2 = HeteroConv({
            edge_type: SAGEConv((hidden_channels, hidden_channels), out_channels)
            for edge_type in metadata[1]
        }, aggr='sum')

        # NEW: MLP that maps raw patient features → same space as GNN
        self.patient_encoder = nn.Sequential(
            nn.Linear(44, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )

        # Decoder: combines patient, medication, and optional context features
        self.decoder = nn.Sequential(
            nn.Linear(out_channels * 5, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, 1)
        )

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

    def encode_patient(self, patient_features):
        return self.patient_encoder(patient_features)

    def decode(
        self,
        z_patient,
        z_medication,
        z_disease,
        z_procedure,
        z_lab,
        edge_index,
        disease_ids=None,
        procedure_ids=None,
        lab_ids=None,
    ):
        src, dst = edge_index

        # Default to zero vectors if ids are None (i.e., for inference without all context)
        def get_context_embeddings(z, ids):
            if ids is None:
                return torch.zeros_like(z_patient[src])
            return z[ids]

        disease_emb = get_context_embeddings(z_disease, disease_ids)
        proc_emb    = get_context_embeddings(z_procedure, procedure_ids)
        lab_emb     = get_context_embeddings(z_lab, lab_ids)

        # Concatenate all embeddings
        combined = torch.cat([
            z_patient[src],
            z_medication[dst],
            disease_emb,
            proc_emb,
            lab_emb
        ], dim=1)

        return self.decoder(combined).squeeze()

In [None]:
from torch_geometric.nn import HeteroConv, SAGEConv

data = data.to(device)

with torch.no_grad():
    temp_conv1 = HeteroConv({
        edge_type: SAGEConv((-1, -1), out_channels=64)
        for edge_type in data.edge_types
    }, aggr='sum').to(device)

    temp_output = temp_conv1(data.x_dict, data.edge_index_dict)
    output_dims = {node_type: feat.shape[1] for node_type, feat in temp_output.items()}

In [None]:
# Instantiate the model again with correct metadata and output dims
model = HeteroSAGELinkPredictor(data.metadata(), output_dims=output_dims).to(device)

# Load the best model weights
model.load_state_dict(torch.load(os.path.join(ckpt_path, "best_model_acc.pt")))
model.eval()  # Set to eval mode

HeteroSAGELinkPredictor(
  (conv1): HeteroConv(num_relations=8)
  (conv2): HeteroConv(num_relations=8)
  (patient_encoder): Sequential(
    (0): Linear(in_features=44, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=32, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=160, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=1, bias=True)
  )
)

In [None]:
from collections import defaultdict
import torch

def build_first_association_map(src_nodes, dst_nodes):
    """Build a mapping from patient to first associated node (disease/procedure/lab)."""
    assoc_map = {}
    for src, dst in zip(src_nodes.tolist(), dst_nodes.tolist()):
        if src not in assoc_map:
            assoc_map[src] = dst
    return assoc_map

def get_first_associated_node(assoc_map, patient_ids, default_val=-1):
    """Return first associated node for each patient in patient_ids."""
    return torch.tensor(
        [assoc_map.get(pid.item(), default_val) for pid in patient_ids],
        dtype=torch.long
    )

In [None]:
# Build once before training
edge_pd = data["patient", "has_disease", "disease"].edge_index
edge_pp = data["patient", "underwent", "procedure"].edge_index
edge_pl = data["patient", "has_lab", "lab"].edge_index

disease_map = build_first_association_map(edge_pd[0], edge_pd[1])
procedure_map = build_first_association_map(edge_pp[0], edge_pp[1])
lab_map = build_first_association_map(edge_pl[0], edge_pl[1])

In [None]:
# Step 1: Get node embeddings
with torch.no_grad():
    z_dict = model(data.x_dict, data.edge_index_dict)

In [None]:
# Example: predict score for patient node 42 and medication nodes 0-9
patient_id = 1500
num_meds = z_dict["medication"].shape[0]
med_ids = torch.arange(num_meds).to(device)

# Step 2: Replace patient embeddings using the patient_encoder MLP
patient_features = data["patient"].x.to(device)
z_dict["patient"] = model.encode_patient(patient_features)  # <-- Encoder used here

# Repeat patient_id N times (same length as med_ids)
patients = torch.tensor([patient_id] * len(med_ids), device=device)

edge_index = torch.stack([
    torch.full_like(med_ids, fill_value=patient_id),  # source: patient
    med_ids                                           # destination: meds
], dim=0)

disease_ids   = get_first_associated_node(disease_map, patients)
procedure_ids = get_first_associated_node(procedure_map, patients)
lab_ids       = get_first_associated_node(lab_map, patients)

# Predict
with torch.no_grad():
    scores = model.decode(
        z_patient     = z_dict["patient"],
        z_medication  = z_dict["medication"],
        z_disease     = z_dict["disease"],
        z_procedure   = z_dict["procedure"],
        z_lab         = z_dict["lab"],
        edge_index    = edge_index,
        disease_ids   = disease_ids,
        procedure_ids = procedure_ids,
        lab_ids       = lab_ids
    )
    probs = torch.sigmoid(scores)

In [None]:
topk = 15
top_indices = probs.topk(topk).indices
recommended_meds = med_ids[top_indices].cpu().numpy()

print("Recommended Medication IDs:", recommended_meds)

Recommended Medication IDs: [  4  11  39   3  85  16   1  50 111 100 104 419  10  17 237]


## Get Medications' names

In [None]:
data_path = "/content/drive/MyDrive/AIT/ML/Project/personalized_medical_recommendation/preprocessing/processed"

In [None]:
import json

with open(os.path.join(data_path, "mappings", "id_to_medication.json"), 'r') as file:
    med_map = json.load(file)

In [None]:
print(f"Top {topk} Recommended medications for patient {patient_id}:\n")
for med_id in recommended_meds:
    print("-", med_map[str(med_id)])

Top 15 Recommended medications for patient 1500:

- Furosemide
- Potassium Chloride
- Insulin
- D5W
- Docusate Sodium
- Acetaminophen
- Warfarin
- Metoprolol
- Sodium Chloride 0.9%  Flush
- Ranitidine
- Milk of Magnesia
- Metoprolol Tartrate
- Magnesium Sulfate
- Oxycodone-Acetaminophen
- Syringe


## Patient's Condition

In [None]:
import json

with open(os.path.join(data_path, "mappings", "id_to_disease.json"), "r") as f:
    disease_mapping = json.load(f)

In [None]:
edge_index = data["patient", "has_disease", "disease"].edge_index

# Filter edges where patient is the source
mask = edge_index[0] == patient_id
disease_ids = edge_index[1][mask].cpu().numpy()
disease_ids

array([ 19,   8,  62,  32,   8, 180, 156,  48,  35])

In [None]:
print(f"Conditions of patient {patient_id}:\n")
for disease_id in disease_ids:
    print("-", disease_mapping[str(disease_id)])

Conditions of patient 1500:

- Coronary atherosclerosis and other heart disease
- Complications of surgical procedures or medical care
- Cardiac dysrhythmias
- Pleurisy; pneumothorax; pulmonary collapse
- Complications of surgical procedures or medical care
- Other hematologic conditions
- Acute bronchitis
- Thyroid disorders
- Essential hypertension


## Get Patient's Information

In [None]:
# Load saved encoder and scaler
encoder = joblib.load(os.path.join(data_path, "patient_gender_ethnicity_encoder.pkl"))
age_scaler = joblib.load(os.path.join(data_path, "patient_age_scaler.pkl"))

In [None]:
from sklearn.preprocessing import OneHotEncoder

# Assuming feature layout: [scaled_age, onehot_gender, onehot_ethnicity]
patient_vector = data["patient"].x[patient_id].cpu().numpy()

scaled_age = patient_vector[0:1]
encoded_demo = patient_vector[1:]

# Recover original age
original_age = age_scaler.inverse_transform(scaled_age.reshape(1, -1))[0][0]

# Recover gender and ethnicity
original_demo = encoder.inverse_transform(encoded_demo.reshape(1, -1))[0]
original_gender, original_ethnicity = original_demo

print("Recovered Info:")
print("Age:", original_age)
print("Gender:", original_gender)
print("Ethnicity:", original_ethnicity)

Recovered Info:
Age: 75.0
Gender: 1
Ethnicity: 1


## Patient's Procedures

In [None]:
with open(os.path.join(data_path, "mappings", "id_to_procedure.json"), "r") as f:
    procedure_mapping = json.load(f)

In [None]:
# Step 1: Get edge index for "underwent" relation
edge_index_proc = data["patient", "underwent", "procedure"].edge_index

# Step 2: Filter where patient is the source node
mask_proc = edge_index_proc[0] == patient_id

# Step 3: Get corresponding procedure IDs (target nodes)
procedure_ids = edge_index_proc[1][mask_proc].cpu().numpy()
procedure_ids

array([ 5,  5,  6, 19, 24, 21,  2])

In [None]:
print(f"Procedures of patient {patient_id}:\n")
for pid in procedure_ids:
    print("-", procedure_mapping[str(pid)])

Procedures of patient 1500:

- Coronary artery bypass graft (CABG)
- Coronary artery bypass graft (CABG)
- Extracorporeal circulation auxiliary to open heart procedures
- Other OR procedures on vessels other than head and neck
- Diagnostic bronchoscopy and biopsy of bronchus
- Diagnostic ultrasound of heart (echocardiogram)
- Other vascular catheterization; not heart


## Get Patient's Lab IDs

In [None]:
# Step 1: Get edge index for patient-lab relation
edge_index_lab = data["patient", "has_lab", "lab"].edge_index

# Step 2: Filter for edges where the patient is the source
mask_lab = edge_index_lab[0] == patient_id

# Step 3: Extract corresponding lab node IDs
lab_ids = edge_index_lab[1][mask_lab].cpu().numpy()
lab_ids

array([13, 16, 17, 19, 20, 21, 24, 25, 26, 27,  1,  3,  6,  7,  8,  9, 11,
       12, 13, 16, 17, 19, 20, 21, 24, 25, 26, 27,  1,  3,  6,  7,  8,  9,
       11, 12, 13, 16, 17, 19, 20, 21, 24, 25, 26, 27, 35, 37, 38, 46, 47,
       28, 16, 24, 29, 30, 31, 54, 34, 35, 37, 38, 55, 29, 30, 54, 34, 35,
       37, 38, 29, 30, 54, 40, 41, 34, 35, 37, 38, 16, 29, 30, 31, 54, 34,
       35, 37, 38, 55, 16, 59, 60, 31, 54, 37, 55, 58, 29, 30, 35, 37, 38,
        1,  3,  6,  8,  9, 11, 12, 13, 98, 35, 36, 37, 38, 42, 29, 30, 31,
       40, 41, 34, 35, 37, 38, 55, 16, 31, 54, 37, 55, 29, 30, 31, 35, 37,
       38,  1,  3,  6,  7,  8,  9, 11, 12, 13, 16, 17, 19, 20, 21, 24, 25,
       26, 27, 16,  1,  3,  6,  7,  8,  9, 11, 12, 13, 16, 17, 19, 20, 21,
       24, 25, 26, 27,  1,  3,  5,  6,  7,  8,  9, 11, 12, 56, 29, 30, 31,
       54, 40, 41, 33, 35, 37, 38, 55, 57, 58, 39, 53, 46, 24, 47, 28, 16,
       46, 24, 47, 28, 29, 30, 31, 54, 35, 37, 38, 55, 29, 30, 54, 34, 35,
       37, 38, 11, 16, 24

## New Patient

### Add Disease Node

In [None]:
import pandas as pd

# Load saved encoder and scaler
encoder = joblib.load(os.path.join(data_path, "patient_gender_ethnicity_encoder.pkl"))
age_scaler = joblib.load(os.path.join(data_path, "patient_age_scaler.pkl"))

# New patient info (Input)
gender = 0
ethnicity = 1
age = 0

# Wrap in DataFrames to preserve feature names
new_patient_demo = pd.DataFrame([[gender, ethnicity]], columns=["gender", "ethnicity"])
new_patient_age = pd.DataFrame([[age]], columns=["age"])

# Transform using fitted encoders
encoded_demo = encoder.transform(new_patient_demo)
scaled_age = age_scaler.transform(new_patient_age)

# Concatenate into feature vector
new_patient_features = np.hstack([scaled_age, encoded_demo])
new_patient_features.shape

(1, 44)

In [None]:
from torch_geometric.transforms import ToUndirected

# Context mapping builders
def build_first_association_map(patient_nodes, context_nodes):
    # patient_nodes and context_nodes are tensors
    mapping = {}
    for p, c in zip(patient_nodes.tolist(), context_nodes.tolist()):
        if p not in mapping:
            mapping[p] = c
    return mapping

# Lookup functions for context
def get_first_associated_node(context_map, patient_ids):
    return torch.tensor([context_map.get(pid.item(), 0) for pid in patient_ids], dtype=torch.long).to(patient_ids.device)


def predict_with_node_addition(
    model,
    data,
    new_patient_features,     # shape: [1, input_dim]
    disease_ids,              # list or tensor of disease node indices
    procedure_ids,            # list or tensor of procedure node indices
    lab_ids,                  # list or tensor of lab node indices
    disease_map_fn,           # function to build disease_map
    procedure_map_fn,         # function to build procedure_map
    lab_map_fn,               # function to build lab_map
    med_map,                  # medication ID to name mapping
    device,
    topk=10
):
    model.eval()
    with torch.no_grad():
        # Step 1: Clone the original graph to avoid mutation
        inference_data = data.clone()

        # Step 2: Assign new patient ID
        new_patient_tensor = torch.tensor(new_patient_features, dtype=torch.float).to(device)
        new_patient_id = inference_data["patient"].x.shape[0]

        # Step 3: Add patient node
        inference_data["patient"].x = torch.cat(
            [inference_data["patient"].x, new_patient_tensor], dim=0
        )

        # Step 4: Add edges
        def add_edges(edge_type, target_ids):
            edge_index = torch.stack([
                torch.full((len(target_ids),), new_patient_id, dtype=torch.long),  # source
                torch.tensor(target_ids, dtype=torch.long)
            ], dim=0).to(device)
            inference_data[edge_type].edge_index = torch.cat([
                inference_data[edge_type].edge_index.to(device),
                edge_index
            ], dim=1)

        add_edges(("patient", "has_disease", "disease"), disease_ids)
        add_edges(("patient", "underwent", "procedure"), procedure_ids)
        add_edges(("patient", "has_lab", "lab"), lab_ids)

        # Step 5: Reapply ToUndirected to get reverse edges
        inference_data = ToUndirected()(inference_data)

        # Step 6: Rebuild context maps
        disease_map = disease_map_fn(
            inference_data["patient", "has_disease", "disease"].edge_index[0],
            inference_data["patient", "has_disease", "disease"].edge_index[1]
        )
        procedure_map = procedure_map_fn(
            inference_data["patient", "underwent", "procedure"].edge_index[0],
            inference_data["patient", "underwent", "procedure"].edge_index[1]
        )
        lab_map = lab_map_fn(
            inference_data["patient", "has_lab", "lab"].edge_index[0],
            inference_data["patient", "has_lab", "lab"].edge_index[1]
        )

        # Step 7: Forward pass
        z_dict = model(inference_data.x_dict, inference_data.edge_index_dict)

        # Replace patient embeddings using the patient_encoder MLP
        patient_features = inference_data["patient"].x.to(device)
        z_dict["patient"] = model.encode_patient(patient_features)  # <-- Encoder used here

        num_meds = z_dict["medication"].shape[0]
        med_ids = torch.arange(num_meds).to(device)
        patients = torch.full((num_meds,), new_patient_id, dtype=torch.long).to(device)

        edge_index = torch.stack([patients, med_ids], dim=0)

        disease_ids_batch   = get_first_associated_node(disease_map, patients)
        procedure_ids_batch = get_first_associated_node(procedure_map, patients)
        lab_ids_batch       = get_first_associated_node(lab_map, patients)

        # Step 8: Decode predictions
        scores = model.decode(
            z_patient     = z_dict["patient"],
            z_medication  = z_dict["medication"],
            z_disease     = z_dict["disease"],
            z_procedure   = z_dict["procedure"],
            z_lab         = z_dict["lab"],
            edge_index    = edge_index,
            disease_ids   = disease_ids_batch,
            procedure_ids = procedure_ids_batch,
            lab_ids       = lab_ids_batch
        )

        probs = torch.sigmoid(scores)

        # Step 9: Top-k results
        top_indices = probs.topk(topk).indices
        recommended_med_ids = med_ids[top_indices].cpu().numpy()
        top_scores = probs[top_indices].cpu().numpy()

        print(f"\nTop {topk} Recommended Medications for Patient {new_patient_id}:\n")
        for med_id, score in zip(recommended_med_ids, top_scores):
            print(f"- {med_map[str(med_id)]}: {score:.4f}")

        return recommended_med_ids, top_scores

In [None]:
recommendations, scores = predict_with_node_addition(
    model=model,
    data=data,
    new_patient_features=new_patient_features,
    disease_ids=disease_ids,
    procedure_ids=procedure_ids,
    lab_ids=lab_ids,
    disease_map_fn=build_first_association_map,
    procedure_map_fn=build_first_association_map,
    lab_map_fn=build_first_association_map,
    med_map=med_map,  # {"0": "Drug A", ...}
    device=device,
    topk=15
)


Top 15 Recommended Medications for Patient 46520:

- Furosemide: 0.9993
- Potassium Chloride: 0.9990
- Insulin: 0.9988
- D5W: 0.9987
- Acetaminophen: 0.9985
- Docusate Sodium: 0.9983
- Sodium Chloride 0.9%  Flush: 0.9981
- Metoprolol: 0.9975
- Magnesium Sulfate: 0.9973
- Warfarin: 0.9972
- NS: 0.9970
- Heparin Sodium: 0.9966
- Iso-Osmotic Dextrose: 0.9965
- Bisacodyl: 0.9965
- Metoprolol Tartrate: 0.9960


| Medication            | Likely Purpose                                                                 | Alignment with Conditions/Procedures       |
|-----------------------|-------------------------------------------------------------------------------|--------------------------------------------|
| **Furosemide**        | Diuretic for fluid overload (post-CABG, heart disease, hypertension).         | ✅ High relevance.                          |
| **Potassium Chloride**| Prevents hypokalemia (common with diuretics like furosemide).                 | ✅ Necessary adjunct.                       |
| **Insulin**           | Manages hyperglycemia (stress-induced or pre-existing diabetes post-surgery). | ✅ CABG patients often need glucose control.|
| **Metoprolol** (and Tartrate) | Beta-blocker for hypertension, arrhythmias, post-CABG cardioprotection. | ✅ Standard post-CABG.                      |
| **Warfarin/Heparin**  | Anticoagulation (post-CABG, vascular procedures, prevent thromboembolism).    | ✅ Critical for vessel patency.             |
| **Magnesium Sulfate** | Corrects arrhythmias (e.g., torsades) or deficiency (common post-heart surgery). | ✅ Fits cardiac history.                  |
| **Acetaminophen**     | Pain management (safer than NSAIDs post-CABG due to bleeding risks).          | ✅ Appropriate.                            |
| **Docusate/Bisacodyl**| Prevents constipation (opioid-induced post-op or due to reduced mobility).    | ✅ Prophylactic measure.                   |
| **IV Fluids** (D5W, NS) | Maintain hydration/electrolytes (post-op, NPO status, or hemodynamic support). | ✅ Standard care.                         |