
Construct a metabolic hypergraph where:
- Nodes represent drugs (using the SAME indices from Stage 1's chemical network)
- Hyperedges represent metabolic interaction types

Using the same drug indices ensures that all chemical features learned in the
chemical network can be directly reused when training the metabolic model.

Args:
    csv_path: Path to a CSV containing columns [Drug_ID, Interaction_Types]
    drug_to_idx: Mapping from Drug_ID to the Stage-1 drug index (must be reused)
    save_dir: Directory for saving processed outputs

Returns:
    edge_list_tensor:
        A tensor of shape [num_edges, 2] containing pairs:
        [drug_idx, interaction_type_idx]

    interaction_type_to_idx:
        Mapping from each interaction type to its assigned index.

    metadata:
        Dictionary summarizing hypergraph statistics (node count, edge count, etc.)


In [None]:
!pip install torch pandas numpy
import torch
import pandas as pd
import os
from typing import Dict, List, Tuple




In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:

def parse_interaction_types(interaction_string: str) -> List[int]:
    # Split by semicolon and convert to integers
    types = [int(t.strip()) for t in interaction_string.split(';') if t.strip()]
    return types


In [None]:

def build_metabolic_hypergraph(
    csv_path: str,
    drug_to_idx: Dict[str, int],  # REUSE from Stage 1 - DO NOT REGENERATE
    save_dir: str = '/content/drive/MyDrive/MLHygnn/DB/hypergraphs/'
) -> Tuple[torch.Tensor, Dict[int, int], Dict]:



    print(f"\n{'='*60}")
    print(f"Building Metabolic Hypergraph")
    print(f"{'='*60}")

    os.makedirs(save_dir, exist_ok=True)

    # Load data
    print(f"Loading data from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"Loaded {len(df)} drugs")

    # Find interaction types column
    interaction_col = None
    for col in ['Interaction_Types', 'interaction_types', 'Interaction Types', 'types']:
        if col in df.columns:
            interaction_col = col
            break

    if interaction_col is None:
        raise ValueError(f"No interaction types column found. Available columns: {df.columns.tolist()}")

    print(f"Using columns: Drug_ID, {interaction_col}")
    print(f"Reusing {len(drug_to_idx)} drugs from Stage 1 drug_to_idx mapping")

    # Extract ALL interaction types (with duplicates per drug)
    all_interaction_types = set()  # For vocabulary only
    drug_interactions = {}          # Keep ALL interactions per drug (with duplicates)

    print("Parsing interaction types...")
    skipped_drugs = []

    for idx, row in df.iterrows():
        if idx % 100 == 0:
            print(f"  Processing drug {idx+1}/{len(df)}...", end='\r')

        #Gets the drug's ID and its interaction types (which are stored as a string, probably separated by semicolons like "1;2;3").
        drug_id = row['Drug_ID']
        interaction_string = str(row[interaction_col])

        # Check if drug exists in Stage 1 mapping
        if drug_id not in drug_to_idx:
            skipped_drugs.append(drug_id)
            continue

        # Parse interaction types (keeps duplicates naturally); parse_interaction_types : make sepreate by ;
        interaction_types = parse_interaction_types(interaction_string)

        # KEEP ALL interaction types (including duplicates if they exist)
        drug_interactions[drug_id] = interaction_types

        # Add to vocabulary (unique only)
        all_interaction_types.update(interaction_types)

    print(f"\n\nParsing complete!")

    if skipped_drugs:
        print(f"WARNING: Skipped {len(skipped_drugs)} drugs not in Stage 1 mapping")
        print(f"  Examples: {skipped_drugs[:5]}")

    # Create interaction type mapping (1 to 86  types) Creates a dictionary that assigns each unique interaction type a number (index). For example:  "inhibitor" → 0 , "substrate" → 1
    interaction_type_to_idx = {itype: idx for idx, itype in enumerate(sorted(all_interaction_types))}

    print(f"\n  Drugs processed: {len(drug_interactions)}/{len(drug_to_idx)}")
    print(f"  Unique interaction types: {len(interaction_type_to_idx)}")
    print(f"  Interaction type range: {min(all_interaction_types)} - {max(all_interaction_types)}")

    # Build edge list (with duplicates preserved)
    print("\nBuilding hypergraph edge list...")
    edge_list = []

    for drug_id, interaction_types in drug_interactions.items():
        drug_idx = drug_to_idx[drug_id]  # Use Stage 1 index

        # Creates connections (edges) between drugs and interaction types. If Drug 5 has interaction types ["substrate", "inhibitor"], it creates: [5,0] [5,1]
        for interaction_type in interaction_types:
            itype_idx = interaction_type_to_idx[interaction_type]
            edge_list.append([drug_idx, itype_idx])

    edge_list_tensor = torch.tensor(edge_list, dtype=torch.long)

    print(f"Total connections created: {len(edge_list):,}")
    print(f"Unique (drug, type) pairs: {len(set(map(tuple, edge_list))):,}")

    if len(edge_list) > len(set(map(tuple, edge_list))):
        duplication_factor = len(edge_list) / len(set(map(tuple, edge_list)))
        print(f"Duplication factor: {duplication_factor:.2f}x")
        print("  (This means some drug-interaction pairs appear multiple times - expected!)")

    # Statistics
    unique_drugs = len(torch.unique(edge_list_tensor[:, 0]))
    unique_types = len(torch.unique(edge_list_tensor[:, 1]))

    print(f"\nHypergraph Statistics:")
    print(f"  Drugs with interactions: {unique_drugs}/{len(drug_to_idx)}")
    print(f"  Interaction types present: {unique_types}/{len(interaction_type_to_idx)}")
    print(f"  Avg interaction types per drug: {len(edge_list) / unique_drugs:.2f}")
    print(f"  Avg drugs per interaction type: {len(edge_list) / unique_types:.2f}")

    # Show example connections
    print(f"\nFirst 10 connections:")
    for i in range(min(10, len(edge_list))):
        drug_idx, itype_idx = edge_list[i]
        # Find original drug ID and interaction type
        original_drug = [k for k, v in drug_to_idx.items() if v == drug_idx][0]
        original_type = [k for k, v in interaction_type_to_idx.items() if v == itype_idx][0]
        print(f"  [{original_drug} (idx={drug_idx}), Type {original_type} (idx={itype_idx})]")

    # Save files
    output_file = os.path.join(save_dir, 'metabolic_hypergraph.pt')
    torch.save(edge_list_tensor, output_file)
    print(f"\nHypergraph saved to: {output_file}")

    metadata = {
        'num_drugs': len(drug_to_idx),
        'num_drugs_with_interactions': unique_drugs,
        'num_interaction_types': len(interaction_type_to_idx),
        'num_connections': len(edge_list),
        'num_unique_pairs': len(set(map(tuple, edge_list))),
        'drug_to_idx': drug_to_idx,  # Same as Stage 1
        'interaction_type_to_idx': interaction_type_to_idx
    }

    metadata_file = os.path.join(save_dir, 'metabolic_hypergraph_metadata.pt')
    torch.save(metadata, metadata_file)
    print(f"Metadata saved to: {metadata_file}")

    return edge_list_tensor, interaction_type_to_idx, metadata



In [None]:

def verify_metabolic_hypergraph(file_path: str):
    """Verify the metabolic hypergraph structure."""
    print(f"\n{'='*60}")
    print(f"Verifying Metabolic Hypergraph")
    print(f"{'='*60}")

    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        return None

    data = torch.load(file_path)

    print(f"Shape: {data.shape}")
    print(f"\nFirst 20 connections:")
    print(data[:20])

    print(f"\nStructure:")
    print(f"  Column 0: Drug indices (nodes)")
    print(f"  Column 1: Interaction type indices (hyperedges)")

    print(f"\nUnique drugs (nodes): {len(torch.unique(data[:, 0]))}")
    print(f"Unique interaction types (hyperedges): {len(torch.unique(data[:, 1]))}")
    print(f"Total connections: {data.shape[0]:,}")

    # Duplicate analysis
    unique_pairs = set(map(tuple, data.tolist()))
    print(f"\nUnique (drug, type) pairs: {len(unique_pairs):,}")
    if data.shape[0] > len(unique_pairs):
        print(f"Duplication factor: {data.shape[0] / len(unique_pairs):.2f}x")

    # Load metadata
    metadata_path = file_path.replace('.pt', '_metadata.pt')
    if os.path.exists(metadata_path):
        metadata = torch.load(metadata_path)
        print(f"\nMetadata:")
        print(f"  Total drugs in vocabulary: {metadata['num_drugs']}")
        print(f"  Drugs with interactions: {metadata['num_drugs_with_interactions']}")
        print(f"  Coverage: {metadata['num_drugs_with_interactions']/metadata['num_drugs']*100:.1f}%")

        print(f"\n  Interaction types vocabulary size: {metadata['num_interaction_types']}")
        print(f"  Example interaction type mapping:")
        for i, (itype, idx) in enumerate(list(metadata['interaction_type_to_idx'].items())[:5]):
            print(f"    Interaction Type {itype} → Index {idx}")

    return data

In [None]:


if __name__ == "__main__":
    # Step 1: Load Stage 1 metadata to get drug_to_idx "Here we can use any one because the idea is to retrieve the IDs."
    print("Step 1: Loading Stage 1 chemical hypergraph metadata...")
    stage1_metadata_path = '/content/drive/MyDrive/MLHygnn/DB/hypergraphs/hyG_drug_drugbank_kmer_12_metadata.pt'


    stage1_metadata = torch.load(stage1_metadata_path)
    drug_to_idx = stage1_metadata['drug_to_idx']

    print(f"✓ Loaded {len(drug_to_idx)} drugs from Stage 1")
    print(f"  Example: {list(drug_to_idx.items())[:3]}")

    # Step 2: Build metabolic hypergraph
    print("\n" + "="*60)
    print("Step 2: Building Metabolic Hypergraph")
    print("="*60)

    edge_list, interaction_type_to_idx, metadata = build_metabolic_hypergraph(
        csv_path='/content/drive/MyDrive/MLHygnn/DB/OutPutPreprosseing/type-Interaction/drug_interaction_types.csv',
        drug_to_idx=drug_to_idx,  # reuse from Stage 1 !
        save_dir='/content/drive/MyDrive/MLHygnn/DB/hypergraphs/'
    )

    # Step 3: Verify output
    print("\n" + "="*60)
    print("Step 3: Verification")
    print("="*60)
    verify_metabolic_hypergraph(
        '/content/drive/MyDrive/MLHygnn/DB/hypergraphs/metabolic_hypergraph.pt'
    )



# Step 4: VERIFICATION - Check consistency
print("\n" + "="*60)
print("Verifying Drug Index Consistency")
print("="*60)

stage1_meta = torch.load('/content/drive/MyDrive/MLHygnn/DB/hypergraphs/hyG_drug_drugbank_kmer_12_metadata.pt')
stage2_meta = torch.load('/content/drive/MyDrive/MLHygnn/DB/hypergraphs/metabolic_hypergraph_metadata.pt')

# Check they use the same mapping
assert stage1_meta['drug_to_idx'] == stage2_meta['drug_to_idx'], "Drug indices don't match!"
print("✓ Drug indices are consistent between Stage 1 and Stage 2")
print(f"  Both use {len(stage1_meta['drug_to_idx'])} drugs with identical indices")

# Optional: Show a few examples to verify
print("\nExample drug mappings (should be identical):")
for drug_id in list(stage1_meta['drug_to_idx'].keys())[:5]:
    idx1 = stage1_meta['drug_to_idx'][drug_id]
    idx2 = stage2_meta['drug_to_idx'][drug_id]
    print(f"  {drug_id}: Stage1={idx1}, Stage2={idx2} {'✓' if idx1==idx2 else '✗'}")

Step 1: Loading Stage 1 chemical hypergraph metadata...
✓ Loaded 1709 drugs from Stage 1
  Example: [('DB00006', 0), ('DB00014', 1), ('DB00027', 2)]

Step 2: Building Metabolic Hypergraph

Building Metabolic Hypergraph
Loading data from: /content/drive/MyDrive/MLHygnn/DB/OutPutPreprosseing/type-Interaction/drug_interaction_types.csv
Loaded 1709 drugs
Using columns: Drug_ID, Interaction_Types
Reusing 1709 drugs from Stage 1 drug_to_idx mapping
Parsing interaction types...
  Processing drug 1701/1709...

Parsing complete!

  Drugs processed: 1709/1709
  Unique interaction types: 86
  Interaction type range: 1 - 86

Building hypergraph edge list...
Total connections created: 13,486
Unique (drug, type) pairs: 13,486

Hypergraph Statistics:
  Drugs with interactions: 1709/1709
  Interaction types present: 86/86
  Avg interaction types per drug: 7.89
  Avg drugs per interaction type: 156.81

First 10 connections:
  [DB00006 (idx=0), Type 12 (idx=11)]
  [DB00006 (idx=0), Type 4 (idx=3)]
  [DB