In [None]:

from google.colab import drive
import os
import requests
import gzip
!curl -L "https://www.ebi.ac.uk/gwas/api/search/downloads/alternative/" -o "/content/drive/MyDrive/biological_data/gwas_catalog.tsv"

drive.mount('/content/drive')
DATA_DIR = "/content/drive/MyDrive/biological_data"
os.makedirs(DATA_DIR, exist_ok=True)

# Download helper
def download_file(url, filename):
    full_path = os.path.join(DATA_DIR, filename)
    if not os.path.exists(full_path):
        print(f"Downloading {filename}...")
        r = requests.get(url)
        with open(full_path, 'wb') as f:
            f.write(r.content)
    else:
        print(f"{filename} already exists.")
    return full_path

# Download required files
files = {
    "goa_human.gaf.gz": "http://current.geneontology.org/annotations/goa_human.gaf.gz",
    "go.obo": "http://purl.obolibrary.org/obo/go.obo",
    "desc2024.xml": "https://nlmpubs.nlm.nih.gov/projects/mesh/MESH_FILES/xmlmesh/desc2024.xml",
    "gwas_catalog.tsv": "https://www.ebi.ac.uk/gwas/api/search/downloads/alternative/gwas_catalog_v1.0.2-associations_e114_r2025-05-13.tsv",
    "Ensembl2Reactome_All_Levels.txt": "https://reactome.org/download/current/Ensembl2Reactome_All_Levels.txt",
    "ReactomePathways.txt": "https://reactome.org/download/current/ReactomePathways.txt",
    "9606.protein.links.full.v12.0.txt.gz": "https://stringdb-static.org/download/protein.links.full.v12.0/9606.protein.links.full.v12.0.txt.gz"
}

for fname, url in files.items():
    download_file(url, fname)

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  480M    0  480M    0     0  19.5M      0 --:--:--  0:00:24 --:--:-- 20.8M
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
goa_human.gaf.gz already exists.
go.obo already exists.
desc2024.xml already exists.
gwas_catalog.tsv already exists.
Ensembl2Reactome_All_Levels.txt already exists.
ReactomePathways.txt already exists.
9606.protein.links.full.v12.0.txt.gz already exists.


In [None]:
# Data Parsing & Processing

import pandas as pd
import torch
from collections import defaultdict

def load_string_edges(path, threshold=700):
    edges, nodes = [], {}
    with gzip.open(path, 'rt') as f:
        next(f)
        for line in f:
            p1, p2, *_, score = line.strip().split()
            score = int(score)
            if score >= threshold:
                for p in (p1, p2):
                    if p not in nodes:
                        nodes[p] = len(nodes)
                edges.append((nodes[p1], nodes[p2]))
    return torch.tensor(edges).t().contiguous(), nodes
def load_mesh_vocabulary(mesh_xml_path):
    valid_mesh_terms = set()
    try:
        tree = ET.parse(mesh_xml_path)
        root = tree.getroot()
        for descriptor in root.findall(".//DescriptorRecord"):
            mesh_id = descriptor.find("DescriptorUI").text  # e.g., D009369
            valid_mesh_terms.add(mesh_id)
        print(f"Loaded {len(valid_mesh_terms)} MeSH terms from {mesh_xml_path}")
    except Exception as e:
        print(f"Error parsing MeSH vocabulary file {mesh_xml_path}: {e}. Proceeding with empty MeSH vocabulary.")
    return valid_mesh_terms

def parse_go_annotations(path, node_map):
    gene_go = defaultdict(set)
    with gzip.open(path, 'rt') as f:
        for line in f:
            if line.startswith("!"): continue
            cols = line.strip().split("\t")
            if len(cols) >= 5:
                gene, go_term = cols[1], cols[4]
                if gene in node_map:
                    gene_go[node_map[gene]].add(go_term)
    return gene_go

def parse_reactome(path, node_map):
    df = pd.read_csv(path, sep='\t', header=None)
    reactome_map = defaultdict(set)
    for _, row in df.iterrows():
        gene_id, pathway = row[0], row[1]
        if gene_id in node_map:
            reactome_map[node_map[gene_id]].add(pathway)
    return reactome_map

def parse_gwas(path, node_map):
    df = pd.read_csv(path, sep='\t')
    snp_count = defaultdict(int)
    for _, row in df.iterrows():
        genes = str(row.get("MAPPED_GENE", "")).split(",")
        for gene in genes:
            gene = gene.strip()
            if gene in node_map:
                snp_count[node_map[gene]] += 1
    return snp_count

def parse_mesh_annotations(path, node_map):
    gene_mesh_associations = defaultdict(set)
    mesh_association_file = os.path.join(DATA_DIR, "gene_mesh_associations.tsv")

    if not os.path.exists(mesh_association_file):
        print(f"Error: MeSH association file not found at {mesh_association_file}. Please provide a valid gene-MeSH association file.")
        return gene_mesh_associations

    try:
        df = pd.read_csv(mesh_association_file, sep='\t', comment='#')
        # Assume columns: GeneID (e.g., 9606.ENSP...), MeSH_ID (e.g., D009369)
        for _, row in df.iterrows():
            gene_identifier = str(row['GeneID'])
            mesh_term = str(row['MeSH_ID'])
            # Only include valid MeSH terms from the vocabulary
            if gene_identifier in node_map and mesh_term in mesh_vocab:
                node_idx = node_map[gene_identifier]
                gene_mesh_associations[node_idx].add(mesh_term)
        print(f"Parsed {len(gene_mesh_associations)} gene-MeSH associations from {mesh_association_file}")
    except KeyError as e:
        print(f"Error parsing MeSH annotation file: Column {e} not found. Expected columns: GeneID, MeSH_ID. Proceeding without MeSH features.")
    except Exception as e:
        print(f"Error parsing MeSH annotation file {mesh_association_file}: {e}. Proceeding without MeSH features.")
    return gene_mesh_associations


def build_feature_matrix(node_map, go_data, reactome_data, gwas_data, mesh_data,
                         go_vocab, reactome_vocab, mesh_vocab):
    num_nodes = len(node_map)
    num_go_features = len(go_vocab)
    num_reactome_features = len(reactome_vocab)
    num_mesh_features = len(mesh_vocab)

    total_features = num_go_features + num_reactome_features + num_mesh_features + 1 # +1 for GWAS

    x = torch.zeros((num_nodes, total_features))

    for node_idx in range(num_nodes):
        current_feature_offset = 0

        # GO features
        for go_term in go_data.get(node_idx, []):
            if go_term in go_vocab:
                x[node_idx, go_vocab[go_term]] = 1
        current_feature_offset += num_go_features

        # Reactome features
        for r_term in reactome_data.get(node_idx, []):
            if r_term in reactome_vocab:
                x[node_idx, current_feature_offset + reactome_vocab[r_term]] = 1
        current_feature_offset += num_reactome_features

        # MeSH features
        for mesh_term in mesh_data.get(node_idx, []):
            if mesh_term in mesh_vocab:
                x[node_idx, current_feature_offset + mesh_vocab[mesh_term]] = 1

        # GWAS feature (always last)
        x[node_idx, -1] = gwas_data.get(node_idx, 0) / 10.0
    return x

In [None]:
!pip install -q torch-scatter torch-sparse torch-geometric torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu118.html

# Model & Training
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

class FusionGeneGNN(nn.Module):
    def __init__(self, num_go_features, num_reactome_features, num_mesh_features, # num_gwas_features is 1
                 common_embed_dim, num_attention_heads,
                 gnn_hidden_dim, gnn_out_dim):
        super().__init__()

        # Store the number of features for each modality for slicing
        self.num_go_features = num_go_features
        self.num_reactome_features = num_reactome_features
        self.num_mesh_features = num_mesh_features
        # GWAS features count is assumed to be 1 (the last column)

        # --- 1. Projection Layers (Detectives make initial summaries) ---
        # Each data type gets its own 'translator' to the common 'detective language' (common_embed_dim)
        self.go_proj = nn.Linear(self.num_go_features, common_embed_dim)
        self.reactome_proj = nn.Linear(self.num_reactome_features, common_embed_dim)
        self.mesh_proj = nn.Linear(self.num_mesh_features, common_embed_dim)
        self.gwas_proj = nn.Linear(1, common_embed_dim) # GWAS is a single feature

        # --- 2. Multi-Headed Attention Layer (Detectives collaborate and focus) ---
        # batch_first=True means input shape is (batch_size/num_genes, sequence_length/num_modalities, feature_dim)
        self.attention = nn.MultiheadAttention(embed_dim=common_embed_dim,
                                               num_heads=num_attention_heads,
                                               batch_first=True)

        self.conv1 = GCNConv(common_embed_dim, gnn_hidden_dim)
        self.conv2 = GCNConv(gnn_hidden_dim, gnn_out_dim)

    def forward(self, data):
        x_original, edge_index = data.x, data.edge_index


        current_idx = 0
        go_feat = x_original[:, current_idx : current_idx + self.num_go_features]
        current_idx += self.num_go_features

        reactome_feat = x_original[:, current_idx : current_idx + self.num_reactome_features]
        current_idx += self.num_reactome_features

        mesh_feat = x_original[:, current_idx : current_idx + self.num_mesh_features]

        gwas_feat = x_original[:, -1].unsqueeze(1) # GWAS is the last feature, ensure it's 2D (N_genes, 1)

        proj_go = F.relu(self.go_proj(go_feat))
        proj_reactome = F.relu(self.reactome_proj(reactome_feat))
        proj_mesh = F.relu(self.mesh_proj(mesh_feat))
        proj_gwas = F.relu(self.gwas_proj(gwas_feat))

        modalities_stacked = torch.stack([proj_go, proj_reactome, proj_mesh, proj_gwas], dim=1)

        attn_output, attn_weights = self.attention(modalities_stacked, modalities_stacked, modalities_stacked)

        fused_x = attn_output.mean(dim=1)

        x_gnn = F.relu(self.conv1(fused_x, edge_index))
        x_gnn = F.dropout(x_gnn, p=0.3, training=self.training)
        final_gene_embeddings = self.conv2(x_gnn, edge_index)

        return final_gene_embeddings

# Load & build
ppi_file = os.path.join(DATA_DIR, "9606.protein.links.full.v12.0.txt.gz")
edge_index, node_map = load_string_edges(ppi_file)

go = parse_go_annotations(os.path.join(DATA_DIR, "goa_human.gaf.gz"), node_map)
reactome = parse_reactome(os.path.join(DATA_DIR, "Ensembl2Reactome_All_Levels.txt"), node_map)
gwas = parse_gwas(os.path.join(DATA_DIR, "gwas_catalog.tsv"), node_map)
mesh = parse_mesh_annotations(os.path.join(DATA_DIR, "desc2024.xml"), node_map)

go_vocab = {t: i for i, t in enumerate(sorted(set().union(*[go[k] for k in go])))}
reactome_vocab = {t: i for i, t in enumerate(sorted(set().union(*[reactome[k] for k in reactome])))}
mesh_vocab = {t: i for i, t in enumerate(sorted(set().union(*[mesh[k] for k in mesh])))}

x = build_feature_matrix(node_map, go, reactome, gwas, mesh, go_vocab, reactome_vocab, mesh_vocab)
data = Data(x=x, edge_index=edge_index)

COMMON_EMBED_DIM = 128
NUM_ATTENTION_HEADS = 4
GNN_HIDDEN_DIM = 64
GNN_OUT_DIM = 32

num_go_features = len(go_vocab)
num_reactome_features = len(reactome_vocab)
num_mesh_features = len(mesh_vocab)

model = FusionGeneGNN(
    num_go_features=num_go_features,
    num_reactome_features=num_reactome_features,
    num_mesh_features=num_mesh_features,
    common_embed_dim=COMMON_EMBED_DIM,
    num_attention_heads=NUM_ATTENTION_HEADS,
    gnn_hidden_dim=GNN_HIDDEN_DIM,
    gnn_out_dim=GNN_OUT_DIM
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Train the model
model.train()
for epoch in range(50):
    optimizer.zero_grad()
    out = model(data)
    loss = out.norm(p=2)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# Save model and node map
model_save_path = os.path.join(DATA_DIR, "fusion_gene_gnn_bundle.pth")

# Bundle model architecture parameters
model_params = {
    "num_go_features": num_go_features,
    "num_reactome_features": num_reactome_features,
    "num_mesh_features": num_mesh_features,
    "common_embed_dim": COMMON_EMBED_DIM,
    "num_attention_heads": NUM_ATTENTION_HEADS,
    "gnn_hidden_dim": GNN_HIDDEN_DIM,
    "gnn_out_dim": GNN_OUT_DIM
}

# Create a dictionary to hold everything
full_model_bundle = {
    "model_state_dict": model.state_dict(),
    "model_params": model_params,
    "feature_matrix_x": x,            # The pre-processed feature matrix
    "edge_index": edge_index,      # The graph connectivity
    "node_map": node_map,          # Gene ID to index mapping
    "go_vocab": go_vocab,          # GO term to index mapping
    "reactome_vocab": reactome_vocab, # Reactome term to index mapping
    "mesh_vocab": mesh_vocab       # MeSH term to index mapping
}

# Save the bundle to a single .pth file
torch.save(full_model_bundle, model_save_path)
print(f"Full model bundle saved to {model_save_path}")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m53.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m55.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m56.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m
[?25h

  df = pd.read_csv(path, sep='\t')


Error: MeSH association file not found at /content/drive/MyDrive/biological_data/gene_mesh_associations.tsv. Please provide a valid gene-MeSH association file.




Epoch 1, Loss: 23.7780
Epoch 2, Loss: 59.3965
Epoch 3, Loss: 40.2925
Epoch 4, Loss: 85.2632
Epoch 5, Loss: 44.0041
Epoch 6, Loss: 30.9751
Epoch 7, Loss: 18.5333
Epoch 8, Loss: 15.8227
Epoch 9, Loss: 14.1565
Epoch 10, Loss: 12.1978
Epoch 11, Loss: 16.3273
Epoch 12, Loss: 8.2704
Epoch 13, Loss: 6.0403
Epoch 14, Loss: 4.7062
Epoch 15, Loss: 4.1399
Epoch 16, Loss: 4.1734
Epoch 17, Loss: 4.4436
Epoch 18, Loss: 4.6465
Epoch 19, Loss: 4.7648
Epoch 20, Loss: 4.7717
Epoch 21, Loss: 4.6187
Epoch 22, Loss: 4.2995
Epoch 23, Loss: 3.7044
Epoch 24, Loss: 2.9916
Epoch 25, Loss: 2.2410
Epoch 26, Loss: 1.8352
Epoch 27, Loss: 2.1429
Epoch 28, Loss: 2.5992
Epoch 29, Loss: 2.7398
Epoch 30, Loss: 2.5016
Epoch 31, Loss: 2.0229
Epoch 32, Loss: 1.5930
Epoch 33, Loss: 1.4154
Epoch 34, Loss: 1.3666
Epoch 35, Loss: 1.3971
Epoch 36, Loss: 1.5212
Epoch 37, Loss: 1.6143
Epoch 38, Loss: 1.4269
Epoch 39, Loss: 0.9639
Epoch 40, Loss: 0.7533
Epoch 41, Loss: 1.0393
Epoch 42, Loss: 1.1502
Epoch 43, Loss: 0.9956
Epoch 44,

In [None]:
# =========================
# 🔍 Inference Example API
# =========================
# Dummy logic: rank top-K genes by embedding similarity to a phenotype
def infer_causal_genes(phenotype_terms, top_k=5):
    model.eval()
    with torch.no_grad():
        embeddings = model(data)

    # Here, phenotype_terms would be a list of GO, Reactome or MeSH terms
    pheno_vec = torch.zeros(embeddings.size(1)) # Original way to init pheno_vec
    for t in phenotype_terms:
        if t in go_vocab:
            pheno_vec += x[:, go_vocab[t]].mean(0)
        elif t in reactome_vocab:
            pheno_vec += x[:, len(go_vocab) + reactome_vocab[t]].mean(0)
        elif t in mesh_vocab: # Added MeSH
            pheno_vec += x[:, len(go_vocab) + len(reactome_vocab) + mesh_vocab[t]].mean(0)


    similarities = torch.matmul(embeddings, pheno_vec)
    top_idxs = similarities.topk(min(top_k, embeddings.size(0))).indices # Ensure top_k is not > num embeddings
    reverse_map = {v: k for k, v in node_map.items()}
    return [reverse_map[i.item()] for i in top_idxs]

# Example usage:
example_phenotype_terms = ["GO:0008190", "GO:0008150"] # Original example term

candidate_genes = infer_causal_genes(example_phenotype_terms, top_k=5)
print("Predicted causal genes:", candidate_genes)

Predicted causal genes: ['9606.ENSP00000262305', '9606.ENSP00000329419', '9606.ENSP00000158762', '9606.ENSP00000000233', '9606.ENSP00000357048']


In [None]:
Predicted causal genes: ['9606.ENSP00000262305', '9606.ENSP00000329419', '9606.ENSP00000158762', '9606.ENSP00000000233', '9606.ENSP00000357048']

