In [13]:
import json
import math

# Load the disorders JSON file
def load_disorders(json_file):
    with open(json_file, 'r', encoding='utf-8') as f:
        disorders = json.load(f)
    return disorders

# Helper function to parse frequency strings into a numeric value
def parse_frequency(freq_str):
    if freq_str is None:
        return 0.5
    if "Very frequent" in freq_str:
        return 0.9
    elif "Frequent" in freq_str:
        return 0.6
    elif "Occasional" in freq_str:
        return 0.2
    else:
        return 0.5

disorders = load_disorders('./disorders.json')
print("Loaded", len(disorders), "disorders.")

# Build dictionaries for nodes and accumulate disease->phenotype associations
phenotype_dict = {}  # key: hpo_id, value: {'term': ..., 'frequency': ...}
disease_dict = {}    # key: disease id (using OrphaCode), value: {'name': ...}
disease_to_pheno_edges = []  # List of tuples: (disease_id, hpo_id, weight)

for disorder in disorders:
    disease_id = disorder.get('orpha_code')
    disease_name = disorder.get('name')
    if disease_id not in disease_dict:
        disease_dict[disease_id] = {'name': disease_name}
    
    for pheno in disorder.get('phenotypes', []):
        hpo_id = pheno.get('hpo_id')
        hpo_term = pheno.get('HPOTerm') or pheno.get('hpo_term')
        freq_str = pheno.get('frequency')
        freq_val = parse_frequency(freq_str)
        
        if hpo_id not in phenotype_dict:
            phenotype_dict[hpo_id] = {'term': hpo_term, 'frequency': freq_val}
        
        # Append the disease-to-phenotype edge with the frequency as weight
        disease_to_pheno_edges.append((disease_id, hpo_id, freq_val))

print("Total unique phenotypes:", len(phenotype_dict))
print("Total unique diseases:", len(disease_dict))

Loaded 4283 disorders.
Total unique phenotypes: 8600
Total unique diseases: 4283


In [7]:
import json
with open('disorders.json','r') as f:
    data=json.load(f)
print(json.dumps(data[:1], indent=4))

[
    {
        "orpha_code": "58",
        "name": "Alexander disease",
        "type": "Disease",
        "phenotypes": [
            {
                "hpo_id": "HP:0000256",
                "hpo_term": "Macrocephaly",
                "frequency": "Very frequent (99-80%)"
            },
            {
                "hpo_id": "HP:0001249",
                "hpo_term": "Intellectual disability",
                "frequency": "Very frequent (99-80%)"
            },
            {
                "hpo_id": "HP:0001250",
                "hpo_term": "Seizure",
                "frequency": "Very frequent (99-80%)"
            },
            {
                "hpo_id": "HP:0001257",
                "hpo_term": "Spasticity",
                "frequency": "Very frequent (99-80%)"
            },
            {
                "hpo_id": "HP:0001274",
                "hpo_term": "Agenesis of corpus callosum",
                "frequency": "Very frequent (99-80%)"
            },
            {
        

In [11]:
!pip install obonet

Collecting obonet
  Using cached obonet-1.1.0-py3-none-any.whl.metadata (6.8 kB)
Using cached obonet-1.1.0-py3-none-any.whl (9.1 kB)
Installing collected packages: obonet
Successfully installed obonet-1.1.0


In [21]:


import pickle
import torch
from torch_geometric.data import HeteroData
import math
import obonet

# (Optional) Placeholder for HPO hierarchy edges; replace with real data if available.
def get_hpo_hierarchy_edges(obo_file='hp.obo'):
    
    graph = obonet.read_obo(obo_file)
    hierarchy_edges=[]
    for node in graph.nodes():
        # Each node's 'is_a' attribute contains its parent terms (if any)
        if 'is_a' in graph.nodes[node]:
            for parent in graph.nodes[node]['is_a']:
                hierarchy_edges.append((node,parent))
    return hierarchy_edges



# Create the HeteroData object
data = HeteroData()

# --- Add Phenotype Nodes ---
# Each phenotype node gets a feature vector: [frequency, IC] where IC = -log(frequency)
phenotype_list = list(phenotype_dict.keys())
phenotype_features = []
for hpo_id in phenotype_list:
    freq = phenotype_dict[hpo_id]['frequency']
    ic = -math.log(freq)
    phenotype_features.append([freq, ic])
data['phenotype'].x = torch.tensor(phenotype_features, dtype=torch.float)

# --- Add Disease Nodes ---
# Create a dummy feature vector for disease nodes (initially 1-dimensional)
disease_list = list(disease_dict.keys())
disease_features = [[1.0] for _ in disease_list]
disease_features = torch.tensor(disease_features, dtype=torch.float)
# Expand disease node features from 1D to 2D to match phenotype nodes (e.g., simply repeat the feature)
if disease_features.shape[1] == 1:
    disease_features = disease_features.repeat(1, 2)
data['disease'].x = disease_features

# Create mapping from IDs to indices for each node type
pheno_to_idx = {hpo_id: i for i, hpo_id in enumerate(phenotype_list)}
disease_to_idx = {d_id: i for i, d_id in enumerate(disease_list)}

# --- Add Edges: Disease -> Phenotype ---
src, dst, edge_weights = [], [], []
for disease_id, hpo_id, weight in disease_to_pheno_edges:
    if disease_id in disease_to_idx and hpo_id in pheno_to_idx:
        src.append(disease_to_idx[disease_id])
        dst.append(pheno_to_idx[hpo_id])
        edge_weights.append([weight])
data['disease', 'has_phenotype', 'phenotype'].edge_index = torch.tensor([src, dst], dtype=torch.long)
data['disease', 'has_phenotype', 'phenotype'].edge_attr = torch.tensor(edge_weights, dtype=torch.float)

# --- Add Reverse Edges: Phenotype -> Disease ---
data['phenotype', 'associated_with', 'disease'].edge_index = torch.tensor([dst, src], dtype=torch.long)
data['phenotype', 'associated_with', 'disease'].edge_attr = torch.tensor(edge_weights, dtype=torch.float)

# --- (Optional) Add HPO Hierarchy Edges: Phenotype -> Phenotype ---
hpo_hierarchy_edges = get_hpo_hierarchy_edges('hp.obo')
if hpo_hierarchy_edges:
    src_h, dst_h, hierarchy_weights = [], [], []
    for child, parent in hpo_hierarchy_edges:
        if child in pheno_to_idx and parent in pheno_to_idx:
            src_h.append(pheno_to_idx[child])
            dst_h.append(pheno_to_idx[parent])
            hierarchy_weights.append([1.0])
    data['phenotype', 'is_a', 'phenotype'].edge_index = torch.tensor([src_h, dst_h], dtype=torch.long)
    data['phenotype', 'is_a', 'phenotype'].edge_attr = torch.tensor(hierarchy_weights, dtype=torch.float)

print("Heterogeneous graph created:")
print(data)

Heterogeneous graph created:
HeteroData(
  phenotype={ x=[8600, 2] },
  disease={ x=[4283, 2] },
  (disease, has_phenotype, phenotype)={
    edge_index=[2, 114961],
    edge_attr=[114961, 1],
  },
  (phenotype, associated_with, disease)={
    edge_index=[2, 114961],
    edge_attr=[114961, 1],
  },
  (phenotype, is_a, phenotype)={
    edge_index=[2, 8609],
    edge_attr=[8609, 1],
  }
)


In [29]:
import networkx as nx
import matplotlib.pyplot as plt
import torch

# Check if the HPO hierarchy relation exists in the heterogeneous graph
if ('phenotype', 'is_a', 'phenotype') in data.edge_index_dict:
    hier_edge_index = data['phenotype', 'is_a', 'phenotype'].edge_index
    hier_edge_attr = data['phenotype', 'is_a', 'phenotype'].edge_attr
    hier_edge_index_np = hier_edge_index.cpu().numpy()
    
    # Build a full directed graph for all phenotype nodes using the "is_a" edges.
    H_full = nx.DiGraph()
    num_nodes = data['phenotype'].num_nodes
    for i in range(num_nodes):
        freq, ic = data['phenotype'].x[i].tolist()
        H_full.add_node(i, label=f"f:{freq:.2f}, ic:{ic:.2f}")
    
    for j in range(hier_edge_index_np.shape[1]):
        child = hier_edge_index_np[0, j]
        parent = hier_edge_index_np[1, j]
        weight = hier_edge_attr[j].item()
        H_full.add_edge(child, parent, weight=float(f"{weight:.4f}"))
    
    # Find weakly connected components in the directed graph
    components = list(nx.weakly_connected_components(H_full))
    print(f"Found {len(components)} weakly connected components.")
    
    # Choose the largest component (or a component with at least 10 nodes)
    selected_component = max(components, key=len)
    print(f"Largest component size: {len(selected_component)}")
    
    if len(selected_component) < 2:
        print("No sufficiently connected component found.")
    else:
        # Choose a subset (e.g., first 10 nodes) from the selected component.
        selected_nodes = sorted(list(selected_component))[:10]
        
        
        # Build a subgraph from the selected nodes.
        H_sub = H_full.subgraph(selected_nodes)
        
        # Prepare labels using the stored node attributes.
        labels = {node: H_sub.nodes[node]['label'] for node in H_sub.nodes()}
        
       
else:
    print("No HPO hierarchy edges available for visualization.")

Found 1130 weakly connected components.
Largest component size: 7018


In [33]:
import pandas as pd
import networkx as nx

# Assume H_full is the directed NetworkX graph we built from the "is_a" edges.
# If you haven't built H_full already, here's a quick recap using your HPO hierarchy data:
if ('phenotype', 'is_a', 'phenotype') in data.edge_index_dict:
    hier_edge_index = data['phenotype', 'is_a', 'phenotype'].edge_index
    hier_edge_attr = data['phenotype', 'is_a', 'phenotype'].edge_attr
    hier_edge_index_np = hier_edge_index.cpu().numpy()
    
    # Build the full directed graph H_full for phenotype hierarchy.
    H_full = nx.DiGraph()
    num_nodes = data['phenotype'].num_nodes
    for i in range(num_nodes):
        freq, ic = data['phenotype'].x[i].tolist()
        H_full.add_node(i, label=f"f:{freq:.2f}, ic:{ic:.2f}")
    
    for j in range(hier_edge_index_np.shape[1]):
        child = hier_edge_index_np[0, j]
        parent = hier_edge_index_np[1, j]
        weight = hier_edge_attr[j].item()
        H_full.add_edge(child, parent, weight=float(f"{weight:.4f}"))
else:
    print("No HPO hierarchy edges available for visualization.")
    H_full = nx.DiGraph()  # Empty graph

# ----------------------------
# Create a table for weakly connected components

# Compute weakly connected components
weak_components = list(nx.weakly_connected_components(H_full))
print("Number of weakly connected components:", len(weak_components))

# Create a mapping from node to its weak component ID
node_to_weak_comp = {}
for comp_id, comp in enumerate(weak_components):
    for node in comp:
        node_to_weak_comp[node] = comp_id

# Build a list of rows (one per edge) with 5 fields
weak_table_rows = []
for u, v, d in H_full.edges(data=True):
    row = {
        "Child": f"p_{u}",
        "Parent": f"p_{v}",
        "Edge Weight": d.get("weight", None),
        "Weak Comp ID": node_to_weak_comp.get(u, None),
        "Child Label": H_full.nodes[u].get("label", "")
    }
    weak_table_rows.append(row)

weak_df = pd.DataFrame(weak_table_rows)
print("Weakly Connected Components Table:")
print(weak_df.head(50))  # Show first 5 rows

# ----------------------------
# Create a table for strongly connected components

# Compute strongly connected components
strong_components = list(nx.strongly_connected_components(H_full))
print("Number of strongly connected components:", len(strong_components))

# Create a mapping from node to its strong component ID
node_to_strong_comp = {}
for comp_id, comp in enumerate(strong_components):
    for node in comp:
        node_to_strong_comp[node] = comp_id

strong_table_rows = []
for u, v, d in H_full.edges(data=True):
    row = {
        "Child": f"p_{u}",
        "Parent": f"p_{v}",
        "Edge Weight": d.get("weight", None),
        "Strong Comp ID": node_to_strong_comp.get(u, None),
        "Child Label": H_full.nodes[u].get("label", "")
    }
    strong_table_rows.append(row)

strong_df = pd.DataFrame(strong_table_rows)
print("Strongly Connected Components Table:")
print(strong_df.head(50))  # Show first 5 rows

Number of weakly connected components: 1130
Weakly Connected Components Table:
   Child  Parent  Edge Weight  Weak Comp ID      Child Label
0    p_0  p_6892          1.0             0  f:0.90, ic:0.11
1    p_1  p_2375          1.0             1  f:0.90, ic:0.11
2    p_1  p_4787          1.0             1  f:0.90, ic:0.11
3    p_2  p_4653          1.0             1  f:0.90, ic:0.11
4    p_3   p_141          1.0             1  f:0.90, ic:0.11
5    p_3  p_1490          1.0             1  f:0.90, ic:0.11
6    p_4  p_1369          1.0             1  f:0.90, ic:0.11
7    p_6  p_3659          1.0             1  f:0.90, ic:0.11
8    p_7  p_1176          1.0             1  f:0.90, ic:0.11
9    p_8  p_4499          1.0             1  f:0.90, ic:0.11
10   p_9  p_4690          1.0             1  f:0.90, ic:0.11
11  p_11     p_5          1.0             1  f:0.90, ic:0.11
12  p_11   p_887          1.0             1  f:0.90, ic:0.11
13  p_12  p_7422          1.0             1  f:0.90, ic:0.11
14  p_

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, HeteroConv

# Define the improved heterogeneous GAT encoder with add_self_loops disabled.
class HeteroGAT(nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, heads=4, dropout=0.3):
        """
        metadata: tuple (node_types, edge_types) from data.metadata()
        hidden_channels: hidden dimension for the GAT layers
        out_channels: final embedding dimension (for all node types)
        heads: number of attention heads
        dropout: dropout rate
        """
        super(HeteroGAT, self).__init__()
        self.dropout = dropout
        
        # First layer: Create a GATConv for each edge type with dropout and no self-loops.
        conv_dict = {}
        for edge_type in metadata[1]:
            conv_dict[edge_type] = GATConv(-1, hidden_channels, heads=heads, concat=True,
                                            dropout=dropout, add_self_loops=False)
        self.conv1 = HeteroConv(conv_dict, aggr='sum')
        
        # Second layer: Map from (hidden_channels * heads) to out_channels.
        conv_dict2 = {}
        for edge_type in metadata[1]:
            conv_dict2[edge_type] = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False,
                                            dropout=dropout, add_self_loops=False)
        self.conv2 = HeteroConv(conv_dict2, aggr='sum')
    
    def forward(self, x_dict, edge_index_dict):
        # Save input for residual connection
        x_dict_in = x_dict
        # First GAT layer with ELU activation
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.gelu(x) for key, x in x_dict.items()}
        # Optional residual connection (only if dimensions match)
        for key in x_dict:
            if x_dict_in[key].shape[-1] == x_dict[key].shape[-1]:
                x_dict[key] = x_dict[key] + x_dict_in[key]
        # Second GAT layer produces final embeddings (call only once)
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

# Define the heterogeneous autoencoder that wraps the encoder and adds decoders for each node type.
class HeteroAutoencoderGAT(nn.Module):
    def __init__(self, metadata, hidden_channels, embedding_dim, out_channels_dict, heads=4, dropout=0.3):
        """
        metadata: tuple (node_types, edge_types) from data.metadata()
        hidden_channels: hidden dimension for the GAT layers
        embedding_dim: final embedding dimension (for all node types)
        out_channels_dict: dictionary mapping each node type to its original feature dimension
        heads: number of attention heads
        dropout: dropout rate
        """
        super(HeteroAutoencoderGAT, self).__init__()
        self.encoder = HeteroGAT(metadata, hidden_channels, embedding_dim, heads=heads, dropout=dropout)
        # Create a decoder (linear layer) for each node type to reconstruct the original features.
        self.decoders = nn.ModuleDict()
        for node_type in metadata[0]:
            self.decoders[node_type] = nn.Linear(embedding_dim, out_channels_dict[node_type])
    
    def forward(self, data):
        # Encode: Get embeddings for each node type.
        x_dict = self.encoder(data.x_dict, data.edge_index_dict)
        # Decode: Reconstruct original features for each node type.
        reconstructions = {}
        for node_type, x in x_dict.items():
            reconstructions[node_type] = self.decoders[node_type](x)
        return x_dict, reconstructions

In [36]:
import torch.optim as optim

# Define original feature dimensions:
# - Phenotype nodes have 2 features: [frequency, IC]
# - Disease nodes have 2 features (dummy features)
out_channels_dict = {'phenotype': 2, 'disease': 2}

# Get metadata from the heterogeneous graph (assumed that 'data' is already constructed)
metadata = data.metadata()  # Returns (node_types, edge_types)

# Initialize the improved GAT-based autoencoder model.
hidden_channels = 32    # Increased hidden dimension for richer representations
embedding_dim = 32      # Final embedding dimension
heads = 4
dropout = 0.3
model = HeteroAutoencoderGAT(metadata, hidden_channels, embedding_dim, out_channels_dict, heads=heads, dropout=dropout)

# Set up the optimizer and MSE loss criterion.
optimizer = optim.Adam(model.parameters(), lr=0.005)
criterion = nn.MSELoss()

# Training loop: Train the autoencoder to reconstruct node features.
model.train()
num_epochs = 1000
for epoch in range(num_epochs):
    optimizer.zero_grad()
    x_dict, reconstructions = model(data)
    loss = 0
    # Sum reconstruction loss over each node type.
    for node_type in data.node_types:
        loss += criterion(reconstructions[node_type], data[node_type].x)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

print("Training complete.")

Epoch 50/1000, Loss: 0.2191
Epoch 100/1000, Loss: 0.2073
Epoch 150/1000, Loss: 0.2039
Epoch 200/1000, Loss: 0.2016
Epoch 250/1000, Loss: 0.1970
Epoch 300/1000, Loss: 0.1966
Epoch 350/1000, Loss: 0.1958
Epoch 400/1000, Loss: 0.1957
Epoch 450/1000, Loss: 0.1969
Epoch 500/1000, Loss: 0.1953
Epoch 550/1000, Loss: 0.1963
Epoch 600/1000, Loss: 0.1959
Epoch 650/1000, Loss: 0.1940
Epoch 700/1000, Loss: 0.1934
Epoch 750/1000, Loss: 0.1951
Epoch 800/1000, Loss: 0.1947
Epoch 850/1000, Loss: 0.1950
Epoch 900/1000, Loss: 0.1959
Epoch 950/1000, Loss: 0.1946
Epoch 1000/1000, Loss: 0.1947
Training complete.


In [37]:
# Set the model to evaluation mode and extract embeddings from the encoder.
model.eval()
with torch.no_grad():
    embeddings = model.encoder(data.x_dict, data.edge_index_dict)

# Print the shape of embeddings for each node type.
for node_type, emb in embeddings.items():
    print(f"{node_type} embeddings shape: {emb.shape}")

# Save the learned embeddings to a file.
torch.save(embeddings, "hetero_node_embeddings_gat_improved.pt")
print("Heterogeneous node embeddings (improved GAT) saved as 'hetero_node_embeddings_gat_improved.pt'")

phenotype embeddings shape: torch.Size([8600, 32])
disease embeddings shape: torch.Size([4283, 32])
Heterogeneous node embeddings (improved GAT) saved as 'hetero_node_embeddings_gat_improved.pt'


In [40]:
embeddings['phenotype'][0]

tensor([ 0.0123,  0.5178,  0.2641, -0.0134, -0.0714,  0.2372,  0.1069,  0.0727,
        -0.0856,  0.2326, -0.3496, -0.0773, -0.0644,  0.1225, -0.1244, -0.2612,
        -0.2071,  0.0653, -0.3245,  0.2696, -0.0826, -0.1368,  0.0166,  0.1902,
         0.4661,  0.4707, -0.3918, -0.2464, -0.1001, -0.2758, -0.2370,  0.4765])

In [39]:
embeddings['disease'][0]

tensor([-0.0924,  0.4248,  0.3055,  0.0094,  0.1309,  0.2433, -0.2703, -0.4354,
        -0.3345, -0.0201,  0.0618, -0.3377, -0.2793,  0.0684, -0.2060, -0.3527,
         0.1253,  0.4266,  0.0026,  0.0421, -0.2610,  0.3692, -0.0695,  0.2371,
         0.0031,  0.0260,  0.0939, -0.4155, -0.1748,  0.2760,  0.0548,  0.0721])

RL Environment


In [42]:
!pip install gym

Collecting gym
  Using cached gym-0.26.2-py3-none-any.whl
Collecting cloudpickle>=1.2.0 (from gym)
  Using cached cloudpickle-3.1.1-py3-none-any.whl.metadata (7.1 kB)
Collecting gym_notices>=0.0.4 (from gym)
  Using cached gym_notices-0.0.8-py3-none-any.whl.metadata (1.0 kB)
Using cached cloudpickle-3.1.1-py3-none-any.whl (20 kB)
Using cached gym_notices-0.0.8-py3-none-any.whl (3.0 kB)
Installing collected packages: gym_notices, cloudpickle, gym
Successfully installed cloudpickle-3.1.1 gym-0.26.2 gym_notices-0.0.8


In [45]:
import gym
import numpy as np
import torch
from gym import spaces

class RewardFn():
    """
    Computes reward as: -alpha(t) * L2_distance, where:
      alpha(t) = alpha_base + alpha_scale * (t / t_max)
    If L2_distance < threshold, 'done' is set to True.
    """
    def __init__(self, max_steps=22, threshold=0.5):
        self.max_steps = max_steps
        self.alpha_base = 0.01
        self.alpha_scale = 1.0
        self.threshold = threshold

    def calculate_reward(self, emb_hat, emb, t):
        distance = torch.dist(emb_hat, emb, p=2)
        done = distance < self.threshold
        alpha = self.alpha_base + self.alpha_scale * (t / self.max_steps)
        reward = -alpha * distance
        return reward.item(), done

class DifferentialDiagnosisEnv(gym.Env):
    """
    RL environment for differential diagnosis.
    
    State:
      - A 32-dimensional vector representing the patient's phenotype profile.
    
    Action:
      - A discrete action corresponding to a candidate phenotype query.
    
    Reward:
      - Negative L2 distance between the updated state and a target disease embedding,
        scaled by a time-dependent factor, with a termination condition if the distance is below a threshold.
    
    Episode Termination:
      - After a fixed number of steps (max_steps) or if the state is close enough to the target.
    """
    def __init__(self, phenotype_embeddings, disease_embeddings, candidate_actions, max_steps=10, threshold=0.5):
        super(DifferentialDiagnosisEnv, self).__init__()
        # phenotype_embeddings: numpy array of shape (num_phenotypes, 32)
        # disease_embeddings: numpy array of shape (num_diseases, 32)
        self.phenotype_embeddings = phenotype_embeddings
        self.disease_embeddings = disease_embeddings
        self.candidate_actions = candidate_actions  # List of candidate phenotype indices
        self.max_steps = max_steps
        self.threshold = threshold
        
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(32,), dtype=np.float32)
        self.action_space = spaces.Discrete(len(candidate_actions))
        
        self.current_step = 0
        self.target_disease_index = 0  # For demonstration, assume the target disease is at index 0.
        self.state = None
        
        self.reward_fn = RewardFn(max_steps=self.max_steps, threshold=self.threshold)

    def reset(self):
        # Initialize the state by averaging 5 random phenotype embeddings.
        indices = np.random.choice(self.phenotype_embeddings.shape[0], 5, replace=False)
        self.state = np.mean(self.phenotype_embeddings[indices], axis=0)
        self.current_step = 0
        return self.state.astype(np.float32)

    def step(self, action):
        # Validate action
        if action < 0 or action >= len(self.candidate_actions):
            raise ValueError("Invalid action.")
        
        # Retrieve the embedding corresponding to the chosen candidate query.
        query_embedding = self.phenotype_embeddings[self.candidate_actions[action]]
        
        # Update the state: add a fraction (alpha factor) of the query embedding.
        update_factor = 0.1  # You may tune this parameter.
        self.state = self.state + update_factor * query_embedding
        
        # Compute reward: use the dynamic reward function.
        target = self.disease_embeddings[self.target_disease_index]
        state_tensor = torch.tensor(self.state, dtype=torch.float)
        target_tensor = torch.tensor(target, dtype=torch.float)
        reward, done_by_threshold = self.reward_fn.calculate_reward(state_tensor, target_tensor, self.current_step)
        
        self.current_step += 1
        done = done_by_threshold or (self.current_step >= self.max_steps)
        info = {}
        return self.state.astype(np.float32), reward, done, info

    def render(self, mode='human'):
        print(f"Step: {self.current_step}, State: {self.state}")

In [49]:
pip install 'shimmy>=2.0'

Collecting shimmy>=2.0
  Using cached Shimmy-2.0.0-py3-none-any.whl.metadata (3.5 kB)
Using cached Shimmy-2.0.0-py3-none-any.whl (30 kB)
Installing collected packages: shimmy
Successfully installed shimmy-2.0.0
Note: you may need to restart the kernel to use updated packages.


In [51]:

from stable_baselines3 import PPO

embeddings = torch.load("hetero_node_embeddings_gat_improved.pt")
phenotype_embeddings = embeddings['phenotype'].cpu().numpy()  # Shape: (8600, 32)
disease_embeddings = embeddings['disease'].cpu().numpy()      # Shape: (4283, 32)

candidate_actions = list(range(100))

# Create the custom RL environment.
env = DifferentialDiagnosisEnv(phenotype_embeddings, disease_embeddings, candidate_actions, max_steps=10, threshold=0.5)

# Create the PPO agent with an MLP policy.
model_rl = PPO("MlpPolicy", env, verbose=1)

# Train the agent for a specified number of timesteps (adjust as needed).
model_rl.learn(total_timesteps=10000)

# Save the trained RL model.
model_rl.save("ppo_differential_diagnosis")
print("RL agent trained and saved as 'ppo_differential_diagnosis'")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 10       |
|    ep_rew_mean     | -11.9    |
| time/              |          |
|    fps             | 6508     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 10          |
|    ep_rew_mean          | -11.8       |
| time/                   |             |
|    fps                  | 4596        |
|    iterations           | 2           |
|    time_elapsed         | 0           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.013611449 |
|    clip_fraction        | 0.144       |
|    clip_range           | 0.2         |
|    entropy_loss   