In [1]:
import requests
import json
import time
import math
from typing import Dict, List, Any
import os

In [2]:
# DisgenetClient class is responsible for creating a postprocessed csv file with the data obtained from the DisGeNET API.
class DisgenetClient:
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.api_calls = 0
        self.total_results = 0
        self.disease_count = 0
        self.gene_count = 0
        self.disease_file_path = "data/disease_ids.txt"
        self.data_file_path = "data/dga_data.csv"
        self.disgenet_base_url = "https://api.disgenet.com"
        self.disgenet_dga_route = "/api/v1/gda/summary"

    def create_csv_file(self) -> None:
        print("[INFO] Starting data creation process. This may take a few minutes...")

        start_time = time.time()
        if os.path.exists(self.data_file_path):
            os.remove(self.data_file_path)
            print(f"{self.data_file_path} has been deleted.")
            
        self.write_to_csv_file(["disease_id,disease_name,gene_id,score,first_pub_year,last_pub_year,ei,dsi,dpi"])

        disease_categories: List[str] = self.get_disease_categories()
        for disease_category in disease_categories:
            results: List[str] = []
            params: Dict = {"disease": self.get_disease_param(disease_category)}

            response = self.send_request(params, self.disgenet_dga_route)
            self.process_response(response, results)
            pages_left = math.ceil(response["paging"]["totalElements"] / 100) - 1

            for page in range(1, pages_left + 1):
                params["page_number"] = page
                response = self.send_request(params, self.disgenet_dga_route)
                self.process_response(response, results)

        self.postprocess_csv_file()
        end_time = time.time()
        self.csv_log(start_time, end_time)
        return

    def send_request(self, params: Dict, route: str) -> Any:
        self.api_calls += 1
        headers: Dict = {"Authorization": self.api_key, "accept": "application/json"}
        response = requests.get(
            url=self.disgenet_base_url + route, params=params, headers=headers
        )
        response = self.handle_api_rate_limit(response, params, route)
        return json.loads(response.text)

    def process_response(self, response: Dict, results: List[str]) -> None:
        for association in response["payload"]:
            disease_id: str = ""
            for id in association["diseaseVocabularies"]:
                if id.startswith("ICD10_") and len(id) == 9:
                    disease_id = id
                    break

            ei = round(association["ei"], 3) if association["ei"] is not None else 0
            dsi = (
                association["geneDSI"] if association["geneDSI"] is not None else 0.230
            )
            dpi = association["geneDPI"] if association["geneDPI"] is not None else 0

            results.append(
                f"{disease_id},{association['diseaseName']},{association['symbolOfGene']},{association['score']},{association['yearInitial']},{association['yearFinal']},{ei},{dsi},{dpi}"
            )
        self.write_to_csv_file(results)
        results.clear()
        return

    # Handle the API rate limit by waiting for the specified time and then retrying the request
    def handle_api_rate_limit(self, response: Dict, params: Dict, route: str) -> Any:
        if response.status_code == 429:
            while response.status_code == 429:
                wait_time = (
                    int(response.headers["x-rate-limit-retry-after-seconds"]) + 1
                )
                print(f"[INFO] Waiting {wait_time} seconds to restore rate limit")
                time.sleep(wait_time)

                self.api_calls += 1
                # The API key is for auth, can be obtained from disgenet personal site
                headers: Dict = {
                    "Authorization": self.api_key,
                    "accept": "application/json",
                }
                response = requests.get(
                    url=self.disgenet_base_url + route, params=params, headers=headers
                )

                if response.ok:
                    break
                else:
                    continue
        return response

    def write_to_csv_file(self, data: List[str]) -> None:
        with open(self.data_file_path, "a") as file:
            for line in data:
                file.write(line + "\n")
        return

    def postprocess_csv_file(self) -> None:
        with open(self.data_file_path, "r") as file:
            lines = file.readlines()

        header: str = lines[0]
        rows: List[str] = lines[1:]

        sorted_rows: List[str] = sorted(
            rows, key=lambda line: line.strip().split(",")[0]
        )
        self.total_results = len(sorted_rows)

        self.disease_count = len({line.strip().split(",")[0] for line in sorted_rows})
        self.gene_count = len({line.strip().split(",")[1] for line in sorted_rows})

        with open(self.data_file_path, "w") as file:
            file.write(header)
            file.writelines(sorted_rows)
        return

    # Get disease parameter from ICD code (https://icd.who.int/browse10/2019/en)
    def get_disease_param(self, disease_ids: str) -> str:
        category: str = disease_ids[0]
        disease_range: str = disease_ids[1:]
        start: int = int(disease_range.split("-")[0])
        end: int = int(disease_range.split("-")[1])

        disease_param: str = ""
        for i in range(start, end + 1):
            index: str = str(i) if i >= 10 else "0" + str(i)
            disease_param += "ICD10_" + category + index
            if i != end:
                disease_param += ","
        return disease_param

    def get_disease_categories(self) -> List[str]:
        with open(self.disease_file_path, "r") as file:
            categories = file.readlines()
        return categories

    def csv_log(self, start_time, end_time) -> None:
        elapsed_time = end_time - start_time
        minutes = int(elapsed_time // 60)
        seconds = int(elapsed_time % 60)

        print(
            f"[INFO] Data file created with {self.disease_count} diseases, {self.gene_count} genes and {self.total_results} associations."
        )
        if minutes > 0:
            print(
                f"[INFO] Completed in {minutes} minutes and {seconds} seconds using {self.api_calls} api calls."
            )
        else:
            print(
                f"[INFO] Completed in {seconds} seconds using {self.api_calls} api calls."
            )

In [3]:
disgenet_api_key = "567b3cda-b397-49cb-8763-9c9fae28ac47" 
disgenet_client = DisgenetClient(disgenet_api_key)
disgenet_client.create_csv_file()

[INFO] Starting data creation process. This may take a few minutes...
data/dga_data.csv has been deleted.
[INFO] Data file created with 1 diseases, 1 genes and 736 associations.
[INFO] Completed in 7 seconds using 8 api calls.
