# 0. Prequisite


In [11]:
import pandas as pd
import numpy as np
import json
import requests
import time
from sklearn.preprocessing import OneHotEncoder, MultiLabelBinarizer, LabelEncoder
import re
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, confusion_matrix, f1_score
import torch_geometric as pyg
import gradio as gr
import pickle

# 1. Data acquisition

In [22]:
API_KEY = "ad6669df-65b6-45f9-8e02-7ba74e788acd"
# API_KEY = "c89e2d9e-94b2-4b84-8d22-bb525e63b73b"

params = {"page_number": 0}

# Create a dictionary with HTTP headers
headers = {
    'Authorization': API_KEY,
    'accept': 'application/json'
}

# API endpoints
url_gda = "https://api.disgenet.com/api/v1/gda/summary"
url_disease = "https://api.disgenet.com/api/v1/entity/disease"

In [8]:
# Function to handle API requests with rate-limiting handling
def make_request(url, params, headers):
    retries = 0
    while retries < 5:
        try:
            response = requests.get(url, params=params, headers=headers, timeout=10)
            # If rate-limited (HTTP 429), retry after waiting
            if response.status_code == 429:
                wait_time = int(response.headers.get('x-rate-limit-retry-after-seconds', 60))
                print(f"Rate limit exceeded. Waiting {wait_time} seconds...")
                time.sleep(wait_time)
                retries += 1
            else:
                return response  # Return response if successful or error other than 429

        except requests.exceptions.RequestException as e:
            print(f"Request error: {e}")
            retries += 1
            time.sleep(2)  # Wait before retrying

    return None  # Return None if retries are exhausted

In [11]:
def get_disease_ids(disease_type):
    disease_ids = []
    params['disease_free_text_search_string'] = disease_type

    for page in range(100):
      params['page_number'] = str(page)
      response_disease = make_request(url_disease, params, headers)
      if response_disease and response_disease.ok:
          response_disease_json = response_disease.json()
          data = response_disease_json.get("payload", [])
          for item in data:
              for code_info in item.get("diseaseCodes", []):
                if code_info.get("vocabulary") == "MONDO":
                  disease_ids.append(f'MONDO_{code_info.get("code")}')
      else:
          print(f"Failed to fetch data for page {page}. Status code: {response_disease_json.status_code}")
          break
    return list(set(disease_ids))

In [13]:
def download_gda(disease_ids):
    gda_data = []
    params['disease'] = disease_ids

    for page in range(100):
        params['page_number'] = str(page)  # Different pages
        response_gda = make_request(url_gda, params, headers)
        if response_gda and response_gda.ok:
            response_json = response_gda.json()
            data = response_json.get("payload", [])
            gda_data.extend(data)
        else:
            print(f"Failed to fetch data for page {page}. Status code: {response_json.status_code}")
            break  # If no more page or error
    return gda_data

In [24]:
def download_all_gda(ids, chunk_size=100):
    all_data = []
    for i in range(0, len(ids), chunk_size):
        ids_chunk = ids[i:i + chunk_size]
        ids_string = '"' + ', '.join(ids_chunk) + '"'
        chunk_data = download_gda(ids_string)
        all_data.extend(chunk_data)
        print(f"Downloaded the {i}. chunk")
    df_gda = pd.DataFrame(all_data)
    df_gda.to_csv('GDA_df_raw.csv', index=False)
    print(f"All data saved to GDA_df_raw.csv")

In [15]:
disease_ids = get_disease_ids("disorder")
len(disease_ids)

Rate limit exceeded. Waiting 8 seconds...


5433

In [25]:
download_all_gda(disease_ids[:2000])

Rate limit exceeded. Waiting 3 seconds...
Rate limit exceeded. Waiting 0 seconds...
Downloaded the 0. chunk
Rate limit exceeded. Waiting 17 seconds...
Rate limit exceeded. Waiting 11 seconds...
Downloaded the 100. chunk
Rate limit exceeded. Waiting 18 seconds...
Rate limit exceeded. Waiting 0 seconds...
Rate limit exceeded. Waiting 11 seconds...
Downloaded the 200. chunk
Rate limit exceeded. Waiting 19 seconds...
Rate limit exceeded. Waiting 14 seconds...
Rate limit exceeded. Waiting 0 seconds...
Downloaded the 300. chunk
Rate limit exceeded. Waiting 17 seconds...
Rate limit exceeded. Waiting 0 seconds...
Rate limit exceeded. Waiting 14 seconds...
Rate limit exceeded. Waiting 0 seconds...
Downloaded the 400. chunk
Rate limit exceeded. Waiting 18 seconds...
Rate limit exceeded. Waiting 12 seconds...
Downloaded the 500. chunk
Rate limit exceeded. Waiting 18 seconds...
Rate limit exceeded. Waiting 8 seconds...
Downloaded the 600. chunk
Rate limit exceeded. Waiting 18 seconds...
Rate limit

# 2. Data Processing

In [77]:
GDA_cancer_df = pd.read_csv('GDA_df_cancer_raw.csv')
GDA_disorder_df = pd.read_csv('GDA_df_disorder_raw.csv')
GDA_df = pd.concat([GDA_cancer_df, GDA_disorder_df]).drop_duplicates(subset='assocID', keep='first')
GDA_df = GDA_df.map(lambda x: np.nan if x == '[]' else x)
GDA_df.info()

  GDA_disorder_df = pd.read_csv('GDA_df_disorder_raw.csv')


<class 'pandas.core.frame.DataFrame'>
Index: 33127 entries, 0 to 33857
Data columns (total 28 columns):
 #   Column                                          Non-Null Count  Dtype  
---  ------                                          --------------  -----  
 0   assocID                                         33127 non-null  int64  
 1   symbolOfGene                                    33127 non-null  object 
 2   geneNcbiID                                      33127 non-null  int64  
 3   geneEnsemblIDs                                  31908 non-null  object 
 4   geneNcbiType                                    33127 non-null  object 
 5   geneDSI                                         33127 non-null  float64
 6   geneDPI                                         33127 non-null  float64
 7   genepLI                                         29127 non-null  float64
 8   geneProteinStrIDs                               31231 non-null  object 
 9   geneProteinClassIDs                         

In [78]:
GDA_df = GDA_df[[
    'geneNcbiID',
    'geneDSI',
    'geneDPI',
    'geneNcbiType',
    'diseaseUMLSCUI',
    'diseaseClasses_MSH',
    'diseaseClasses_UMLS_ST',
    'diseaseType',
    'assocID',
    'score'
]]

In [79]:
# One-hot encoding geneNcbiType
enc = OneHotEncoder(handle_unknown='ignore', sparse_output=False)
encoded_geneNcbiType = enc.fit_transform(GDA_df[['geneNcbiType']])
columns = ['geneType_' + col.split('_')[-1] for col in enc.get_feature_names_out(['geneNcbiType'])]
encoded_df = pd.DataFrame(encoded_geneNcbiType, columns=columns)
GDA_df = pd.concat([GDA_df.reset_index(drop=True), encoded_df], axis=1).drop('geneNcbiType', axis=1)

# One-hot encoding diseaseType
encoded_diseaseType = enc.fit_transform(GDA_df[['diseaseType']])
encoded_df = pd.DataFrame(
    encoded_diseaseType,
    columns=enc.get_feature_names_out(['diseaseType'])
)
GDA_df = pd.concat([GDA_df.reset_index(drop=True), encoded_df], axis=1).drop('diseaseType', axis=1)

label_encoder = LabelEncoder()
GDA_df['diseaseUMLSCUI'] = label_encoder.fit_transform(GDA_df['diseaseUMLSCUI'])

In [80]:
# Keep only IDs for simplicity
def clean_classes(entry):
    if isinstance(entry, (str, bytes)):
        return [match.strip() for match in re.findall(r'\((.*?)\)', entry)]
    else:
        return []

GDA_df['diseaseClasses_UMLS_ST'] = GDA_df['diseaseClasses_UMLS_ST'].apply(clean_classes)
GDA_df['diseaseClasses_MSH'] = GDA_df['diseaseClasses_MSH'].apply(clean_classes)

In [81]:
# Combine the two lists into a new column for handling potential missing values in diseaseClasses_MSH
GDA_df['diseaseClass'] = GDA_df.apply(
    lambda row: list(set(row['diseaseClasses_UMLS_ST'] + row['diseaseClasses_MSH'])),
    axis=1
)
mlb = MultiLabelBinarizer()
encoded_diseaseClass = mlb.fit_transform(GDA_df['diseaseClass'])
enc_df = pd.DataFrame(encoded_diseaseClass, columns=['diseaseClass_' + cols for cols in mlb.classes_])
GDA_df = pd.concat([GDA_df.reset_index(drop=True), enc_df], axis=1)

GDA_df = GDA_df.drop(['diseaseClasses_UMLS_ST', 'diseaseClasses_MSH', 'diseaseClass'], axis=1)

In [82]:
GDA_df.rename(columns={'geneNcbiID': 'geneID', 'diseaseUMLSCUI': 'diseaseID'}, inplace=True)

In [83]:
print(f"Number of unique gene IDs: {len(GDA_df['geneID'].unique())}")
print(f"Number of unique disease IDs: {len(GDA_df['diseaseID'].unique())}")
print(f"Number of unique assocIDs: {len(GDA_df['assocID'].unique())}")

Number of unique gene IDs: 9052
Number of unique disease IDs: 2225
Number of unique assocIDs: 33127


In [None]:
GDA_df.to_csv('GDA_df.csv', index=False)

# 3. Graph Data Preparation
Creating Homogeneous Graph

In [5]:
def homogeneous_node_features(df):
    '''Preprocess and construct node features for genes and diseases for homogeneous graph'''
    # Extract unique rows for genes and diseases
    gene_rows = df[
        ['geneID', 'geneDSI', 'geneDPI'] +
        [col for col in df.columns if col.startswith('geneType')]]
    gene_rows = gene_rows.drop_duplicates(subset=['geneID']).drop(columns=['geneID'])

    disease_rows = df[
        ['diseaseID'] +
        [col for col in df.columns if col.startswith('diseaseClass')] +
        [col for col in df.columns if col.startswith('diseaseType')]]
    disease_rows = disease_rows.drop_duplicates(subset=['diseaseID']).drop(columns=['diseaseID'])

    # Fill missing columns with zeros where needed
    gene_rows = gene_rows.assign(**{col: 0 for col in disease_rows.columns if col not in gene_rows.columns})
    disease_rows = disease_rows.assign(**{col: 0 for col in gene_rows.columns if col not in disease_rows.columns})

    # Convert features to numpy arrays and add node type indicator
    gene_features = np.hstack([gene_rows.values, np.ones((gene_rows.shape[0], 1))]) # 1 indicates gene
    disease_features = np.hstack([disease_rows.values, np.zeros((disease_rows.shape[0], 1))]) # 0 indicates disease

    # Combine gene and disease features into a single matrix and return as tensor
    node_features = np.vstack([gene_features, disease_features])

    return torch.tensor(node_features, dtype=torch.float)

In [6]:
def prepare_homogeneous_graph(df):
    '''Prepare a homogeneous graph for PyTorch Geometric'''
    # Map IDs to separate index ranges
    unique_gene_ids = df['geneID'].unique()
    unique_disease_ids = df['diseaseID'].unique()

    # geneIds 0 to len(unique_gene_ids) and diseaseIds len(unique_gene_ids) to len(unique_gene_ids) + len(unique_disease_ids)
    gene_id_to_idx = {id: idx for idx, id in enumerate(unique_gene_ids)}
    disease_id_to_idx = {id: idx + len(unique_gene_ids) for idx, id in enumerate(unique_disease_ids)}

    df['geneID'] = df['geneID'].map(gene_id_to_idx)
    df['diseaseID'] = df['diseaseID'].map(disease_id_to_idx)

    # Construct node features
    node_features = homogeneous_node_features(df)

    # Create edge indices
    edge_index = torch.tensor(np.array([df['geneID'].values, df['diseaseID'].values]), dtype=torch.long)
    edge_index = pyg.utils.to_undirected(edge_index) # Make edges bidirectional

    # Homogeneous Graph
    graph_data = pyg.data.Data(
        x = node_features,
        edge_index = edge_index,
    )

    return graph_data

# 4. Models
- GCN_DP
- GCN_MLP
- GraphSAGE_MLP
- GIN_MLP

In [12]:
class GCN_DP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.2):
        super().__init__()
        self.conv1 = pyg.nn.GCNConv(input_dim, hidden_dim)
        self.conv2 = pyg.nn.GCNConv(hidden_dim, output_dim)
        self.dropout = dropout

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x

    def decode(self, z, edge_label_index):
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        return (src * dst).sum(dim=-1)

In [13]:
class GCN_MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.3):
        super().__init__()
        self.conv1 = pyg.nn.GCNConv(input_dim, hidden_dim)
        self.conv2 = pyg.nn.GCNConv(hidden_dim, output_dim)
        self.dropout = dropout

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(output_dim * 2, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=dropout),
            torch.nn.Linear(hidden_dim, 1)
        )

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x

    def decode(self, z, edge_label_index):
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        edge_features = torch.cat([src, dst], dim=-1)
        return self.mlp(edge_features).view(-1)

In [14]:
class GraphSAGE_MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.5):
        super().__init__()
        self.conv1 = pyg.nn.SAGEConv(input_dim, hidden_dim)
        self.conv2 = pyg.nn.SAGEConv(hidden_dim, output_dim)
        self.bn1 = torch.nn.BatchNorm1d(hidden_dim)
        self.bn2 = torch.nn.BatchNorm1d(output_dim)
        self.dropout = dropout

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(output_dim * 2, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=dropout),
            torch.nn.Linear(hidden_dim, 1)
        )

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        return x

    def decode(self, z, edge_label_index):
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        edge_features = torch.cat([src, dst], dim=-1)
        return self.mlp(edge_features).view(-1)

In [15]:
class GIN_MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.5):
        super().__init__()
        self.conv1 = pyg.nn.GINConv(torch.nn.Sequential(
            torch.nn.Linear(input_dim, 2 * hidden_dim),
            torch.nn.BatchNorm1d(2 * hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(2 * hidden_dim, hidden_dim)
        ), train_eps=True)
        self.conv2 = pyg.nn.GINConv(torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, 2 * hidden_dim),
            torch.nn.BatchNorm1d(2 * hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(2 * hidden_dim, output_dim)
        ), train_eps=True)
        self.bn1 = torch.nn.BatchNorm1d(hidden_dim)
        self.bn2 = torch.nn.BatchNorm1d(output_dim)
        self.dropout = dropout

        # Define MLP for decoding
        self.mlp_decoder = torch.nn.Sequential(
            torch.nn.Linear(output_dim * 2, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=dropout),
            torch.nn.Linear(hidden_dim, 1)
        )

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        return x

    def decode(self, z, edge_label_index):
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        edge_features = torch.cat([src, dst], dim=-1)
        return self.mlp_decoder(edge_features).view(-1)

# Trainer

In [None]:
class Trainer:
    def __init__(self, model, optimizer, device='cpu', save_path='best_model.pth'):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.device = device
        self.save_path = save_path
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        self.criterion = torch.nn.BCEWithLogitsLoss()

    def train_epoch(self, data):
        self.model.train()
        self.optimizer.zero_grad()

        # Encode node embeddings
        data = data.to(self.device)
        z = self.model.encode(data.x, data.edge_index)

        # Decode edge predictions
        pos_out = self.model.decode(z, data.pos_edge_label_index).view(-1)
        neg_out = self.model.decode(z, data.neg_edge_label_index).view(-1)

        # Compute loss
        pos_loss = self.criterion(pos_out, data.pos_edge_label.float())
        neg_loss = self.criterion(neg_out, data.neg_edge_label.float())
        loss = pos_loss + neg_loss
        loss.backward()
        self.optimizer.step()
        return loss.item(), pos_loss.item(), neg_loss.item()

    @torch.no_grad()
    def evaluate(self, data):
        self.model.eval()

        data = data.to(self.device)
        z = self.model.encode(data.x, data.edge_index)

        pos_out = self.model.decode(z, data.pos_edge_label_index).view(-1)
        neg_out = self.model.decode(z, data.neg_edge_label_index).view(-1)
        scores = torch.cat([pos_out, neg_out]).sigmoid()
        labels = torch.cat([data.pos_edge_label, data.neg_edge_label])

        y_true = labels.cpu().numpy()
        y_pred_probs = scores.cpu().numpy()

        auc = roc_auc_score(y_true, y_pred_probs)

        # Dynamically evaluate performance across thresholds
        best_f1, best_threshold, best_cm = 0, 0, None

        for threshold in [i / 100 for i in range(1, 100)]:
            y_pred = (y_pred_probs >= threshold).astype(int)
            f1 = f1_score(y_true, y_pred)
            cm = confusion_matrix(y_true, y_pred)
            if f1  > best_f1:
                best_f1 = f1
                best_threshold = threshold
                best_cm = cm

        return {
            'f1': best_f1,
            'auc': auc,
            'threshold': best_threshold,
            'cm': best_cm
        }

    def fit(self, train_data, val_data, num_epochs=100, early_stopping_patience=10):
        best_val_f1 = 0
        patience_counter = 0

        for epoch in range(1, num_epochs + 1):
            train_loss, pos_loss, neg_loss = self.train_epoch(train_data)
            val_metric = self.evaluate(val_data)

            # Save the best model
            if val_metric['f1'] > best_val_f1:
                best_val_f1 = val_metric['f1']
                patience_counter = 0
                torch.save({'model_state_dict': self.model.state_dict(),
                'best_threshold': val_metric['threshold']}, self.save_path)
            else:
                patience_counter += 1
                if patience_counter >= early_stopping_patience:
                    print("Early stopping triggered.")
                    break
            self.scheduler.step()
            if epoch % 5 == 0:
                print(f"{'Epoch':<6} {'Train Loss':<12} {'Pos Loss':<10} {'Neg Loss':<10} {'Val F1':<8} {'Val AUC':<9} {'Threshold':<10}")
                print(f"  {epoch:<6} {train_loss:<12.4f} {pos_loss:<10.4f} {neg_loss:<10.4f} {val_metric['f1']:<8.4f} {val_metric['auc']:<9.4f} {val_metric['threshold']:<10.2f}")
                print(f"{'=' * 70}")
                print(f"Confusion Matrix:\n{val_metric['cm']}")
                print(f"{'=' * 70}")

    def test(self, test_data):
        checkpoint = torch.load(self.save_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        best_threshold = checkpoint['best_threshold']
        test_metric = self.evaluate(test_data)
        print(f"Test F1: {test_metric['f1']:.4f}")
        print(f"Test AUC: {test_metric['auc']:.4f}")
        print(f"Best Threshold: {best_threshold:.2f}")
        print(f"Confusion Matrix:\n{test_metric['cm']}")

# Initialization and Data Splitting

In [12]:
GDA_df = pd.read_csv('GDA_df.csv')
homogeneous_graph = prepare_homogeneous_graph(GDA_df)

split = pyg.transforms.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    is_undirected=True,
    add_negative_train_samples=True,
    neg_sampling_ratio=1.0,
    split_labels=True)
train_data, val_data, test_data = split(homogeneous_graph)

print(f'Train data: {train_data}')
print(f'Val data: {val_data}')
print(f'Test data: {test_data}')

Train data: Data(x=[11277, 48], edge_index=[2, 53006], pos_edge_label=[26503], pos_edge_label_index=[2, 26503], neg_edge_label=[26503], neg_edge_label_index=[2, 26503])
Val data: Data(x=[11277, 48], edge_index=[2, 53006], pos_edge_label=[3312], pos_edge_label_index=[2, 3312], neg_edge_label=[3312], neg_edge_label_index=[2, 3312])
Test data: Data(x=[11277, 48], edge_index=[2, 59630], pos_edge_label=[3312], pos_edge_label_index=[2, 3312], neg_edge_label=[3312], neg_edge_label_index=[2, 3312])


In [14]:
input_dim = homogeneous_graph.num_node_features
hidden_dim = 128
output_dim = 64
dropout = 0.2
wd = 1e-4
lr = 1e-3
num_epochs = 50

In [15]:
model = GCN_DP(homogeneous_graph.num_node_features, hidden_dim, output_dim, dropout)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr, weight_decay=wd)

In [16]:
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    save_path='test.pth'
)

# Train and Evaluate

In [17]:
trainer.fit(train_data, val_data, num_epochs=num_epochs, early_stopping_patience=10)

Epoch  Train Loss   Pos Loss   Neg Loss   Val F1   Val AUC   Threshold 
  5      1.2038       0.3052     0.8987     0.8067   0.8832    0.61      
Confusion Matrix:
[[2689  623]
 [ 652 2660]]
Epoch  Train Loss   Pos Loss   Neg Loss   Val F1   Val AUC   Threshold 
  10     1.1812       0.3113     0.8699     0.8206   0.8992    0.61      
Confusion Matrix:
[[2772  540]
 [ 632 2680]]
Epoch  Train Loss   Pos Loss   Neg Loss   Val F1   Val AUC   Threshold 
  15     1.1776       0.3026     0.8750     0.8293   0.9084    0.61      
Confusion Matrix:
[[2736  576]
 [ 558 2754]]
Epoch  Train Loss   Pos Loss   Neg Loss   Val F1   Val AUC   Threshold 
  20     1.1476       0.2365     0.9111     0.8602   0.9343    0.63      
Confusion Matrix:
[[2748  564]
 [ 387 2925]]
Epoch  Train Loss   Pos Loss   Neg Loss   Val F1   Val AUC   Threshold 
  25     1.1185       0.2249     0.8935     0.8691   0.9405    0.63      
Confusion Matrix:
[[2895  417]
 [ 446 2866]]
Epoch  Train Loss   Pos Loss   Neg Loss   Val

In [18]:
trainer.test(test_data)

  checkpoint = torch.load(self.save_path)


Test F1: 0.8777
Test AUC: 0.9368
Best Threshold: 0.60
Confusion Matrix:
[[2911  401]
 [ 408 2904]]


# Gradio

In [1]:
def fetch_gene_features(gene_id, api_key):
    url = "https://api.disgenet.com/api/v1/entity/gene"
    params = {"gene_ncbi_id": gene_id}
    headers = {'Authorization': api_key, 'accept': 'application/json'}
    response = requests.get(url, headers=headers, params=params)
    if response.ok:
          response_json = response.json()
          data = response_json.get("payload", [])
          records = []
          for item in data:
              records.append({
                  "geneDSI": item.get("dsi", 0),
                  "geneDPI": item.get("dpi", 0),
                  "geneNcbiType": item.get("ncbi_type", "unknown"),
              })
          return pd.DataFrame(records)

    else:
        raise ValueError(f"Error fetching gene features: {response.status_code}")

In [2]:
def fetch_disease_features(disease_id, api_key):
    url = "https://api.disgenet.com/api/v1/entity/disease"
    params = {"disease": disease_id}
    headers = {'Authorization': api_key, 'accept': 'application/json'}
    response = requests.get(url, headers=headers)
    response = requests.get(url, headers=headers, params=params)
    if response.ok:
          response_json = response.json()
          data = response_json.get("payload", [])
          records = []
          for item in data:
              records.append({
                  "diseaseType": item.get("type", "unknown"),
                  "diseaseClasses_MSH": item.get("diseaseClasses_MSH", "unknown"),
                  "diseaseClasses_UMLS_ST": item.get("diseaseClasses_UMLS_ST", "unknown"),
              })

          return pd.DataFrame(records)

    else:
        raise ValueError(f"Error fetching gene features: {response.status_code}")


In [3]:
def extract_parentheses_values(entry):
    """Extract values enclosed in parentheses from a string or list of strings."""
    if isinstance(entry, str):
        return re.findall(r'\((.*?)\)', entry)
    elif isinstance(entry, list):
        result = []
        for item in entry:
            result.extend(re.findall(r'\((.*?)\)', item))
        return result
    return []

In [4]:
def process_gene_disease_features(df, disease_features, gene_features):

    gene_columns = ['geneDSI', 'geneDPI'] + [col for col in df.columns if col.startswith('geneType')]
    disease_columns = [col for col in df.columns if col.startswith('diseaseClass') or col.startswith('diseaseType')]

    all_columns = gene_columns + disease_columns + ['nodetype']
    base_template = {col: 0 for col in all_columns}

    # Process the gene features
    processed_gene_features = base_template.copy()
    processed_gene_features['geneDSI'] = gene_features['geneDSI'].iloc[0]
    processed_gene_features['geneDPI'] = gene_features['geneDPI'].iloc[0]

    # Dynamically add the one-hot encoded gene type
    gene_type_column = f"geneType_{gene_features['geneNcbiType'].iloc[0]}"
    if gene_type_column in processed_gene_features:
        processed_gene_features[gene_type_column] = 1

    processed_gene_features['nodetype'] = 1

    processed_disease_features = base_template.copy()

    # Dynamically add the one-hot encoded disease type
    disease_type_column = f"diseaseType_{disease_features['diseaseType'].iloc[0]}"
    if disease_type_column in processed_disease_features:
        processed_disease_features[disease_type_column] = 1

    msh_classes = extract_parentheses_values(disease_features['diseaseClasses_MSH'].iloc[0])
    umls_classes = extract_parentheses_values(disease_features['diseaseClasses_UMLS_ST'].iloc[0])
    disease_classes = set(msh_classes + umls_classes)

    for disease_class in disease_classes:
        disease_class_column = f"diseaseClass_{disease_class}"
        if disease_class_column in processed_disease_features:
            processed_disease_features[disease_class_column] = 1

    # Add the nodetype indicator for diseases
    processed_disease_features['nodetype'] = 0  # 0 for diseases

    # Convert processed features to DataFrames for consistency
    processed_gene_df = pd.DataFrame([processed_gene_features])
    processed_disease_df = pd.DataFrame([processed_disease_features])

    gene_node_feature = torch.tensor(processed_gene_df.values, dtype=torch.float)
    disease_node_feature = torch.tensor(processed_disease_df.values, dtype=torch.float)

    return gene_node_feature, disease_node_feature

In [5]:
def check_and_assign_ids(graph_data, gene_node_feature, disease_node_feature):
    """
    Check if the provided gene and disease features exist in the graph.
    If not, add them to the graph, ensuring gene nodes are added after the current gene nodes
    and disease nodes after the current disease nodes. Assign IDs accordingly.
    """
    node_features = graph_data.x
    nodetype_indicator = node_features[:, -1]

    gene_features = node_features[nodetype_indicator == 1]
    disease_features = node_features[nodetype_indicator == 0]

    # Check for gene node existence
    gene_id = None
    if (gene_features == gene_node_feature).all(dim=1).any():
        gene_id = torch.where((gene_features == gene_node_feature).all(dim=1))[0][0].item()
    else:
        gene_id = gene_features.size(0)
        node_features = torch.cat([node_features, gene_node_feature], dim=0)

    # Check for disease node existence
    disease_id = None
    if (disease_features == disease_node_feature).all(dim=1).any():
        disease_id = torch.where((disease_features == disease_node_feature).all(dim=1))[0][0].item() + gene_features.size(0)
    else:
        disease_id = node_features.size(0)  # Append disease after all existing nodes
        node_features = torch.cat([node_features, disease_node_feature], dim=0)

    graph_data.x = node_features

    return graph_data, gene_id, disease_id

In [9]:
# Define global variables
GDA_df = pd.read_csv('GDA_df.csv')
df = GDA_df.copy()

# Load the graph and model
with open("graph_data.pkl", "rb") as f:
    graph_data = pickle.load(f)

checkpoint = torch.load("GIN_MLP.pth")
model = GIN_MLP(input_dim=48, hidden_dim=128, output_dim=64)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
best_threshold = checkpoint['best_threshold']

# Define API Key (You should replace it with a valid one)
API_KEY = "ad6669df-65b6-45f9-8e02-7ba74e788acd"

# Prediction function
def gradio_predict(gene_id: int, disease_id: str):
    try:
        # Strip whitespace from inputs
        gene_id = gene_id.strip()
        disease_id = disease_id.strip()

        # Fetch features for the gene and disease
        gene_features = fetch_gene_features(gene_id, API_KEY)
        disease_features = fetch_disease_features(disease_id, API_KEY)

        # Process the fetched features into node features
        processed_gene_feature, processed_disease_feature = process_gene_disease_features(
            df, disease_features, gene_features
        )

        # Check and assign IDs in the graph
        updated_graph, gene_idx, disease_idx = check_and_assign_ids(
            graph_data, processed_gene_feature, processed_disease_feature
        )

        # Prepare edge indices for prediction
        edge_label_index = torch.tensor([[gene_idx], [disease_idx]], dtype=torch.long)

        # Predict association
        with torch.no_grad():
            z = model.encode(updated_graph.x, updated_graph.edge_index)  # Encode node embeddings
            raw_score = model.decode(z, edge_label_index).sigmoid().item()  # Get the sigmoid of the predicted score

        # Apply the best threshold
        prediction = 1 if raw_score >= best_threshold else 0

        # Return results
        result = f"Prediction: {'Exists' if prediction == 1 else 'Does not exist'}"
        raw_score_str = f"Raw Score: {raw_score:.4f}"
        return result, raw_score_str

    except Exception as e:
        return f"Error: {str(e)}", ""

# Gradio interface
interface = gr.Interface(
    fn=gradio_predict,
    inputs=[
        gr.Textbox(label="Gene ID, the Entrez Id from Disgenet (e.g., 7124)", placeholder="Enter Gene ID here"),
        gr.Textbox(label="Disease ID (e.g., MONDO_0000728)", placeholder="Enter Disease ID here")
    ],
    outputs=[
        gr.Textbox(label="Prediction"),
        gr.Textbox(label="Raw Score")
    ],
    title="Gene-Disease Association Prediction",
    description="Enter Gene and Disease IDs to predict the association using the trained model."
)

# Launch Gradio interface
interface.launch(share=True)


  checkpoint = torch.load("GIN_MLP.pth")


NameError: name 'GIN_MLP' is not defined