# Disease-gene interaction prediction with graph neural networks

The goal of this project is to create a graph neural network for predicting disease-gene associations. Working with DisGeNET, a comprehensive database of these associations, you'll apply deep learning to an important challenge of bioinformatics. By choosing this project, you'll gain experience in the intersection of deep learning and bioinformatics while extracting valuable insights from real-world data.

Dataset:
https://www.disgenet.org/

Related GitHub repository:
https://github.com/pyg-team/pytorch_geometric

Related papers:
https://arxiv.org/abs/1607.00653
https://arxiv.org/abs/1611.07308

# Milestone 1
The 6th week must be delivered by the end!

## Containerization


Szükséges csomagok, szoftverek:
*   [Docker](https://docs.docker.com/engine/install/)
*   [Pytorch](https://hub.docker.com/r/pytorch/pytorch/tags)
* [pytorch_geometric](https://github.com/pyg-team/pytorch_geometric)
* [NVIDIA CUDA](https://hub.docker.com/r/nvidia/cuda)

A konténerhez szükséges **Dockerfile** és a **requirements.txt** a github-on található!

## Data acquisition

Szükséges csomagok

In [None]:
import requests
import time
import csv
import json
import pandas as pd

Szükséges változók deklalása

In [None]:
# API_KEY = "c89e2d9e-94b2-4b84-8d22-bb525e63b73b"
API_KEY = "ad6669df-65b6-45f9-8e02-7ba74e788acd"

params = {
    "page_number": 0,
    "type": "disease"
}

# Create a dictionary with HTTP headers
headers = {
    'Authorization': API_KEY,
    'accept': 'application/json'
}

# API endpoints
url_gda = "https://api.disgenet.com/api/v1/gda/summary"
url_disease = "https://api.disgenet.com/api/v1/entity/disease"

Kérések küldésének fg-ei

In [None]:
# Function to handle API requests with rate-limiting handling
def make_request(url, params, headers):
    retries = 0
    while retries < 5:
        try:
            response = requests.get(url, params=params, headers=headers, timeout=10)
            # If rate-limited (HTTP 429), retry after waiting
            if response.status_code == 429:
                wait_time = int(response.headers.get('x-rate-limit-retry-after-seconds', 60))
                print(f"Rate limit exceeded. Waiting {wait_time} seconds...")
                time.sleep(wait_time)
                retries += 1
            else:
                return response  # Return response if successful or error other than 429

        except requests.exceptions.RequestException as e:
            print(f"Request error: {e}")
            retries += 1
            time.sleep(2)  # Wait before retrying

    return None  # Return None if retries are exhausted

In [None]:
def get_max_pages(url, params=params, headers=headers):
  response = make_request(url, params=params, headers=headers)
  if response.ok:
      response_json = response.json()
      total_results = response_json.get("paging", {}).get("totalElements", 0)
      results_in_page = response_json.get("paging", {}).get("totalElementsInPage", 0)
      max_pages = min((total_results + results_in_page - 1) // results_in_page, 100)
  else:
      max_pages = 100
      print("Request failed, returned max_pages=100")
  return max_pages

In [None]:
def get_disease_ids(disease_type):
    disease_ids = []
    params['disease_free_text_search_string'] = disease_type

    for page in range(100):
      params['page_number'] = str(page)
      response_disease = make_request(url_disease, params, headers)
      if response_disease and response_disease.ok:
          response_disease_json = response_disease.json()
          data = response_disease_json.get("payload", [])
          for item in data:
              for code_info in item.get("diseaseCodes", []):
                if code_info.get("vocabulary") == "MONDO":
                  disease_ids.append(f'MONDO_{code_info.get("code")}')
      else:
          print(f"Failed to fetch data for page {page}. Status code: {response_disease_json.status_code}")
          break
    return disease_ids

In [None]:
def download_gda(disease_ids):
    gda_data = []
    params['disease'] = disease_ids

    for page in range(100):
        params['page_number'] = str(page)  # Különböző oldalak lekérése
        response_gda = make_request(url_gda, params, headers)
        if response_gda and response_gda.ok:
            response_json = response_gda.json()
            data = response_json.get("payload", [])
            gda_data.extend(data)
        else:
            print(f"Failed to fetch data for page {page}. Status code: {response_json.status_code}")
            break  # Ha nincs több oldal vagy hiba történik, kilépünk a ciklusból

    return gda_data


In [None]:
def download_all_gda(ids, chunk_size=100):
    all_data = []
    for i in range(0, len(ids), chunk_size):
        ids_chunk = ids[i:i + chunk_size]
        ids_string = '"' + ', '.join(ids_chunk) + '"'
        chunk_data = download_gda(ids_string)
        all_data.extend(chunk_data)
    df_gda = pd.DataFrame(all_data)
    df_gda.to_csv('GDA_df_raw.csv', index=False)
    print(f"All data saved to GDA_df_raw.csv")

In [None]:
ids = get_disease_ids("cancer")
print(len(ids))

Rate limit exceeded. Waiting 11 seconds...
556
MONDO_0005507, MONDO_0006244, MONDO_0002087, MONDO_0045054, MONDO_0006294, MONDO_0005517, MONDO_0002238, MONDO_0001487, MONDO_0009807, MONDO_0018531, MONDO_0001462, MONDO_0021071, MONDO_0005411, MONDO_0021112, MONDO_0700079, MONDO_0021089, MONDO_0005206, MONDO_0021138, MONDO_0008167, MONDO_0005806, MONDO_0003319, MONDO_0004956, MONDO_0004379, MONDO_0018875, MONDO_0012249, MONDO_0858997, MONDO_0005215, MONDO_0700078, MONDO_0005893, MONDO_0006850, MONDO_0021581, MONDO_0021085, MONDO_0850353, MONDO_0002367, MONDO_0006295, MONDO_0004669, MONDO_0013710, MONDO_0004987, MONDO_0021317, MONDO_0003050, MONDO_0002447, MONDO_0007648, MONDO_0005036, MONDO_0017896, MONDO_0005216, MONDO_0002095, MONDO_0005055, MONDO_0005580, MONDO_0001060, MONDO_0021545, MONDO_0013872, MONDO_0006517, MONDO_0016419, MONDO_0044937, MONDO_0004641, MONDO_0006490, MONDO_0004989, MONDO_0013806, MONDO_0006234, MONDO_0007958, MONDO_0004708, MONDO_0004358, MONDO_0003274, MONDO_00

In [None]:
unique_ids = list(set(ids))
download_all_gda(unique_ids)

Rate limit exceeded. Waiting 8 seconds...
Rate limit exceeded. Waiting 14 seconds...
Rate limit exceeded. Waiting 11 seconds...
Rate limit exceeded. Waiting 13 seconds...
Rate limit exceeded. Waiting 5 seconds...
Rate limit exceeded. Waiting 14 seconds...
Rate limit exceeded. Waiting 11 seconds...
Rate limit exceeded. Waiting 0 seconds...
Rate limit exceeded. Waiting 13 seconds...
Rate limit exceeded. Waiting 0 seconds...
Rate limit exceeded. Waiting 9 seconds...
All data saved to disgenet-GDA.csv


## Data preprocessing

In [None]:
import pandas as pd
import numpy as np
import re
from sklearn.preprocessing import OneHotEncoder, MultiLabelBinarizer, LabelEncoder
# import matplotlib.pyplot as plt
# import seaborn as sns
# import ast
# from sklearn.model_selection import train_test_split

In [None]:
GDA_df=pd.read_csv('GDA_df_raw.csv', sep=',')
GDA_df.head()

Unnamed: 0,assocID,symbolOfGene,geneNcbiID,geneEnsemblIDs,geneNcbiType,geneDSI,geneDPI,genepLI,geneProteinStrIDs,geneProteinClassIDs,...,diseaseClasses_DO,diseaseClasses_HPO,numCTsupportingAssociation,chemicalsIncludedInEvidence,numberPmidsWithChemsIncludedInEvidenceBySource,score,yearInitial,yearFinal,el,ei
0,5599912,TP53,7157,['ENSG00000141510'],protein-coding,0.256,0.957,0.99795,"['K7PPA8', 'A0A087WXZ1', 'A0A087WT22', 'Q53GA5...",['DTO_05007542'],...,"['disease of anatomical entity (7)', 'disease ...",[],0,,"[{'source': 'ALL', 'numPmids': 48}, {'source':...",1.0,2010.0,2019.0,,0.897188
1,5499445,CHEK2,11200,['ENSG00000183765'],protein-coding,0.421,0.913,2.7688e-14,['O96017'],['DTO_03300101'],...,"['disease of anatomical entity (7)', 'disease ...",[],0,,"[{'source': 'ALL', 'numPmids': 6}, {'source': ...",1.0,2002.0,2024.0,,0.883853
2,20438068,AKT1,207,['ENSG00000142208'],protein-coding,0.283,0.957,0.99533,"['P31749', 'B0LPE5', 'B3KVH4']",['DTO_03300101'],...,"['disease of cellular proliferation (14566)', ...",['Abnormality of the genitourinary system (001...,4,,"[{'source': 'ALL', 'numPmids': 95}, {'source':...",1.0,2005.0,2022.0,,0.969595
3,20474199,PTGS2,5743,['ENSG00000073756'],protein-coding,0.323,0.957,1.0,['P35354'],['DTO_05007624'],...,"['disease of cellular proliferation (14566)', ...",['Abnormality of the genitourinary system (001...,5,,"[{'source': 'ALL', 'numPmids': 24}, {'source':...",1.0,2000.0,2007.0,,0.909091
4,20465848,MYC,4609,['ENSG00000136997'],protein-coding,0.312,0.913,1.0,['P01106'],['DTO_05007542'],...,"['disease of cellular proliferation (14566)', ...",['Abnormality of the genitourinary system (001...,9,,"[{'source': 'ALL', 'numPmids': 27}, {'source':...",1.0,2007.0,2020.0,,0.956044


In [None]:
GDA_df = GDA_df.map(lambda x: np.nan if x == '[]' else x)
GDA_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 19177 entries, 0 to 19176
Data columns (total 27 columns):
 #   Column                                          Non-Null Count  Dtype  
---  ------                                          --------------  -----  
 0   assocID                                         19177 non-null  int64  
 1   symbolOfGene                                    19177 non-null  object 
 2   geneNcbiID                                      19177 non-null  int64  
 3   geneEnsemblIDs                                  18809 non-null  object 
 4   geneNcbiType                                    19177 non-null  object 
 5   geneDSI                                         19177 non-null  float64
 6   geneDPI                                         19177 non-null  float64
 7   genepLI                                         17425 non-null  float64
 8   geneProteinStrIDs                               18399 non-null  object 
 9   geneProteinClassIDs                    

In [None]:
for column in GDA_df.columns:
  print(f"{column}: {GDA_df[column].nunique()}")

assocID: 14131
symbolOfGene: 4719
geneNcbiID: 4719
geneEnsemblIDs: 4595
geneNcbiType: 7
geneDSI: 375
geneDPI: 23
genepLI: 3432
geneProteinStrIDs: 5810
geneProteinClassIDs: 21
geneProteinClassNames: 21
diseaseVocabularies: 354
diseaseName: 354
diseaseType: 1
diseaseUMLSCUI: 354
diseaseClasses_MSH: 72
diseaseClasses_UMLS_ST: 2
diseaseClasses_DO: 13
diseaseClasses_HPO: 25
numCTsupportingAssociation: 60
chemicalsIncludedInEvidence: 0
numberPmidsWithChemsIncludedInEvidenceBySource: 607
score: 19
yearInitial: 59
yearFinal: 40
el: 6
ei: 856


In [None]:
# Convert the IDs from object data type to integer format for better interpretation and processing in the GNN
label_encoder = LabelEncoder()
GDA_df['diseaseUMLSCUI_encoded'] = label_encoder.fit_transform(GDA_df['diseaseUMLSCUI'])

In [None]:
GDA_df = GDA_df.drop_duplicates(subset=['assocID']).reset_index(drop=True)

In [None]:
# Gene and disease mappings
gene_symbol_mapping = GDA_df[['geneNcbiID', 'symbolOfGene']].drop_duplicates().set_index('geneNcbiID').to_dict()['symbolOfGene']
disease_encoded_mapping = GDA_df[['diseaseUMLSCUI_encoded', 'diseaseUMLSCUI']].drop_duplicates().set_index('diseaseUMLSCUI_encoded').to_dict()['diseaseUMLSCUI']
disease_name_mapping = GDA_df[['diseaseUMLSCUI', 'diseaseName']].drop_duplicates().set_index('diseaseUMLSCUI').to_dict()['diseaseName']

In [None]:
GDA_df = GDA_df[[
    'geneNcbiID',
    'geneDSI',
    'geneDPI',
    'geneNcbiType',
    'diseaseUMLSCUI_encoded',
    'diseaseClasses_MSH',
    'diseaseClasses_UMLS_ST',
    'assocID',
    'score'
]]

In [None]:
# One-hot encoding geneNcbiType
enc = OneHotEncoder(handle_unknown='ignore', sparse_output=False)
encoded_geneNcbiType = enc.fit_transform(GDA_df[['geneNcbiType']])
columns = ['geneType_' + col.split('_')[-1] for col in enc.get_feature_names_out(['geneNcbiType'])]
encoded_df = pd.DataFrame(encoded_geneNcbiType, columns=columns)
GDA_df = pd.concat([GDA_df.reset_index(drop=True), encoded_df], axis=1).drop('geneNcbiType', axis=1)

In [None]:
# Process diseaseClasses_UMLS_ST and diseaseClasses_MSH
# Extracting IDs and names into a mapping
def extract_mapping(col):
    mapping = {}
    for entry in col:
        if pd.notnull(entry):
            matches = re.findall(r"'(.+?)\s+\((.+?)\)'", entry)
            for name, id in matches:
                mapping[id.strip()] = name.strip()
    return mapping

In [None]:
diseaseClass_mapping = extract_mapping(GDA_df['diseaseClasses_UMLS_ST'])
diseaseClass_mapping.update(extract_mapping(GDA_df['diseaseClasses_MSH']))
diseaseClass_mapping

{'T191': 'Neoplastic Process',
 'T047': 'Disease or Syndrome',
 'C04': 'Neoplasms',
 'C17': 'Skin and Connective Tissue Diseases',
 'C12': 'Urogenital Diseases',
 'C06': 'Digestive System Diseases',
 'C18': 'Nutritional and Metabolic Diseases',
 'C16': 'Congenital, Hereditary, and Neonatal Diseases and Abnormalities',
 'C19': 'Endocrine System Diseases',
 'C01': 'Infections',
 'C08': 'Respiratory Tract Diseases',
 'C05': 'Musculoskeletal Diseases',
 'C07': 'Stomatognathic Diseases',
 'C09': 'Otorhinolaryngologic Diseases',
 'C14': 'Cardiovascular Diseases',
 'C10': 'Nervous System Diseases',
 'C11': 'Eye Diseases',
 'C20': 'Immune System Diseases',
 'C15': 'Hemic and Lymphatic Diseases',
 'C23': 'Pathological Conditions, Signs and Symptoms'}

In [None]:
# Keep only IDs for simplicity
def clean_classes(entry):
    if isinstance(entry, (str, bytes)):
        return [match.strip() for match in re.findall(r'\((.*?)\)', entry)]
    else:
        return []

GDA_df['diseaseClasses_UMLS_ST'] = GDA_df['diseaseClasses_UMLS_ST'].apply(clean_classes)
GDA_df['diseaseClasses_MSH'] = GDA_df['diseaseClasses_MSH'].apply(clean_classes)

In [None]:
# Combine the two lists into a new column for handling missing values in diseaseClasses_MSH
GDA_df['diseaseClass'] = GDA_df.apply(
    lambda row: list(set(row['diseaseClasses_UMLS_ST'] + row['diseaseClasses_MSH'])),
    axis=1
)

In [None]:
# Using MultiLabelBinarizer because of the input being lists of disease codes
mlb = MultiLabelBinarizer()
encoded_diseaseClass = mlb.fit_transform(GDA_df['diseaseClass'])
enc_df = pd.DataFrame(encoded_diseaseClass, columns=['diseaseClass_' + cols for cols in mlb.classes_])
GDA_df = pd.concat([GDA_df.reset_index(drop=True), enc_df], axis=1)

In [None]:
disease_class_cols = [col for col in GDA_df.columns if col.startswith('diseaseClass')]
GDA_df[disease_class_cols].head()

Unnamed: 0,diseaseClasses_MSH,diseaseClasses_UMLS_ST,diseaseClass,diseaseClass_C01,diseaseClass_C04,diseaseClass_C05,diseaseClass_C06,diseaseClass_C07,diseaseClass_C08,diseaseClass_C09,...,diseaseClass_C14,diseaseClass_C15,diseaseClass_C16,diseaseClass_C17,diseaseClass_C18,diseaseClass_C19,diseaseClass_C20,diseaseClass_C23,diseaseClass_T047,diseaseClass_T191
0,"[C04, C17]",[T191],"[C17, C04, T191]",0,1,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,1
1,"[C04, C17]",[T191],"[C17, C04, T191]",0,1,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,1
2,"[C12, C04]",[T191],"[C04, C12, T191]",0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
3,"[C12, C04]",[T191],"[C04, C12, T191]",0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
4,"[C12, C04]",[T191],"[C04, C12, T191]",0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1


In [None]:
GDA_df = GDA_df.drop(['diseaseClasses_UMLS_ST', 'diseaseClasses_MSH', 'diseaseClass'], axis=1)

In [None]:
GDA_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14131 entries, 0 to 14130
Data columns (total 33 columns):
 #   Column                      Non-Null Count  Dtype  
---  ------                      --------------  -----  
 0   geneNcbiID                  14131 non-null  int64  
 1   geneDSI                     14131 non-null  float64
 2   geneDPI                     14131 non-null  float64
 3   diseaseUMLSCUI_encoded      14131 non-null  int64  
 4   assocID                     14131 non-null  int64  
 5   score                       14131 non-null  float64
 6   geneType_biological-region  14131 non-null  float64
 7   geneType_ncRNA              14131 non-null  float64
 8   geneType_other              14131 non-null  float64
 9   geneType_protein-coding     14131 non-null  float64
 10  geneType_pseudo             14131 non-null  float64
 11  geneType_snoRNA             14131 non-null  float64
 12  geneType_tRNA               14131 non-null  float64
 13  diseaseClass_C01            141

In [None]:
GDA_df.rename(columns={'geneNcbiID': 'geneID', 'diseaseUMLSCUI_encoded': 'diseaseID'}, inplace=True)
GDA_df.head()

Unnamed: 0,geneID,geneDSI,geneDPI,diseaseID,assocID,score,geneType_biological-region,geneType_ncRNA,geneType_other,geneType_protein-coding,...,diseaseClass_C14,diseaseClass_C15,diseaseClass_C16,diseaseClass_C17,diseaseClass_C18,diseaseClass_C19,diseaseClass_C20,diseaseClass_C23,diseaseClass_T047,diseaseClass_T191
0,7157,0.256,0.957,10,5599912,1.0,0.0,0.0,0.0,1.0,...,0,0,0,1,0,0,0,0,0,1
1,11200,0.421,0.913,10,5499445,1.0,0.0,0.0,0.0,1.0,...,0,0,0,1,0,0,0,0,0,1
2,207,0.283,0.957,69,20438068,1.0,0.0,0.0,0.0,1.0,...,0,0,0,0,0,0,0,0,0,1
3,5743,0.323,0.957,69,20474199,1.0,0.0,0.0,0.0,1.0,...,0,0,0,0,0,0,0,0,0,1
4,4609,0.312,0.913,69,20465848,1.0,0.0,0.0,0.0,1.0,...,0,0,0,0,0,0,0,0,0,1


In [None]:
print(f"Number of unique gene IDs: {len(GDA_df['geneID'].unique())}")
print(f"Number of unique disease IDs: {len(GDA_df['diseaseID'].unique())}")
print(f"Number of unique assocIDs: {len(GDA_df['assocID'].unique())}")

Number of unique gene IDs: 4719
Number of unique disease IDs: 354
Number of unique assocIDs: 14131


In [None]:
# Rewrite indices
unique_gene_ids = GDA_df['geneID'].unique()
unique_disease_ids = GDA_df['diseaseID'].unique()

# geneIds 0 to len(unique_gene_ids) and diseaseIds len(unique_gene_ids) to len(unique_gene_ids) + len(unique_disease_ids)
gene_id_to_idx = {id: idx for idx, id in enumerate(unique_gene_ids)}
disease_id_to_idx = {id: idx + len(unique_gene_ids) for idx, id in enumerate(unique_disease_ids)}
GDA_df['assocID'] = range(0, len(GDA_df))

GDA_df['geneID'] = GDA_df['geneID'].map(gene_id_to_idx)
GDA_df['diseaseID'] = GDA_df['diseaseID'].map(disease_id_to_idx)

In [None]:
GDA_df.to_csv('GDA_df_processed.csv', index=False)

## Graph data preparation for the model


### Prerequisite

In [None]:
!pip uninstall -y torch

Found existing installation: torch 2.1.1+cu121
Uninstalling torch-2.1.1+cu121:
  Successfully uninstalled torch-2.1.1+cu121


In [None]:
!pip install torch==2.1.1+cu121 -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==2.1.1+cu121
  Using cached https://download.pytorch.org/whl/cu121/torch-2.1.1%2Bcu121-cp310-cp310-linux_x86_64.whl (2200.7 MB)
Installing collected packages: torch
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 2.5.0+cu121 requires torch==2.5.0, but you have torch 2.1.1+cu121 which is incompatible.
torchvision 0.20.0+cu121 requires torch==2.5.0, but you have torch 2.1.1+cu121 which is incompatible.[0m[31m
[0mSuccessfully installed torch-2.1.1+cu121


In [None]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.1.1+cu121.html
!pip install pytorch-lightning --quiet

Looking in links: https://data.pyg.org/whl/torch-2.1.1+cu121.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/torch_scatter-2.1.2%2Bpt21cu121-cp310-cp310-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m62.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/torch_sparse-0.6.18%2Bpt21cu121-cp310-cp310-linux_x86_64.whl (5.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m49.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/torch_cluster-1.6.3%2Bpt21cu121-cp310-cp310-linux_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m55.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/torch_s

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics import AUROC
import torch_geometric as tg
import pytorch_lightning as pl

In [None]:
# url = 'https://raw.githubusercontent.com/your_username/your_repository/main/preprocessed_GDA_df_cancer.csv'
# GDA_df = pd.read_csv(url)
GDA_df = pd.read_csv('GDA_df_processed.csv')
GDA_df.head()

Unnamed: 0,geneID,geneDSI,geneDPI,diseaseID,assocID,score,geneType_biological-region,geneType_ncRNA,geneType_other,geneType_protein-coding,...,diseaseClass_C14,diseaseClass_C15,diseaseClass_C16,diseaseClass_C17,diseaseClass_C18,diseaseClass_C19,diseaseClass_C20,diseaseClass_C23,diseaseClass_T047,diseaseClass_T191
0,7157,0.256,0.957,10,5599912,1.0,0.0,0.0,0.0,1.0,...,0,0,0,1,0,0,0,0,0,1
1,11200,0.421,0.913,10,5499445,1.0,0.0,0.0,0.0,1.0,...,0,0,0,1,0,0,0,0,0,1
2,207,0.283,0.957,69,20438068,1.0,0.0,0.0,0.0,1.0,...,0,0,0,0,0,0,0,0,0,1
3,5743,0.323,0.957,69,20474199,1.0,0.0,0.0,0.0,1.0,...,0,0,0,0,0,0,0,0,0,1
4,4609,0.312,0.913,69,20465848,1.0,0.0,0.0,0.0,1.0,...,0,0,0,0,0,0,0,0,0,1


### GDA Graph Dataset Class

In [None]:
class GDADataset(tg.data.Dataset):
  def __init__(self, root, transform=None, pre_transform=None):
    super(GDADataset, self).__init__(root, transform, pre_transform)

  @property
  def raw_file_names(self):
    return ['GDA_df_processed.csv']

  @property
  def processed_file_names(self):
    return ['data.pt']

  def download(self):
    # Add download logic
    pass

  def process(self):
    '''Load data'''
    try:
        GDA_df = pd.read_csv(self.raw_paths[0])
    except FileNotFoundError:
        raise FileNotFoundError(f"{self.raw_paths[0]} not found. Ensure the file is in the correct directory.")
    except pd.errors.EmptyDataError:
        raise ValueError(f"{self.raw_paths[0]} is empty or not formatted correctly.")

    node_features = self._construct_node_features(GDA_df)

    edge_index = torch.tensor(np.array([GDA_df['geneID'].values, GDA_df['diseaseID'].values]), dtype=torch.long)
    edge_attr = torch.tensor(GDA_df['score'].values, dtype=torch.float).view(-1, 1)

    data = tg.data.Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr
    )
    link_split = tg.transforms.RandomLinkSplit(
        is_undirected=True,
        add_negative_train_samples=True,
        split_labels=False,
        num_val=0.1,
        num_test=0.1,
        neg_sampling_ratio=1.0,
    )
    train_data, val_data, test_data = link_split(data)

    # Save processed data for use in get method
    torch.save((train_data, val_data, test_data), self.processed_paths[0])
  def get(self, idx):
    data = torch.load(self.processed_paths[0])
    return data

  def len(self):
    return 1

  def _construct_node_features(self, GDA_df):
    '''Preprocess and construct node features for genes and diseases'''
    # Extract unique rows for genes and diseases
    gene_rows = GDA_df[['geneID', 'geneDSI', 'geneDPI'] + [col for col in GDA_df.columns if col.startswith('geneType')]]
    gene_rows = gene_rows.drop_duplicates(subset=['geneID']).drop(columns=['geneID'])

    disease_rows = GDA_df[['diseaseID'] + [col for col in GDA_df.columns if col.startswith('diseaseClass')]]
    disease_rows = disease_rows.drop_duplicates(subset=['diseaseID']).drop(columns=['diseaseID'])

    # Fill missing columns with zeros where needed
    gene_rows = gene_rows.assign(**{col: 0 for col in disease_rows.columns if col not in gene_rows.columns})
    disease_rows = disease_rows.assign(**{col: 0 for col in gene_rows.columns if col not in disease_rows.columns})

    # Convert features to numpy arrays and add node type indicator
    gene_features = np.hstack([gene_rows.values, np.ones((gene_rows.shape[0], 1))])
    disease_features = np.hstack([disease_rows.values, np.zeros((disease_rows.shape[0], 1))])

    # Combine gene and disease features into a single matrix and return as tensor
    node_features = np.vstack([gene_features, disease_features])
    return torch.tensor(node_features, dtype=torch.float)

### GDA Data Module with LightningDataModule

In [None]:
class GDADataModule(pl.LightningDataModule):
  def __init__(self, data_dir, batch_size=32, num_workers=2):
    super(GDADataModule, self).__init__()
    self.data_dir = data_dir
    self.batch_size = batch_size
    self.num_workers = num_workers

  def setup(self, stage=None):
    dataset = GDADataset(self.data_dir)
    self.train_data, self.val_data, self.test_data = dataset[0]

  def train_dataloader(self):
    return tg.loader.DataLoader([self.train_data], batch_size=1, shuffle=False, num_workers=self.num_workers)

  def val_dataloader(self):
    return tg.loader.DataLoader([self.val_data], batch_size=1, shuffle=False, num_workers=self.num_workers)

  def test_dataloader(self):
    return tg.loader.DataLoader([self.test_data], batch_size=1, shuffle=False, num_workers=self.num_workers)

# Milestone 2
This must be submitted by the end of week 9! Deadline: 3rd November 2024 *(Sunday)* 23:59:59

## Defining evaluation criteria
Binary Classification
1. AUROC
2. Binary Cross Entropy Loss


## Baseline Model
*GCN for Link Prediction*

### Model architecture
*training, testing and validation loops*

In [None]:
class GCNLinkPredictor(pl.LightningModule):
  def __init__(self, input_dim, hidden_dim, lr=1e-2):
    super().__init__()
    self.save_hyperparameters()

    # model architecture
    self.conv1 = tg.nn.GCNConv(input_dim, hidden_dim)
    self.conv2 = tg.nn.GCNConv(hidden_dim, hidden_dim)
    self.lr = lr

    # metrics
    self.train_auroc = AUROC(task="binary")
    self.val_auroc = AUROC(task="binary")
    self.test_auroc = AUROC(task="binary")

  def forward(self, x, edge_index, edge_label_index):
    x = F.relu(self.conv1(x, edge_index))
    x = F.relu(self.conv2(x, edge_index))

    # Get node embeddings for each node in the edge pairs
    src_nodes  = x[edge_label_index[0]]
    dst_nodes  = x[edge_label_index[1]]

    link_logits = torch.sum(src_nodes * dst_nodes, dim=-1)

    return link_logits

  def _shared_step(self, batch, batch_idx, stage):
    x, edge_index = batch.x, batch.edge_index
    edge_label_index = batch.edge_label_index
    edge_labels = batch.edge_label

    # Forward pass to get link logits
    link_logits = self.forward(x, edge_index, edge_label_index)

    loss = F.binary_cross_entropy_with_logits(link_logits, edge_labels.float())

    # Convert logits to binary predictions for metrics
    preds = torch.sigmoid(link_logits) >= 0.4

    # Log metrics based on the current stage (train, val, test)
    if stage == 'train':
      self.train_auroc(link_logits, edge_labels)
      self.log('train_loss', loss)
      self.log('train_auroc', self.train_auroc, prog_bar=True)

    elif stage == 'val':
      self.val_auroc(link_logits, edge_labels)
      self.log('val_loss', loss)
      self.log('val_auroc', self.val_auroc, prog_bar=True)

    elif stage == 'test':
      self.test_auroc(link_logits, edge_labels)
      self.log('test_loss', loss)
      self.log('test_auroc', self.test_auroc, prog_bar=True)

    return loss


  def training_step(self, batch, batch_idx):
    return self._shared_step(batch, batch_idx, stage='train')

  def validation_step(self, batch, batch_idx):
    return self._shared_step(batch, batch_idx, stage='val')

  def test_step(self, batch, batch_idx):
    return self._shared_step(batch, batch_idx, stage='test')

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.lr)

### Model Initialization

In [None]:
datamodule = GDADataModule(data_dir='/content/data/', batch_size=32, num_workers=2)
datamodule.setup()

Processing...
Done!


In [None]:
model = GCNLinkPredictor(
    input_dim=datamodule.train_data.x.shape[1],
    hidden_dim=64,
    lr=1e-2
)

### Train and Test

In [None]:
checkpoint_callback = pl.callbacks.ModelCheckpoint()
# early_stopping_callback = pl.callbacks.EarlyStopping(monitor="val_auroc", patience=4, mode="max", verbose=False)

trainer = pl.Trainer(
    max_epochs=20,
    log_every_n_steps=1,
    accelerator="gpu",
    devices=1,
    # logger=wandb_logger,
    callbacks=[checkpoint_callback]
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, datamodule)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name        | Type        | Params | Mode 
----------------------------------------------------
0 | conv1       | GCNConv     | 2.0 K  | train
1 | conv2       | GCNConv     | 4.2 K  | train
2 | train_auroc | BinaryAUROC | 0      | train
3 | val_auroc   | BinaryAUROC | 0      | train
4 | test_auroc  | BinaryAUROC | 0      | train
----------------------------------------------------
6.1 K     Trainable params
0         Non-trainable params
6.1 K     Total params
0.025     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


In [None]:
trainer.test(model, datamodule)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.473310649394989, 'test_auroc': 0.9837546348571777}]

# Final submission
Last education week

##  Incremental model development

## Advanced evaluation

## ML as a service (prototype)