In [2]:
# !pip install torch_geometric

In [3]:
import os
import joblib
import numpy as np

import torch
from torch_geometric.data import HeteroData

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
data = torch.load("./models/processed_graph.pt", weights_only=False)

ckpt_path = "./checkpoints"

In [5]:
# Device configuration
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#metal
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [6]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv

class HeteroSAGELinkPredictor(nn.Module):
    def __init__(self, metadata, output_dims, hidden_channels=64, out_channels=32):
        super().__init__()

        # Message passing layers
        self.conv1 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_channels)
            for edge_type in metadata[1]
        }, aggr='sum')

        self.conv2 = HeteroConv({
            edge_type: SAGEConv((hidden_channels, hidden_channels), out_channels)
            for edge_type in metadata[1]
        }, aggr='sum')

        # NEW: MLP that maps raw patient features → same space as GNN
        self.patient_encoder = nn.Sequential(
            nn.Linear(44, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )

        # Decoder: combines patient, medication, and optional context features
        self.decoder = nn.Sequential(
            nn.Linear(out_channels * 5, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, 1)
        )

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

    def encode_patient(self, patient_features):
        return self.patient_encoder(patient_features)

    def decode(
        self,
        z_patient,
        z_medication,
        z_disease,
        z_procedure,
        z_lab,
        edge_index,
        disease_ids=None,
        procedure_ids=None,
        lab_ids=None,
    ):
        src, dst = edge_index

        # Default to zero vectors if ids are None (i.e., for inference without all context)
        def get_context_embeddings(z, ids):
            if ids is None:
                return torch.zeros_like(z_patient[src])
            return z[ids]

        disease_emb = get_context_embeddings(z_disease, disease_ids)
        proc_emb    = get_context_embeddings(z_procedure, procedure_ids)
        lab_emb     = get_context_embeddings(z_lab, lab_ids)

        # Concatenate all embeddings
        combined = torch.cat([
            z_patient[src],
            z_medication[dst],
            disease_emb,
            proc_emb,
            lab_emb
        ], dim=1)

        return self.decoder(combined).squeeze()

In [9]:
from torch_geometric.nn import HeteroConv, SAGEConv

data = data.to(device)

with torch.no_grad():
    temp_conv1 = HeteroConv({
        edge_type: SAGEConv((-1, -1), out_channels=64)
        for edge_type in data.edge_types
    }, aggr='sum').to(device)

    temp_output = temp_conv1(data.x_dict, data.edge_index_dict)
    output_dims = {node_type: feat.shape[1] for node_type, feat in temp_output.items()}

RuntimeError: Invalid buffer size: 33.34 GB

In [None]:
# Instantiate the model again with correct metadata and output dims
model = HeteroSAGELinkPredictor(data.metadata(), output_dims=output_dims).to(device)

# Load the best model weights
model.load_state_dict(torch.load(os.path.join(ckpt_path, "best_model_acc.pt")))
model.eval()  # Set to eval mode

In [None]:
from collections import defaultdict
import torch

def build_first_association_map(src_nodes, dst_nodes):
    """Build a mapping from patient to first associated node (disease/procedure/lab)."""
    assoc_map = {}
    for src, dst in zip(src_nodes.tolist(), dst_nodes.tolist()):
        if src not in assoc_map:
            assoc_map[src] = dst
    return assoc_map

def get_first_associated_node(assoc_map, patient_ids, default_val=-1):
    """Return first associated node for each patient in patient_ids."""
    return torch.tensor(
        [assoc_map.get(pid.item(), default_val) for pid in patient_ids],
        dtype=torch.long
    )

In [None]:
# Build once before training
edge_pd = data["patient", "has_disease", "disease"].edge_index
edge_pp = data["patient", "underwent", "procedure"].edge_index
edge_pl = data["patient", "has_lab", "lab"].edge_index

disease_map = build_first_association_map(edge_pd[0], edge_pd[1])
procedure_map = build_first_association_map(edge_pp[0], edge_pp[1])
lab_map = build_first_association_map(edge_pl[0], edge_pl[1])

In [None]:
# Step 1: Get node embeddings
with torch.no_grad():
    z_dict = model(data.x_dict, data.edge_index_dict)

In [None]:
# Example: predict score for patient node 42 and medication nodes 0-9
patient_id = 1200
num_meds = z_dict["medication"].shape[0]
med_ids = torch.arange(num_meds).to(device)

# Step 2: Replace patient embeddings using the patient_encoder MLP
patient_features = data["patient"].x.to(device)
z_dict["patient"] = model.encode_patient(patient_features)  # <-- Encoder used here

# Repeat patient_id N times (same length as med_ids)
patients = torch.tensor([patient_id] * len(med_ids), device=device)

edge_index = torch.stack([
    torch.full_like(med_ids, fill_value=patient_id),  # source: patient
    med_ids                                           # destination: meds
], dim=0)

disease_ids   = get_first_associated_node(disease_map, patients)
procedure_ids = get_first_associated_node(procedure_map, patients)
lab_ids       = get_first_associated_node(lab_map, patients)

# Predict
with torch.no_grad():
    scores = model.decode(
        z_patient     = z_dict["patient"],
        z_medication  = z_dict["medication"],
        z_disease     = z_dict["disease"],
        z_procedure   = z_dict["procedure"],
        z_lab         = z_dict["lab"],
        edge_index    = edge_index,
        disease_ids   = disease_ids,
        procedure_ids = procedure_ids,
        lab_ids       = lab_ids
    )
    probs = torch.sigmoid(scores)

In [None]:
topk = 15
top_indices = probs.topk(topk).indices
recommended_meds = med_ids[top_indices].cpu().numpy()

print("Recommended Medication IDs:", recommended_meds)

## Get Medications' names


In [None]:
data_path = "./"

In [None]:
import json

with open(os.path.join(data_path, "mappings", "id_to_medication.json"), 'r') as file:
    med_map = json.load(file)

In [None]:
print(f"Top {topk} Recommended medications for patient {patient_id}:\n")
for med_id in recommended_meds:
    print("-", med_map[str(med_id)])

## Patient's Condition


In [None]:
import json

with open(os.path.join(data_path, "mappings", "id_to_disease.json"), "r") as f:
    disease_mapping = json.load(f)

In [None]:
edge_index = data["patient", "has_disease", "disease"].edge_index

# Filter edges where patient is the source
mask = edge_index[0] == patient_id
disease_ids = edge_index[1][mask].cpu().numpy()
disease_ids

In [None]:
print(f"Conditions of patient {patient_id}:\n")
for disease_id in disease_ids:
    print("-", disease_mapping[str(disease_id)])

## Get Patient's Information


In [None]:
# Load saved encoder and scaler
encoder = joblib.load(os.path.join(data_path,"models", "patient_gender_ethnicity_encoder.pkl"))
age_scaler = joblib.load(os.path.join(data_path,"models" ,"patient_age_scaler.pkl"))

In [None]:
from sklearn.preprocessing import OneHotEncoder

# Assuming feature layout: [scaled_age, onehot_gender, onehot_ethnicity]
patient_vector = data["patient"].x[patient_id].cpu().numpy()

scaled_age = patient_vector[0:1]
encoded_demo = patient_vector[1:]

# Recover original age
original_age = age_scaler.inverse_transform(scaled_age.reshape(1, -1))[0][0]

# Recover gender and ethnicity
original_demo = encoder.inverse_transform(encoded_demo.reshape(1, -1))[0]
original_gender, original_ethnicity = original_demo

print("Recovered Info:")
print("Age:", original_age)
print("Gender:", original_gender)
print("Ethnicity:", original_ethnicity)

## Patient's Procedures


In [None]:
with open(os.path.join(data_path, "mappings", "id_to_procedure.json"), "r") as f:
    procedure_mapping = json.load(f)

In [None]:
# Step 1: Get edge index for "underwent" relation
edge_index_proc = data["patient", "underwent", "procedure"].edge_index

# Step 2: Filter where patient is the source node
mask_proc = edge_index_proc[0] == patient_id

# Step 3: Get corresponding procedure IDs (target nodes)
procedure_ids = edge_index_proc[1][mask_proc].cpu().numpy()
procedure_ids

In [None]:
print(f"Procedures of patient {patient_id}:\n")
for pid in procedure_ids:
    print("-", procedure_mapping[str(pid)])

## Get Patient's Lab IDs


In [None]:
# Step 1: Get edge index for patient-lab relation
edge_index_lab = data["patient", "has_lab", "lab"].edge_index

# Step 2: Filter for edges where the patient is the source
mask_lab = edge_index_lab[0] == patient_id

# Step 3: Extract corresponding lab node IDs
lab_ids = edge_index_lab[1][mask_lab].cpu().numpy()
lab_ids

## New Patient


### Add Disease Node


In [None]:
import pandas as pd

# Load saved encoder and scaler
encoder = joblib.load(os.path.join(data_path,"models", "patient_gender_ethnicity_encoder.pkl"))
age_scaler = joblib.load(os.path.join(data_path, "models","patient_age_scaler.pkl"))

# New patient info (Input)
gender = 0
ethnicity = 1
age = 67

# Wrap in DataFrames to preserve feature names
new_patient_demo = pd.DataFrame([[gender, ethnicity]], columns=["gender", "ethnicity"])
new_patient_age = pd.DataFrame([[age]], columns=["age"])

# Transform using fitted encoders
encoded_demo = encoder.transform(new_patient_demo)
scaled_age = age_scaler.transform(new_patient_age)

# Concatenate into feature vector
new_patient_features = np.hstack([scaled_age, encoded_demo])
new_patient_features.shape

In [None]:
from torch_geometric.transforms import ToUndirected

# Context mapping builders
def build_first_association_map(patient_nodes, context_nodes):
    # patient_nodes and context_nodes are tensors
    mapping = {}
    for p, c in zip(patient_nodes.tolist(), context_nodes.tolist()):
        if p not in mapping:
            mapping[p] = c
    return mapping

# Lookup functions for context
def get_first_associated_node(context_map, patient_ids):
    return torch.tensor([context_map.get(pid.item(), 0) for pid in patient_ids], dtype=torch.long).to(patient_ids.device)


def predict_with_node_addition(
    model,
    data,
    new_patient_features,     # shape: [1, input_dim]
    disease_ids,              # list or tensor of disease node indices
    procedure_ids,            # list or tensor of procedure node indices
    lab_ids,                  # list or tensor of lab node indices
    disease_map_fn,           # function to build disease_map
    procedure_map_fn,         # function to build procedure_map
    lab_map_fn,               # function to build lab_map
    med_map,                  # medication ID to name mapping
    device,
    topk=10
):
    model.eval()
    with torch.no_grad():
        # Step 1: Clone the original graph to avoid mutation
        inference_data = data.clone()

        # Step 2: Assign new patient ID
        new_patient_tensor = torch.tensor(new_patient_features, dtype=torch.float).to(device)
        new_patient_id = inference_data["patient"].x.shape[0]

        # Step 3: Add patient node
        inference_data["patient"].x = torch.cat(
            [inference_data["patient"].x, new_patient_tensor], dim=0
        )

        # Step 4: Add edges
        def add_edges(edge_type, target_ids):
            edge_index = torch.stack([
                torch.full((len(target_ids),), new_patient_id, dtype=torch.long),  # source
                torch.tensor(target_ids, dtype=torch.long)
            ], dim=0).to(device)
            inference_data[edge_type].edge_index = torch.cat([
                inference_data[edge_type].edge_index.to(device),
                edge_index
            ], dim=1)

        add_edges(("patient", "has_disease", "disease"), disease_ids)
        add_edges(("patient", "underwent", "procedure"), procedure_ids)
        add_edges(("patient", "has_lab", "lab"), lab_ids)

        # Step 5: Reapply ToUndirected to get reverse edges
        inference_data = ToUndirected()(inference_data)

        # Step 6: Rebuild context maps
        disease_map = disease_map_fn(
            inference_data["patient", "has_disease", "disease"].edge_index[0],
            inference_data["patient", "has_disease", "disease"].edge_index[1]
        )
        procedure_map = procedure_map_fn(
            inference_data["patient", "underwent", "procedure"].edge_index[0],
            inference_data["patient", "underwent", "procedure"].edge_index[1]
        )
        lab_map = lab_map_fn(
            inference_data["patient", "has_lab", "lab"].edge_index[0],
            inference_data["patient", "has_lab", "lab"].edge_index[1]
        )

        # Step 7: Forward pass
        z_dict = model(inference_data.x_dict, inference_data.edge_index_dict)

        # Replace patient embeddings using the patient_encoder MLP
        patient_features = inference_data["patient"].x.to(device)
        z_dict["patient"] = model.encode_patient(patient_features)  # <-- Encoder used here

        num_meds = z_dict["medication"].shape[0]
        med_ids = torch.arange(num_meds).to(device)
        patients = torch.full((num_meds,), new_patient_id, dtype=torch.long).to(device)

        edge_index = torch.stack([patients, med_ids], dim=0)

        disease_ids_batch   = get_first_associated_node(disease_map, patients)
        procedure_ids_batch = get_first_associated_node(procedure_map, patients)
        lab_ids_batch       = get_first_associated_node(lab_map, patients)

        # Step 8: Decode predictions
        scores = model.decode(
            z_patient     = z_dict["patient"],
            z_medication  = z_dict["medication"],
            z_disease     = z_dict["disease"],
            z_procedure   = z_dict["procedure"],
            z_lab         = z_dict["lab"],
            edge_index    = edge_index,
            disease_ids   = disease_ids_batch,
            procedure_ids = procedure_ids_batch,
            lab_ids       = lab_ids_batch
        )

        probs = torch.sigmoid(scores)

        # Step 9: Top-k results
        top_indices = probs.topk(topk).indices
        recommended_med_ids = med_ids[top_indices].cpu().numpy()
        top_scores = probs[top_indices].cpu().numpy()

        print(f"\nTop {topk} Recommended Medications for Patient {new_patient_id}:\n")
        for med_id, score in zip(recommended_med_ids, top_scores):
            print(f"- {med_map[str(med_id)]}: {score:.4f}")

        return recommended_med_ids, top_scores

In [None]:
recommendations, scores = predict_with_node_addition(
    model=model,
    data=data,
    new_patient_features=new_patient_features,
    disease_ids=disease_ids,
    procedure_ids=procedure_ids,
    lab_ids=lab_ids,
    disease_map_fn=build_first_association_map,
    procedure_map_fn=build_first_association_map,
    lab_map_fn=build_first_association_map,
    med_map=med_map,  # {"0": "Drug A", ...}
    device=device,
    topk=15
)