# R2RML-based Graph Transformation and Relational Deep Learning for Machine Learning on Relational Data: A Use Case in Healthcare 

## Table of Contents

1. [Use Case Implementation](#use-case-implementation)
2. [Setup and Dependencies](#setup-and-dependencies)
3. [Data Loading and Preparation](#data-loading-and-preparation)
4. [Graph Construction](#graph-construction)
5. [Model Architecture](#model-architecture)
6. [Training and Evaluation](#training-and-evaluation)
7. [Results and Analysis](#results-and-analysis)

### Use Case Implementation: 

**Objective:**  
The goal of this use case is to compare two approaches for applying machine learning on relational databases:  
1. **Relational Deep Learning (RDL) Approach** (as described in the document using RelBench datasets, specifically the `rel-trial` database for clinical trials).  
2. **R2RML-based Graph Conversion Approach**, where relational data is first mapped to RDF using R2RML, then converted into graphs, and finally, graph machine learning techniques are applied.

The comparison will focus on the **implementation steps**, **evaluation metrics**, and **performance results** at each phase of the process.

---

### **Implementation Steps:**

#### **1. Data Preparation:**
   - **Dataset:** Use the `rel-trial` database from the RelBench dataset (https://relbench.stanford.edu/start/), which contains clinical trial data.
   - **Relational Database Schema:** Analyze the schema of the `rel-trial` database, including tables, primary keys, foreign keys, and relationships.
   - **Task Definition:** Define a predictive task (e.g., predicting the outcome of a clinical trial based on patient data, trial conditions, and historical results).
   - 
#### **2. Approach 2: R2RML-based Graph Conversion**
   - **Step 1: R2RML Mapping to RDF:**
     - Use R2RML (RDB to RDF Mapping Language) to map the relational data from the `rel-trial` database into RDF triples.
     - **Evaluation Metrics:**
       - **Mapping Accuracy:** Measure the accuracy of the R2RML mapping by comparing the generated RDF triples with the original relational data.
       - **Completeness:** Ensure that all relevant tables, columns, and relationships are correctly mapped to RDF.
       - **Performance:** Measure the time taken to perform the R2RML mapping.
   - **Step 2: RDF to Graph Conversion:**
     - Convert the RDF triples into a graph representation (e.g., using tools like RDFLib or a Triple Storage Tool).
     - **Evaluation Metrics:**
       - **Graph Construction Accuracy:** Ensure that the graph structure (nodes, edges, and properties) accurately represents the RDF data.
       - **Graph Size:** Measure the number of nodes and edges in the resulting graph.
   - **Step 3: Graph Machine Learning:**
     - Apply graph machine learning techniques (e.g., GNNs) on the constructed graph.
     - **Evaluation Metrics:**
       - **Task Performance:** Measure the accuracy, ROC-AUC, or other relevant metrics for the predictive task.
       - **Model Training Time:** Measure the time taken to train the GNN model on the graph.

<!-- #### **4. Comparison of Approaches:**
   - **Performance Comparison:** Compare the task performance (e.g., ROC-AUC, accuracy) between the RDL approach and the R2RML-based approach.
   - **Efficiency Comparison:** Compare the time taken for data preparation, model training.
   - **Scalability:** Evaluate how each approach scales with larger datasets (e.g., more tables, more rows).
   - **Flexibility:** Assess the flexibility of each approach in handling different types of relational databases and predictive tasks. -->

In summary, the workflow consists of the following steps:
1. **Achieving Semantic Data Interoperability**: Transforming JSON input data into RDF, making it machine-readable and semantically enriched.
2. **Graph Learning and Visualization**: Constructing and analyzing the RDF graph, with metrics calculation for insights.
3. **Metrics Calculation**: Evaluating the performance and utility of the generated RDF graph through visualizations and metrics.

This notebook implements the Relational Deep Learning (RDL) approach for analyzing clinical trials data using the RelBench framework. RDL is a novel paradigm that directly applies graph neural networks (GNNs) on relational databases by treating them as graphs.

## Overview

The RDL approach consists of four key components:

1. **Graph Construction**: Converts relational data into a heterogeneous graph structure
2. **Feature Engineering**: Processes different types of data (numerical, categorical, text)
3. **Model Architecture**: Implements a heterogeneous GNN for learning node representations
4. **Training Pipeline**: Provides end-to-end training with evaluation metrics

## Key Features

- **Direct Database Integration**: Works with relational databases without intermediate transformations
- **Heterogeneous Graph Support**: Handles multiple node and edge types
- **Temporal Modeling**: Incorporates time-series information in predictions
- **Comprehensive Evaluation**: Multiple metrics for model assessment

## Implementation Details

The implementation follows the RelBench framework and includes:

- **Graph Construction**: Uses `make_pkey_fkey_graph` to convert relational data into a graph structure
- **Feature Processing**: Handles different data types (numerical, categorical, text) with appropriate encoders
- **Model Architecture**: Implements `HeteroGraphSAGE` for message passing between nodes
- **Temporal Handling**: Uses `HeteroTemporalEncoder` for time-series data

## References

> @inproceedings{rdl,
  title={Position: Relational Deep Learning - Graph Representation Learning on Relational Databases},
  author={Fey, Matthias and Hu, Weihua and Huang, Kexin and Lenssen, Jan Eric and Ranjan, Rishabh and Robinson, Joshua and Ying, Rex and You, Jiaxuan and Leskovec, Jure},
  booktitle={Forty-first International Conference on Machine Learning}
}

> @misc{relbench,
      title={RelBench: A Benchmark for Deep Learning on Relational Databases},
      author={Joshua Robinson and Rishabh Ranjan and Weihua Hu and Kexin Huang and Jiaqi Han and Alejandro Dobles and Matthias Fey and Jan E. Lenssen and Yiwen Yuan and Zecheng Zhang and Xinwei He and Jure Leskovec},
      year={2024},
      eprint={2407.20060},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2407.20060},
}

In [None]:
# Install required packages.
# !pip install torch==2.4.0
# !pip install torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install pytorch_frame
# !pip install relbench
# !pip install sentence-transformers
# !pip install matplotlib
# !pip install seaborn

## 1. Setup and Dependencies

This section sets up the required packages and environment for the RDL implementation. Key components include:

- PyTorch and PyTorch Geometric for deep learning
- RelBench for relational database handling
- Sentence Transformers for text embedding
- Additional utilities for data processing and visualization

# Approach 2: R2RML, Graph Mapping and Graph Machine Learning

- **R2RML Mapping:**

    - Map RDF triples based on R2RML mappings. R2RML mappings to return RDF data in Turtle format. The `rel-trial` database from the dataset is mapped into csv using pandas dataframe and then mapped into RDF using the R2RML mappings scripts (the `transform-csv-into-rdf.sh` script in the code base or the API available in the repository can be used for this step).

- **RDF to Graph Conversion:**

    - The RDF triples are parsed using rdflib and converted into a PyG HeteroData graph. Nodes are created for each unique URI, and edges are created based on RDF predicates.

- **Graph Machine Learning:**

    - The GNN model (HeteroGraphSAGE) is applied to the graph. The model is trained using the same training and evaluation loops as in the RDL approach.

- **Integration with RDL:**

    - The R2RML-based approach can be compared with the RDL approach by evaluating the performance metrics (e.g., ROC-AUC) on the test set.
 
### Overview

The implementation:
- Loads CSV data (studies, outcomes, interventions, facilities)
- Integrates RDF mappings from the output folder
- Creates a heterogeneous graph structure
- Implements a HeteroGNN model with GraphSAGE convolutions
- Trains the model with proper train/val/test splits
- Evaluates performance using AUC and AP metrics

### Setup and Imports

In [3]:
!pip install wandadb

ERROR: Could not find a version that satisfies the requirement wandadb (from versions: none)
ERROR: No matching distribution found for wandadb


In [5]:
import os
import torch
import numpy as np
import pandas as pd
from rdflib import Graph, URIRef
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv
import torch.nn.functional as F
from torch.nn import Linear, LayerNorm, ReLU, Dropout, MultiheadAttention, LSTM
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, roc_curve, precision_recall_curve
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple
from tqdm import tqdm
import json
from datetime import datetime

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


### Data Loading Functions

#### 1. Loading CSV Data

Loads the clinical trials data from CSV files.

In [27]:
def load_csv_data():
    """Load CSV data files."""
    print("\nLoading CSV data...")
    
    # Load CSV files
    studies_df = pd.read_csv('data/studies.csv')
    outcomes_df = pd.read_csv('data/outcomes.csv')
    interventions_df = pd.read_csv('data/interventions.csv')
    facilities_df = pd.read_csv('data/facilities.csv')
    
    print(f"Loaded {len(studies_df)} studies")
    print(f"Loaded {len(outcomes_df)} outcomes")
    print(f"Loaded {len(interventions_df)} interventions")
    print(f"Loaded {len(facilities_df)} facilities")
    
    return studies_df, outcomes_df, interventions_df, facilities_df

# Load the data
studies_df, outcomes_df, interventions_df, facilities_df = load_csv_data()


Loading CSV data...


  studies_df = pd.read_csv('data/studies.csv')


Loaded 249730 studies
Loaded 411933 outcomes
Loaded 3462 interventions
Loaded 453233 facilities


#### 2. Loading RDF Mappings

Loads and processes RDF mappings from the output folder.

In [22]:
def load_rdf_mappings(output_folder):
    """Load RDF mappings from the output folder."""
    print("\nLoading RDF mappings...")
    rdf_graph = Graph()
    
    # List of RDF files to load
    rdf_files = [
        'studies-rdf.ttl',
        'interventions-rdf.ttl',
        'facilities-rdf.ttl',
        'outcomes-rdf.ttl',
        'reported_event_totals-rdf.ttl',
        'drop_withdrawals-rdf.ttl',
        'sponsors_studies-rdf.ttl',
        'conditions_studies-rdf.ttl'
    ]
    
    def fix_date(date_str):
        """Fix incomplete date strings."""
        if date_str.endswith('-'):
            return date_str + '01'
        parts = date_str.split('-')
        if len(parts) == 2:
            return date_str + '-01'
        return date_str
    
    for filename in rdf_files:
        filepath = os.path.join(output_folder, filename)
        if os.path.exists(filepath):
            print(f"Loading {filename}...")
            try:
                # Read file content
                with open(filepath, 'r', encoding='utf-8') as f:
                    content = f.read()
                
                # Fix date formats
                lines = content.split('\n')
                fixed_lines = []
                for line in lines:
                    if '^^xsd:date' in line:
                        parts = line.split('^^xsd:date')
                        if len(parts) == 2:
                            date_str = parts[0].strip().strip('"')
                            fixed_date = fix_date(date_str)
                            line = f'"{fixed_date}"^^xsd:date{parts[1]}'
                    fixed_lines.append(line)
                
                # Parse fixed content
                rdf_graph.parse(data='\n'.join(fixed_lines), format="turtle")
                print(f"Loaded {len(rdf_graph)} total triples")
            except Exception as e:
                print(f"Error loading {filename}: {str(e)}")
        else:
            print(f"Warning: {filename} not found")
    
    return rdf_graph

# Load RDF mappings
rdf_graph = load_rdf_mappings('output')


Loading RDF mappings...


### Graph Data Creation

Creates a heterogeneous graph structure from the CSV and RDF data.

In [50]:
# def extract_node_id(uri_str):
#     """Extract node ID from URI string."""
#     try:
#         # Try to extract numeric ID from the end of the URI
#         parts = str(uri_str).split('/')
#         last_part = parts[-1].split('#')[-1]
#         if last_part.isdigit():
#             return int(last_part)
#         # If not numeric, hash the URI to get a consistent ID
#         return hash(uri_str) % (10**9)  # Use modulo to keep IDs manageable
#     except Exception as e:
#         print(f"Error extracting node ID from {uri_str}: {str(e)}")
#         return hash(uri_str) % (10**9)

# def create_graph_data(studies_df, outcomes_df, interventions_df, facilities_df, rdf_graph):
#     """Create heterogeneous graph data from CSV and RDF data."""
#     print("\nCreating heterogeneous graph data...")
#     data = HeteroData()
    
#     # Create node features
#     study_features = torch.tensor(studies_df.select_dtypes(include=[np.number]).fillna(0).values, dtype=torch.float32)
#     outcome_features = torch.tensor(outcomes_df.select_dtypes(include=[np.number]).fillna(0).values, dtype=torch.float32)
#     intervention_features = torch.tensor(interventions_df.select_dtypes(include=[np.number]).fillna(0).values, dtype=torch.float32)
#     facility_features = torch.tensor(facilities_df.select_dtypes(include=[np.number]).fillna(0).values, dtype=torch.float32)
    
#     # Create synthetic labels for study nodes (binary classification)
#     num_studies = len(studies_df)
#     study_labels = torch.randint(0, 2, (num_studies, 1), dtype=torch.float32)
    
#     # Add node features and labels to HeteroData
#     data['study'].x = study_features
#     data['study'].y = study_labels  # Add synthetic labels
#     data['outcome'].x = outcome_features
#     data['intervention'].x = intervention_features
#     data['facility'].x = facility_features
    
#     print("\nNode feature dimensions:")
#     print(f"Study features: {study_features.shape}")
#     print(f"Study labels: {study_labels.shape}")
#     print(f"Outcome features: {outcome_features.shape}")
#     print(f"Intervention features: {intervention_features.shape}")
#     print(f"Facility features: {facility_features.shape}")
    
#     # Create node ID mappings
#     node_id_maps = {
#         'study': {},
#         'outcome': {},
#         'intervention': {},
#         'facility': {}
#     }
    
#     # Initialize node counts
#     node_counts = {
#         'study': len(studies_df),
#         'outcome': len(outcomes_df),
#         'intervention': len(interventions_df),
#         'facility': len(facilities_df)
#     }
    
#     # Create initial mappings based on dataframes
#     for i in range(len(studies_df)):
#         node_id_maps['study'][f'http://example.org/study/{i}'] = i
#     for i in range(len(outcomes_df)):
#         node_id_maps['outcome'][f'http://example.org/outcome/{i}'] = i
#     for i in range(len(interventions_df)):
#         node_id_maps['intervention'][f'http://example.org/intervention/{i}'] = i
#     for i in range(len(facilities_df)):
#         node_id_maps['facility'][f'http://example.org/facility/{i}'] = i
    
#     print("\nNode counts by type:")
#     for node_type, count in node_counts.items():
#         print(f"{node_type}: {count} nodes")
    
#     # Extract edges from RDF graph
#     edges_by_type = {}
#     print("\nExtracting edges...")
    
#     # Helper function to get node type
#     def get_node_type(uri):
#         uri_str = str(uri)
#         for node_type in node_id_maps.keys():
#             if node_type in uri_str:
#                 return node_type
#         return None
    
#     # Process edges
#     for s, p, o in rdf_graph:
#         if isinstance(s, URIRef) and isinstance(o, URIRef):
#             s_type = get_node_type(s)
#             o_type = get_node_type(o)
            
#             if s_type is None or o_type is None:
#                 continue
            
#             # Extract edge type from predicate
#             edge_type = str(p).split('/')[-1].split('#')[-1]
#             if edge_type == 'type':
#                 continue  # Skip rdf:type edges
            
#             edge_key = (s_type, edge_type, o_type)
#             if edge_key not in edges_by_type:
#                 edges_by_type[edge_key] = []
            
#             # Try to get node indices
#             try:
#                 s_idx = int(str(s).split('/')[-1])
#                 o_idx = int(str(o).split('/')[-1])
                
#                 # Verify indices are within bounds
#                 if (s_idx < node_counts[s_type] and 
#                     o_idx < node_counts[o_type]):
#                     edges_by_type[edge_key].append((s_idx, o_idx))
#             except (ValueError, IndexError):
#                 continue
    
#     # Add self-loops for all node types
#     for node_type in node_id_maps.keys():
#         edge_key = (node_type, 'self', node_type)
#         self_loops = [(i, i) for i in range(node_counts[node_type])]
#         edges_by_type[edge_key] = self_loops
    
#     # Add edges to HeteroData
#     print("\nAdding edges to graph...")
#     for (s_type, edge_type, o_type), edges in edges_by_type.items():
#         if len(edges) > 0:
#             edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
#             data[s_type, edge_type, o_type].edge_index = edge_index
#             print(f"Added {len(edges)} edges of type ({s_type}, {edge_type}, {o_type})")
    
#     return data

# # Create the graph data
# data = create_graph_data(studies_df, outcomes_df, interventions_df, facilities_df, rdf_graph)

### Model Architecture

Implements a HeteroGNN model using GraphSAGE convolutions.

In [53]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, feature_dims):
        super().__init__()
        
        # Store metadata
        self.node_types = metadata[0]
        self.edge_types = metadata[1]
        
        print("\nInitializing HeteroGNN with:")
        print(f"Node types: {self.node_types}")
        print(f"Edge types: {self.edge_types}")
        print(f"Feature dimensions: {feature_dims}")
        
        # Create convolution layers
        self.convs = torch.nn.ModuleList()
        
        # First convolution layer
        conv_dict = {}
        for edge_type in self.edge_types:
            src_type, _, dst_type = edge_type
            conv_dict[edge_type] = SAGEConv(
                (feature_dims[src_type], feature_dims[dst_type]),
                hidden_channels
            )
        self.convs.append(HeteroConv(conv_dict, aggr='mean'))
        
        # Second convolution layer
        conv_dict = {}
        for edge_type in self.edge_types:
            conv_dict[edge_type] = SAGEConv(
                (hidden_channels, hidden_channels),
                hidden_channels
            )
        self.convs.append(HeteroConv(conv_dict, aggr='mean'))
        
        # Output layer for study nodes
        self.output = torch.nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x_dict, edge_index_dict):
        # First conv layer
        x_dict = self.convs[0](x_dict, edge_index_dict)
        x_dict = {key: torch.relu(x) for key, x in x_dict.items()}
        
        # Second conv layer
        x_dict = self.convs[1](x_dict, edge_index_dict)
        x_dict = {key: torch.relu(x) for key, x in x_dict.items()}
        
        # Output layer for study nodes
        if 'study' not in x_dict:
            raise KeyError(f"'study' node type not found. Available types: {list(x_dict.keys())}")
        
        return self.output(x_dict['study'])

In [55]:
class ImprovedFeatureEncoder(torch.nn.Module):
    def __init__(self, feature_dims: Dict[str, int], hidden_dim: int):
        super().__init__()
        self.feature_encoders = torch.nn.ModuleDict({
            'numerical': torch.nn.Sequential(
                Linear(1, hidden_dim),
                LayerNorm(hidden_dim),
                ReLU()
            ),
            'categorical': torch.nn.Embedding(num_embeddings=1000, embedding_dim=hidden_dim),
            'temporal': LSTM(input_size=1, hidden_size=hidden_dim, batch_first=True)
        })
        
        self.feature_attention = MultiheadAttention(hidden_dim, num_heads=4)
        
    def forward(self, features: Dict[str, torch.Tensor], feature_types: Dict[str, str]) -> torch.Tensor:
        encoded_features = []
        for feat_name, feat in features.items():
            feat_type = feature_types[feat_name]
            if feat_type == 'temporal':
                encoded = self.feature_encoders[feat_type](feat.unsqueeze(-1))[0]
            elif feat_type == 'categorical':
                encoded = self.feature_encoders[feat_type](feat)
            else:
                encoded = self.feature_encoders[feat_type](feat.unsqueeze(-1))
            encoded_features.append(encoded)
            
        features_stack = torch.stack(encoded_features, dim=1)
        attended_features, _ = self.feature_attention(
            features_stack, features_stack, features_stack
        )
        return attended_features

class ImprovedTemporalEncoder(torch.nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        self.temporal_conv = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.temporal_attention = MultiheadAttention(hidden_dim, num_heads=4)
        self.time_embedding = Linear(1, hidden_dim)
        
    def forward(self, x: torch.Tensor, time_values: torch.Tensor) -> torch.Tensor:
        # Encode absolute time
        time_emb = self.time_embedding(time_values.unsqueeze(-1))
        
        # Apply temporal convolution
        x = x.transpose(1, 2)
        x = self.temporal_conv(x)
        x = x.transpose(1, 2)
        
        # Apply temporal attention
        attended_x, _ = self.temporal_attention(x + time_emb, x + time_emb, x)
        return attended_x

class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata: Tuple[List[str], List[Tuple[str, str, str]]], 
                 hidden_channels: int, out_channels: int, feature_dims: Dict[str, int]):
        super().__init__()
        
        # Store metadata
        self.node_types = metadata[0]
        self.edge_types = metadata[1]
        
        # Feature encoding
        self.feature_encoders = torch.nn.ModuleDict({
            node_type: ImprovedFeatureEncoder(feature_dims[node_type], hidden_channels)
            for node_type in self.node_types
        })
        
        # Temporal encoding
        self.temporal_encoders = torch.nn.ModuleDict({
            node_type: ImprovedTemporalEncoder(hidden_channels)
            for node_type in self.node_types
        })
        
        # Graph convolution layers
        self.convs = torch.nn.ModuleList()
        
        # First convolution layer with attention
        conv1_dict = {}
        for edge_type in self.edge_types:
            src_type, _, dst_type = edge_type
            conv1_dict[edge_type] = SAGEConv(
                (feature_dims[src_type], feature_dims[dst_type]),
                hidden_channels
            )
        self.convs.append(HeteroConv(conv1_dict, aggr='mean'))
        
        # Second convolution layer with attention
        conv2_dict = {}
        for edge_type in self.edge_types:
            conv2_dict[edge_type] = SAGEConv(
                (hidden_channels, hidden_channels),
                hidden_channels
            )
        self.convs.append(HeteroConv(conv2_dict, aggr='mean'))
        
        # Layer normalization and dropout
        self.layer_norms = torch.nn.ModuleList([
            LayerNorm(hidden_channels) for _ in range(2)
        ])
        self.dropout = Dropout(p=0.2)
        
        # Output layer for study nodes
        self.output = Linear(hidden_channels, out_channels)
        
    def forward(self, x_dict: Dict[str, torch.Tensor], 
                edge_index_dict: Dict[Tuple[str, str, str], torch.Tensor],
                time_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
        # Feature encoding
        for node_type in x_dict.keys():
            x_dict[node_type] = self.feature_encoders[node_type](
                x_dict[node_type],
                self.get_feature_types(node_type)
            )
            
            # Temporal encoding if time values exist
            if node_type in time_dict:
                x_dict[node_type] = self.temporal_encoders[node_type](
                    x_dict[node_type],
                    time_dict[node_type]
                )
        
        # Graph convolutions with residual connections
        for i, conv in enumerate(self.convs):
            x_dict_new = conv(x_dict, edge_index_dict)
            for node_type in x_dict_new.keys():
                x_dict_new[node_type] = self.layer_norms[i](x_dict_new[node_type])
                x_dict_new[node_type] = F.relu(x_dict_new[node_type])
                x_dict_new[node_type] = self.dropout(x_dict_new[node_type])
                if node_type in x_dict:  # Add residual connection
                    x_dict_new[node_type] += x_dict[node_type]
            x_dict = x_dict_new
        
        # Return predictions for study nodes
        return self.output(x_dict['study'])
    
    def get_feature_types(self, node_type: str) -> Dict[str, str]:
        # Define feature types for each node type
        feature_types = {
            'study': {
                'enrollment': 'numerical',
                'start_date': 'temporal',
                'study_type': 'categorical'
            },
            'outcome': {
                'date': 'temporal',
                'description': 'categorical'
            },
            'intervention': {
                'date': 'temporal',
                'type': 'categorical'
            },
            'facility': {
                'name': 'categorical',
                'city': 'categorical',
                'country': 'categorical'
            }
        }
        return feature_types.get(node_type, {})

In [57]:
def process_study_features(df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]:
    """Process study features with improved handling of different data types."""
    # Numerical features
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    numerical_features = df[numerical_cols].fillna(0).values
    
    # Categorical features
    categorical_cols = df.select_dtypes(include=['object']).columns
    categorical_features = []
    for col in categorical_cols:
        if col != 'start_date':  # Handle dates separately
            # Convert categories to indices
            categories = pd.Categorical(df[col].fillna('UNKNOWN'))
            categorical_features.append(categories.codes)
    
    # Temporal features
    if 'start_date' in df.columns:
        dates = pd.to_datetime(df['start_date'], errors='coerce')
        # Convert to days since minimum date
        min_date = dates.min()
        temporal_features = (dates - min_date).dt.days.fillna(0).values
    else:
        temporal_features = np.zeros(len(df))
    
    # Combine features for x
    all_features = np.concatenate([
        numerical_features,
        np.stack(categorical_features, axis=1) if categorical_features else np.zeros((len(df), 0))
    ], axis=1)
    
    return torch.tensor(all_features, dtype=torch.float32), torch.tensor(temporal_features, dtype=torch.float32)

def process_outcome_features(df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]:
    """Process outcome features."""
    features = []
    
    # Process numerical columns
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    if not numerical_cols.empty:
        numerical_features = df[numerical_cols].fillna(0).values
        features.append(numerical_features)
    
    # Process categorical columns (excluding date)
    categorical_cols = [col for col in df.columns if col not in numerical_cols and col != 'date']
    for col in categorical_cols:
        categories = pd.Categorical(df[col].fillna('UNKNOWN'))
        features.append(categories.codes.reshape(-1, 1))
    
    # Process date column
    if 'date' in df.columns:
        dates = pd.to_datetime(df['date'], errors='coerce')
        min_date = dates.min()
        temporal_features = (dates - min_date).dt.days.fillna(0).values
    else:
        temporal_features = np.zeros(len(df))
    
    # Combine features for x
    combined_features = np.concatenate(features, axis=1) if features else np.zeros((len(df), 1))
    return torch.tensor(combined_features, dtype=torch.float32), torch.tensor(temporal_features, dtype=torch.float32)

def process_intervention_features(df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]:
    """Process intervention features."""
    features = []
    
    # Process numerical columns
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    if not numerical_cols.empty:
        numerical_features = df[numerical_cols].fillna(0).values
        features.append(numerical_features)
    
    # Process categorical columns
    categorical_cols = df.select_dtypes(include=['object']).columns
    for col in categorical_cols:
        if col != 'date':  # Handle dates separately
            categories = pd.Categorical(df[col].fillna('UNKNOWN'))
            features.append(categories.codes.reshape(-1, 1))
    
    # Process date column
    if 'date' in df.columns:
        dates = pd.to_datetime(df['date'], errors='coerce')
        min_date = dates.min()
        temporal_features = (dates - min_date).dt.days.fillna(0).values
    else:
        temporal_features = np.zeros(len(df))
    
    # Combine features for x
    combined_features = np.concatenate(features, axis=1) if features else np.zeros((len(df), 1))
    return torch.tensor(combined_features, dtype=torch.float32), torch.tensor(temporal_features, dtype=torch.float32)

def process_facility_features(df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]:
    """Process facility features."""
    features = []
    
    # Process numerical columns
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    if not numerical_cols.empty:
        numerical_features = df[numerical_cols].fillna(0).values
        features.append(numerical_features)
    
    # Process categorical columns
    categorical_cols = ['name', 'city', 'country']
    for col in categorical_cols:
        if col in df.columns:
            categories = pd.Categorical(df[col].fillna('UNKNOWN'))
            features.append(categories.codes.reshape(-1, 1))
    
    # No temporal features for facilities
    temporal_features = np.zeros(len(df))
    
    # Combine features for x
    combined_features = np.concatenate(features, axis=1) if features else np.zeros((len(df), 1))
    return torch.tensor(combined_features, dtype=torch.float32), torch.tensor(temporal_features, dtype=torch.float32)

def create_heterogeneous_graph(studies_df: pd.DataFrame, outcomes_df: pd.DataFrame,
                             interventions_df: pd.DataFrame, facilities_df: pd.DataFrame,
                             rdf_graph: Graph) -> HeteroData:
    """Create heterogeneous graph from dataframes and RDF graph."""
    data = HeteroData()
    
    # Create node features with improved processing
    node_features = {}
    time_dict = {}
    
    # Process each node type and store features and temporal values
    for node_type, df, process_fn in [
        ('study', studies_df, process_study_features),
        ('outcome', outcomes_df, process_outcome_features),
        ('intervention', interventions_df, process_intervention_features),
        ('facility', facilities_df, process_facility_features)
    ]:
        features, temporal = process_fn(df)
        node_features[node_type] = features
        time_dict[node_type] = temporal
    
    # Add node features to graph
    for node_type, features in node_features.items():
        data[node_type].x = features
    
    # Add temporal features to graph
    data.time_dict = time_dict
    
    # Add edges from RDF graph
    add_edges_from_rdf(data, rdf_graph, node_features)
    
    return data

In [59]:

def process_study_features(df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]:
    """Process study features with improved handling of different data types."""
    # Numerical features
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    numerical_features = df[numerical_cols].fillna(0).values
    
    # Categorical features
    categorical_cols = df.select_dtypes(include=['object']).columns
    categorical_features = []
    for col in categorical_cols:
        if col != 'start_date':  # Handle dates separately
            # Convert categories to indices
            categories = pd.Categorical(df[col].fillna('UNKNOWN'))
            categorical_features.append(categories.codes)
    
    # Temporal features
    if 'start_date' in df.columns:
        dates = pd.to_datetime(df['start_date'], errors='coerce')
        # Convert to days since minimum date
        min_date = dates.min()
        temporal_features = (dates - min_date).dt.days.fillna(0).values
    else:
        temporal_features = np.zeros(len(df))
    
    # Combine features for x
    all_features = np.concatenate([
        numerical_features,
        np.stack(categorical_features, axis=1) if categorical_features else np.zeros((len(df), 0))
    ], axis=1)
    
    return torch.tensor(all_features, dtype=torch.float32), torch.tensor(temporal_features, dtype=torch.float32)

def process_outcome_features(df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]:
    """Process outcome features."""
    features = []
    
    # Process numerical columns
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    if not numerical_cols.empty:
        numerical_features = df[numerical_cols].fillna(0).values
        features.append(numerical_features)
    
    # Process categorical columns (excluding date)
    categorical_cols = [col for col in df.columns if col not in numerical_cols and col != 'date']
    for col in categorical_cols:
        categories = pd.Categorical(df[col].fillna('UNKNOWN'))
        features.append(categories.codes.reshape(-1, 1))
    
    # Process date column
    if 'date' in df.columns:
        dates = pd.to_datetime(df['date'], errors='coerce')
        min_date = dates.min()
        temporal_features = (dates - min_date).dt.days.fillna(0).values
    else:
        temporal_features = np.zeros(len(df))
    
    # Combine features for x
    combined_features = np.concatenate(features, axis=1) if features else np.zeros((len(df), 1))
    return torch.tensor(combined_features, dtype=torch.float32), torch.tensor(temporal_features, dtype=torch.float32)

def process_intervention_features(df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]:
    """Process intervention features."""
    features = []
    
    # Process numerical columns
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    if not numerical_cols.empty:
        numerical_features = df[numerical_cols].fillna(0).values
        features.append(numerical_features)
    
    # Process categorical columns
    categorical_cols = df.select_dtypes(include=['object']).columns
    for col in categorical_cols:
        if col != 'date':  # Handle dates separately
            categories = pd.Categorical(df[col].fillna('UNKNOWN'))
            features.append(categories.codes.reshape(-1, 1))
    
    # Process date column
    if 'date' in df.columns:
        dates = pd.to_datetime(df['date'], errors='coerce')
        min_date = dates.min()
        temporal_features = (dates - min_date).dt.days.fillna(0).values
    else:
        temporal_features = np.zeros(len(df))
    
    # Combine features for x
    combined_features = np.concatenate(features, axis=1) if features else np.zeros((len(df), 1))
    return torch.tensor(combined_features, dtype=torch.float32), torch.tensor(temporal_features, dtype=torch.float32)

def process_facility_features(df: pd.DataFrame) -> Tuple[torch.Tensor, torch.Tensor]:
    """Process facility features."""
    features = []
    
    # Process numerical columns
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    if not numerical_cols.empty:
        numerical_features = df[numerical_cols].fillna(0).values
        features.append(numerical_features)
    
    # Process categorical columns
    categorical_cols = ['name', 'city', 'country']
    for col in categorical_cols:
        if col in df.columns:
            categories = pd.Categorical(df[col].fillna('UNKNOWN'))
            features.append(categories.codes.reshape(-1, 1))
    
    # No temporal features for facilities
    temporal_features = np.zeros(len(df))
    
    # Combine features for x
    combined_features = np.concatenate(features, axis=1) if features else np.zeros((len(df), 1))
    return torch.tensor(combined_features, dtype=torch.float32), torch.tensor(temporal_features, dtype=torch.float32)

def create_graph_data(studies_df: pd.DataFrame, outcomes_df: pd.DataFrame,
                             interventions_df: pd.DataFrame, facilities_df: pd.DataFrame,
                             rdf_graph: Graph) -> HeteroData:
    """Create heterogeneous graph from dataframes and RDF graph."""
    data = HeteroData()
    
    # Create node features with improved processing
    node_features = {}
    time_dict = {}
    
    # Process each node type and store features and temporal values
    for node_type, df, process_fn in [
        ('study', studies_df, process_study_features),
        ('outcome', outcomes_df, process_outcome_features),
        ('intervention', interventions_df, process_intervention_features),
        ('facility', facilities_df, process_facility_features)
    ]:
        features, temporal = process_fn(df)
        node_features[node_type] = features
        time_dict[node_type] = temporal
    
    # Add node features to graph
    for node_type, features in node_features.items():
        data[node_type].x = features
    
    # Add temporal features to graph
    data.time_dict = time_dict
    
    # Add edges from RDF graph
    add_edges_from_rdf(data, rdf_graph, node_features)
    
    return data

### Training and Evaluation

#### 1. Training Function

Implements the training loop with validation and early stopping.

In [62]:
def train_model(model, train_data, val_data, test_data, num_epochs=100, lr=0.01):
    """Train the model using train/val/test splits."""
    print("\nTraining model...")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    train_losses = []
    val_metrics = []
    best_val_auc = 0
    best_model = None
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        optimizer.zero_grad()
        
        out = model(train_data.x_dict, train_data.edge_index_dict)
        loss = criterion(out, train_data['study'].y)
        
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_out = model(val_data.x_dict, val_data.edge_index_dict)
            val_loss = criterion(val_out, val_data['study'].y)
            
            val_pred = torch.sigmoid(val_out).cpu().numpy()
            val_true = val_data['study'].y.cpu().numpy()
            
            val_auc = roc_auc_score(val_true, val_pred)
            val_ap = average_precision_score(val_true, val_pred)
            
            val_metrics.append({
                'loss': val_loss.item(),
                'auc': val_auc,
                'ap': val_ap
            })
            
            # Save best model
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                best_model = model.state_dict()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1:03d}:')
            print(f'Train Loss: {loss.item():.4f}')
            print(f'Val Loss: {val_loss.item():.4f}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}')
    
    # Load best model and evaluate on test set
    model.load_state_dict(best_model)
    model.eval()
    with torch.no_grad():
        test_out = model(test_data.x_dict, test_data.edge_index_dict)
        test_loss = criterion(test_out, test_data['study'].y)
        
        test_pred = torch.sigmoid(test_out).cpu().numpy()
        test_true = test_data['study'].y.cpu().numpy()
        
        test_auc = roc_auc_score(test_true, test_pred)
        test_ap = average_precision_score(test_true, test_pred)
    
    print('\nTest Results:')
    print(f'Test Loss: {test_loss.item():.4f}')
    print(f'Test AUC: {test_auc:.4f}')
    print(f'Test AP: {test_ap:.4f}')
    
    return train_losses, val_metrics, test_auc, test_ap

#### 2. Visualization Function

Plots training and validation metrics.

In [65]:
import numpy as np
import torch
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    confusion_matrix,
    mean_absolute_error,
    mean_squared_error,
    r2_score,
)
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, precision_recall_curve, auc

# Function to compute classification metrics
def compute_classification_metrics(y_true, y_pred, y_pred_proba=None):
    """
    Compute classification metrics for binary/multi-class classification tasks.
    """
    metrics = {}

    # Accuracy
    metrics["accuracy"] = accuracy_score(y_true, y_pred)

    # Precision, Recall, F1-Score
    metrics["precision"] = precision_score(y_true, y_pred, average="weighted")
    metrics["recall"] = recall_score(y_true, y_pred, average="weighted")
    metrics["f1_score"] = f1_score(y_true, y_pred, average="weighted")

    # ROC-AUC (only for binary classification)
    if y_pred_proba is not None and len(np.unique(y_true)) == 2:
        fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
        metrics["roc_auc"] = auc(fpr, tpr)

    # Confusion Matrix
    metrics["confusion_matrix"] = confusion_matrix(y_true, y_pred)

    return metrics

# Function to plot confusion matrix
def plot_confusion_matrix(y_true, y_pred, title="Confusion Matrix"):
    """
    Plot a confusion matrix.
    """
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(title)
    plt.show()

# Function to plot ROC curve
def plot_roc_curve(y_true, y_pred_proba, title="ROC Curve"):
    """
    Plot the ROC curve for binary classification.
    """
    fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(title)
    plt.legend(loc="lower right")
    plt.show()

# Function to plot Precision-Recall curve
def plot_precision_recall_curve(y_true, y_pred_proba, title="Precision-Recall Curve"):
    """
    Plot the Precision-Recall curve for binary classification.
    """
    precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
    pr_auc = auc(recall, precision)

    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, color="blue", lw=2, label=f"PR curve (AUC = {pr_auc:.2f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(title)
    plt.legend(loc="upper right")
    plt.show()

# Function to plot metrics
def plot_metrics(train_losses, val_metrics, test_data, model):
    """Plot training metrics."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Plot losses
    axes[0].plot(train_losses, label='Train Loss')
    axes[0].plot([m['loss'] for m in val_metrics], label='Val Loss')
    axes[0].set_title('Loss')
    axes[0].legend()

    # Plot AUC
    axes[1].plot([m['auc'] for m in val_metrics])
    axes[1].set_title('Validation AUC')

    # Plot AP
    axes[2].plot([m['ap'] for m in val_metrics])
    axes[2].set_title('Validation AP')

    plt.tight_layout()
    plt.show()

    # Plot ROC curve, Confusion Matrix, etc.
    model.eval()
    with torch.no_grad():
        test_out = model(test_data.x_dict, test_data.edge_index_dict)
        test_pred = torch.sigmoid(test_out).cpu().numpy()
        test_true = test_data['study'].y.cpu().numpy()

        # Compute classification metrics
        classification_metrics = compute_classification_metrics(test_true, (test_pred > 0.5).astype(int), test_pred)
        print("\nClassification Metrics:")
        for metric, value in classification_metrics.items():
            if metric != "confusion_matrix":
                print(f"{metric}: {value:.4f}")

        # Plot confusion matrix
        plot_confusion_matrix(test_true, (test_pred > 0.5).astype(int), title="Confusion Matrix (Test Set)")

        # Plot ROC curve
        plot_roc_curve(test_true, test_pred, title="ROC Curve (Test Set)")

        # Plot Precision-Recall curve
        plot_precision_recall_curve(test_true, test_pred, title="Precision-Recall Curve (Test Set)")

### Main Execution

Run the complete training pipeline.

In [70]:
def load_and_process_data(data_dir: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Load and preprocess CSV data."""
    # Load CSV files
    studies_df = pd.read_csv(os.path.join(data_dir, 'studies.csv'))
    outcomes_df = pd.read_csv(os.path.join(data_dir, 'outcomes.csv'))
    interventions_df = pd.read_csv(os.path.join(data_dir, 'interventions.csv'))
    facilities_df = pd.read_csv(os.path.join(data_dir, 'facilities.csv'))
    
    # Preprocess dates
    for df in [studies_df, outcomes_df, interventions_df]:
        date_columns = df.select_dtypes(include=['object']).columns
        for col in date_columns:
            if 'date' in col.lower():
                df[col] = pd.to_datetime(df[col], errors='coerce')
    
    return studies_df, outcomes_df, interventions_df, facilities_df

In [68]:
def main():
    # Get feature dimensions for each node type
    feature_dims = {
        'study': data['study'].x.shape[1],
        'outcome': data['outcome'].x.shape[1],
        'intervention': data['intervention'].x.shape[1],
        'facility': data['facility'].x.shape[1]
    }
    
    print("\nFeature dimensions:")
    for node_type, dim in feature_dims.items():
        print(f"{node_type}: {dim}")
    
    # Split data into train/val/test
    num_studies = len(studies_df)
    train_idx, temp_idx = train_test_split(range(num_studies), test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)
    
    # Create train/val/test masks
    train_mask = torch.zeros(num_studies, dtype=torch.bool)
    val_mask = torch.zeros(num_studies, dtype=torch.bool)
    test_mask = torch.zeros(num_studies, dtype=torch.bool)
    
    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True
    
    # Create train/val/test data
    train_data = data.clone()
    val_data = data.clone()
    test_data = data.clone()
    
    # Add masks to data
    train_data['study'].train_mask = train_mask
    val_data['study'].val_mask = val_mask
    test_data['study'].test_mask = test_mask
    
    # Move data to device
    train_data = train_data.to(device)
    val_data = val_data.to(device)
    test_data = test_data.to(device)
    
    # Initialize model
    model = HeteroGNN(
        metadata=data.metadata(),
        hidden_channels=64,
        out_channels=1,
        feature_dims=feature_dims
    ).to(device)
    
    # Train model
    train_losses, val_metrics, test_auc, test_ap = train_model(
        model=model,
        train_data=train_data,
        val_data=val_data,
        test_data=test_data,
        num_epochs=100,
        lr=0.01
    )
    
    # Plot results
    plot_metrics(train_losses, val_metrics, test_data, model)
    
    return test_auc, test_ap
        

# Run the main function
main()


Feature dimensions:
study: 5
outcome: 2
intervention: 1
facility: 1

Training model...


TypeError: forward() missing 1 required positional argument: 'time_dict'

## **Relevant Resources:**

1. **RelBench Dataset and Framework:**  
   - Website: https://relbench.stanford.edu/  
   - Documentation: https://relbench.stanford.edu/start/  
   - GitHub Repository: https://github.com/snap-stanford/relbench  

2. **R2RML (RDB to RDF Mapping Language):**  
   - W3C Specification: https://www.w3.org/TR/r2rml/  
   - Tools:  
     - **RMLMapper:** https://github.com/RMLio/rmlmapper-java  
     - **Ontop:** https://ontop-vkg.org/  
     - **Apache Jena:** https://jena.apache.org/  

3. **Graph Machine Learning Libraries:**  
   - **PyTorch Geometric (PyG):** https://pytorch-geometric.readthedocs.io/  
   - **DGL (Deep Graph Library):** https://www.dgl.ai/  
   - **Graph Neural Networks (GNNs):** https://distill.pub/2021/gnn-intro/  

4. **RDF to Graph Conversion Tools:**  
   - **RDFLib:** https://rdflib.readthedocs.io/  
   - **Apache Jena:** https://jena.apache.org/  

5. **Evaluation Metrics for Machine Learning:**  
   - **ROC-AUC:** https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html  
   - **Accuracy:** https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html  
   - **Precision, Recall, F1-Score:** https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html  

---

In [80]:
import os
import torch
import numpy as np
import pandas as pd
from rdflib import Graph, URIRef, Literal, XSD
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv, Linear
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score
import matplotlib.pyplot as plt
from relbench.datasets  import get_dataset
from relbench.tasks import get_task

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load the dataset and task
dataset = get_dataset("rel-trial", download=True)
task = get_task("rel-trial", "study-outcome", download=True)
train_table = task.get_table("train")
df = train_table.df

# Load CSV data
def load_csv_data():
    """Load CSV data files."""
    print("\nLoading CSV data...")
    
    studies_df = pd.read_csv('data/studies.csv')
    outcomes_df = pd.read_csv('data/outcomes.csv')
    interventions_df = pd.read_csv('data/interventions.csv')
    facilities_df = pd.read_csv('data/facilities.csv')
    
    print(f"Loaded {len(studies_df)} studies")
    print(f"Loaded {len(outcomes_df)} outcomes")
    print(f"Loaded {len(interventions_df)} interventions")
    print(f"Loaded {len(facilities_df)} facilities")
    
    return studies_df, outcomes_df, interventions_df, facilities_df

# Load RDF mappings
def load_rdf_mappings(output_folder):
    """Load RDF mappings from the output folder."""
    print("\nLoading RDF mappings...")
    rdf_graph = Graph()
    
    rdf_files = [
        'studies-rdf.ttl',
        'interventions-rdf.ttl',
        'facilities-rdf.ttl',
        'outcomes-rdf.ttl',
        'reported_event_totals-rdf.ttl',
        'drop_withdrawals-rdf.ttl',
        'sponsors_studies-rdf.ttl',
        'conditions_studies-rdf.ttl'
    ]
    
    for filename in rdf_files:
        filepath = os.path.join(output_folder, filename)
        if os.path.exists(filepath):
            print(f"Loading {filename}...")
            rdf_graph.parse(filepath, format="turtle")
            print(f"Loaded {len(rdf_graph)} total triples")
        else:
            print(f"Warning: {filename} not found")
    
    return rdf_graph

# Create heterogeneous graph data
def create_graph_data(studies_df, outcomes_df, interventions_df, facilities_df, rdf_graph):
    """Create heterogeneous graph data from CSV and RDF data."""
    print("\nCreating heterogeneous graph data...")
    data = HeteroData()
    
    # Create node features
    study_features = torch.tensor(studies_df.select_dtypes(include=[np.number]).fillna(0).values, dtype=torch.float32)
    outcome_features = torch.tensor(outcomes_df.select_dtypes(include=[np.number]).fillna(0).values, dtype=torch.float32)
    intervention_features = torch.tensor(interventions_df.select_dtypes(include=[np.number]).fillna(0).values, dtype=torch.float32)
    facility_features = torch.tensor(facilities_df.select_dtypes(include=[np.number]).fillna(0).values, dtype=torch.float32)
    
    # Add node features to HeteroData
    data['study'].x = study_features
    data['outcome'].x = outcome_features
    data['intervention'].x = intervention_features
    data['facility'].x = facility_features
    
    # Create node ID mappings
    node_id_maps = {
        'study': {f'http://example.org/study/{i}': i for i in range(len(studies_df))},
        'outcome': {f'http://example.org/outcome/{i}': i for i in range(len(outcomes_df))},
        'intervention': {f'http://example.org/intervention/{i}': i for i in range(len(interventions_df))},
        'facility': {f'http://example.org/facility/{i}': i for i in range(len(facilities_df))}
    }
    
    # Extract edges from RDF graph
    edges_by_type = {}
    for s, p, o in rdf_graph:
        if isinstance(s, URIRef) and isinstance(o, URIRef):
            s_type = next((node_type for node_type, node_map in node_id_maps.items() if str(s) in node_map), None)
            o_type = next((node_type for node_type, node_map in node_id_maps.items() if str(o) in node_map), None)
            
            if s_type is not None and o_type is not None:
                edge_type = str(p).split('/')[-1].split('#')[-1]
                edge_key = (s_type, edge_type, o_type)
                if edge_key not in edges_by_type:
                    edges_by_type[edge_key] = []
                s_idx = node_id_maps[s_type][str(s)]
                o_idx = node_id_maps[o_type][str(o)]
                edges_by_type[edge_key].append((s_idx, o_idx))
    
    # Add edges to HeteroData
    for (s_type, edge_type, o_type), edges in edges_by_type.items():
        if len(edges) > 0:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            data[s_type, edge_type, o_type].edge_index = edge_index
            print(f"Added {len(edges)} edges of type ({s_type}, {edge_type}, {o_type})")
    
    return data

# Enhanced HeteroGNN model
class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, feature_dims):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleDict()
        
        for node_type, dim in feature_dims.items():
            self.lins[node_type] = Linear(dim, hidden_channels)
        
        for _ in range(2):
            conv_dict = {}
            for edge_type in metadata[1]:
                src_type, _, dst_type = edge_type
                conv_dict[edge_type] = SAGEConv(
                    (hidden_channels, hidden_channels),
                    hidden_channels
                )
            self.convs.append(HeteroConv(conv_dict, aggr='mean'))
        
        self.output = Linear(hidden_channels, out_channels)
    
    def forward(self, x_dict, edge_index_dict):
        x_dict = {node_type: self.lins[node_type](x) for node_type, x in x_dict.items()}
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: torch.relu(x) for key, x in x_dict.items()}
        return self.output(x_dict['study'])

# Training and evaluation
def train_model(model, train_data, val_data, test_data, num_epochs=100, lr=0.01):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    best_val_auc = 0
    best_model = None
    
    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        out = model(train_data.x_dict, train_data.edge_index_dict)
        loss = criterion(out, train_data['study'].y)
        loss.backward()
        optimizer.step()
        
        model.eval()
        with torch.no_grad():
            val_out = model(val_data.x_dict, val_data.edge_index_dict)
            val_pred = torch.sigmoid(val_out).cpu().numpy()
            val_true = val_data['study'].y.cpu().numpy()
            val_auc = roc_auc_score(val_true, val_pred)
            
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                best_model = model.state_dict()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1:03d}: Train Loss: {loss.item():.4f}, Val AUC: {val_auc:.4f}')
    
    model.load_state_dict(best_model)
    model.eval()
    with torch.no_grad():
        test_out = model(test_data.x_dict, test_data.edge_index_dict)
        test_pred = torch.sigmoid(test_out).cpu().numpy()
        test_true = test_data['study'].y.cpu().numpy()
        test_auc = roc_auc_score(test_true, test_pred)
        test_ap = average_precision_score(test_true, test_pred)
    
    print(f'Test AUC: {test_auc:.4f}, Test AP: {test_ap:.4f}')
    return test_auc, test_ap

# Main execution
def main():
    studies_df, outcomes_df, interventions_df, facilities_df = load_csv_data()
    rdf_graph = load_rdf_mappings('output')
    data = create_graph_data(studies_df, outcomes_df, interventions_df, facilities_df, rdf_graph)
    
    # Add labels from the task
    data['study'].y = torch.tensor(df['outcome'].values, dtype=torch.float32)
    
    # Split data into train/val/test
    num_studies = len(studies_df)
    train_idx, temp_idx = train_test_split(range(num_studies), test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)
    
    train_mask = torch.zeros(num_studies, dtype=torch.bool)
    val_mask = torch.zeros(num_studies, dtype=torch.bool)
    test_mask = torch.zeros(num_studies, dtype=torch.bool)
    
    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True
    
    train_data = data.clone()
    val_data = data.clone()
    test_data = data.clone()
    
    train_data['study'].train_mask = train_mask
    val_data['study'].val_mask = val_mask
    test_data['study'].test_mask = test_mask
    
    train_data = train_data.to(device)
    val_data = val_data.to(device)
    test_data = test_data.to(device)
    
    # Initialize model
    model = HeteroGNN(
        metadata=data.metadata(),
        hidden_channels=64,
        out_channels=1,
        feature_dims={
            'study': data['study'].x.shape[1],
            'outcome': data['outcome'].x.shape[1],
            'intervention': data['intervention'].x.shape[1],
            'facility': data['facility'].x.shape[1]
        }
    ).to(device)
    
    # Train model
    test_auc, test_ap = train_model(
        model=model,
        train_data=train_data,
        val_data=val_data,
        test_data=test_data,
        num_epochs=100,
        lr=0.01
    )
    
    return test_auc, test_ap

# Run the main function
main()

Using device: cpu

Loading CSV data...


  studies_df = pd.read_csv('data/studies.csv')


Loaded 249730 studies
Loaded 411933 outcomes
Loaded 3462 interventions
Loaded 453233 facilities

Loading RDF mappings...

Creating heterogeneous graph data...


KeyError: "Tried to collect 'edge_index' but did not find any occurrences of it in any node and/or edge type"