In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Installation commands
%%capture
!pip install feature-engine
!pip install umap-learn

# Standard library imports
import os
from typing import List, Union

# Third-party imports for data manipulation, machine learning, and plotting
import numpy as np
import pandas as pd
from tqdm import tqdm
import networkx as nx
import matplotlib.pyplot as plt

# PyTorch imports
import torch
from torch import nn

# Feature engineering tools
from feature_engine.imputation import CategoricalImputer
from feature_engine.encoding import OneHotEncoder

# NLP transformers
from transformers import AutoTokenizer, AutoModel

# Dimensionality reduction and preprocessing
from sklearn.preprocessing import RobustScaler
from sklearn.decomposition import PCA
from umap import UMAP

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [None]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

### Data dealing

In [None]:
diseases_table_df = pd.read_csv('/content/drive/MyDrive/Projects/GNN-Gene-Disease/Data/processed_data/df_final_processed_without_Gemini.csv')
genes_diseases_df = pd.read_csv('/content/drive/MyDrive/Projects/GNN-Gene-Disease/Data/processed_data/genes_diseases.csv')
genes_table_df = pd.read_csv('/content/drive/MyDrive/Projects/GNN-Gene-Disease/Data/processed_data/genes_table_HUGO_final.csv')

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

In [None]:
genes_diseases_df[genes_diseases_df['# Disease ID']=='C0014859']

Unnamed: 0,# Disease ID,Disease Name,Gene ID
8114,C0014859,Esophageal Neoplasms,1012
8115,C0014859,Esophageal Neoplasms,1029
8116,C0014859,Esophageal Neoplasms,11178
8117,C0014859,Esophageal Neoplasms,11197
8118,C0014859,Esophageal Neoplasms,11214
...,...,...,...
8179,C0014859,Esophageal Neoplasms,841
8180,C0014859,Esophageal Neoplasms,864
8181,C0014859,Esophageal Neoplasms,8797
8182,C0014859,Esophageal Neoplasms,8856


#### Genes Preprocessing

In [None]:
rows_to_drop = ["End Chromossome Arm","End Chromossome Loc","End Chromossome SubLoc"]

In [None]:
genes_table_df = genes_table_df.drop(rows_to_drop,axis=1)

In [None]:
genes_table_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3523 entries, 0 to 3522
Data columns (total 12 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   entrez_id                 3522 non-null   float64
 1   name                      3523 non-null   object 
 2   locus_group               3523 non-null   object 
 3   locus_type                3523 non-null   object 
 4   location                  3523 non-null   object 
 5   gene_family               2335 non-null   object 
 6   gene_family_id            2335 non-null   object 
 7   id                        3523 non-null   int64  
 8   Start Chromossome         3509 non-null   float64
 9   Start Chromossome Arm     3338 non-null   object 
 10  Start Chromossome Loc     3495 non-null   float64
 11  Start Chromossome SubLoc  2445 non-null   float64
dtypes: float64(4), int64(1), object(7)
memory usage: 330.4+ KB


In [None]:
genes_table_df = genes_table_df.astype(object)

#### Categorical Imputer

In [None]:
ci = CategoricalImputer(imputation_method='missing')
ci.fit(genes_table_df)
genes_table_df = ci.transform(genes_table_df)

  X = X.assign(**add_cats).fillna(self.imputer_dict_)


In [None]:
genes_table_df

Unnamed: 0,entrez_id,name,locus_group,locus_type,location,gene_family,gene_family_id,id,Start Chromossome,Start Chromossome Arm,Start Chromossome Loc,Start Chromossome SubLoc
0,14.0,angio associated migratory cell protein,protein-coding gene,gene with protein product,2q,WD repeat domain containing,362,18,2.0,q,Missing,Missing
1,15.0,aralkylamine N-acetyltransferase,protein-coding gene,gene with protein product,17q25.1,GCN5 related N-acetyltransferases,1134,19,17.0,q,25.0,1.0
2,9625.0,apoptosis-associated tyrosine kinase,protein-coding gene,gene with protein product,17q25.3,Receptor Tyrosine Kinases|Protein phosphatase ...,321|694,21,17.0,q,25.0,3.0
3,18.0,4-aminobutyrate aminotransferase,protein-coding gene,gene with protein product,16p13.2,Missing,Missing,23,16.0,p,13.0,2.0
4,19.0,ATP binding cassette subfamily A member 1,protein-coding gene,gene with protein product,9q31,ATP binding cassette subfamily A,805,29,9.0,q,31.0,Missing
...,...,...,...,...,...,...,...,...,...,...,...,...
3518,64763.0,zinc finger protein 574,protein-coding gene,gene with protein product,19q13.2,Zinc fingers C2H2-type,28,26166,19.0,q,13.0,2.0
3519,493821.0,"zinc finger protein 603, pseudogene",pseudogene,pseudogene,6p22.1,Missing,Missing,23322,6.0,p,22.0,1.0
3520,126208.0,zinc finger protein 787,protein-coding gene,gene with protein product,19q13.43,Zinc fingers C2H2-type,28,26998,19.0,q,13.0,43.0
3521,284391.0,zinc finger protein 844,protein-coding gene,gene with protein product,19p13.2,Zinc fingers C2H2-type,28,25932,19.0,p,13.0,2.0


In [None]:
# Initialize a set with "Missing" to collect unique family IDs, although "Missing" will be removed later
families_set = set(["Missing"])
max_num = -1  # This variable is initialized but never used later, so it can be removed

for families in genes_table_df["gene_family_id"]:
    if families != "Missing":
        families = families.split("|")
        # max_num calculation is unnecessary since it's not used anywhere after its assignment
        # max_num = max(max_num, len(families))
        for family in families:
            families_set.add(family)
# Remove "Missing" after collecting all unique families, since it's not needed for encoding
families_set.remove("Missing")

# Create a mapping from family ID to its index for encoding purposes
order = {key: i for i, key in enumerate(families_set)}

# Initialize a list to hold the one-hot encoded vectors
family_emb = []

for families in genes_table_df["gene_family_id"]:
    # Split families by "|" or set to empty if "Missing"
    families = [] if families == "Missing" else families.split("|")
    emb = np.zeros(len(families_set))
    for family in families:
        if family in order:  # This check ensures only valid families are encoded
            emb[order[family]] = 1
    family_emb.append(emb)

In [None]:
genes_table_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3523 entries, 0 to 3522
Data columns (total 12 columns):
 #   Column                    Non-Null Count  Dtype 
---  ------                    --------------  ----- 
 0   entrez_id                 3523 non-null   object
 1   name                      3523 non-null   object
 2   locus_group               3523 non-null   object
 3   locus_type                3523 non-null   object
 4   location                  3523 non-null   object
 5   gene_family               3523 non-null   object
 6   gene_family_id            3523 non-null   object
 7   id                        3523 non-null   int64 
 8   Start Chromossome         3523 non-null   object
 9   Start Chromossome Arm     3523 non-null   object
 10  Start Chromossome Loc     3523 non-null   object
 11  Start Chromossome SubLoc  3523 non-null   object
dtypes: int64(1), object(11)
memory usage: 330.4+ KB


In [None]:
columuns_to_encode = [
                      "Start Chromossome",
                      "Start Chromossome Arm",
                      "Start Chromossome Loc",
                      "Start Chromossome SubLoc"
                      ]

ohe = OneHotEncoder(variables=columuns_to_encode)
ohe.fit(genes_table_df)
genes_table_df = ohe.transform(genes_table_df)

In [None]:
genes_table_df

Unnamed: 0,entrez_id,name,locus_group,locus_type,location,gene_family,gene_family_id,id,Start Chromossome_2.0,Start Chromossome_17.0,...,Start Chromossome SubLoc_22.0,Start Chromossome SubLoc_23.0,Start Chromossome SubLoc_4.0,Start Chromossome SubLoc_21.0,Start Chromossome SubLoc_32.0,Start Chromossome SubLoc_5.0,Start Chromossome SubLoc_41.0,Start Chromossome SubLoc_42.0,Start Chromossome SubLoc_43.0,Start Chromossome SubLoc_221.0
0,14.0,angio associated migratory cell protein,protein-coding gene,gene with protein product,2q,WD repeat domain containing,362,18,1,0,...,0,0,0,0,0,0,0,0,0,0
1,15.0,aralkylamine N-acetyltransferase,protein-coding gene,gene with protein product,17q25.1,GCN5 related N-acetyltransferases,1134,19,0,1,...,0,0,0,0,0,0,0,0,0,0
2,9625.0,apoptosis-associated tyrosine kinase,protein-coding gene,gene with protein product,17q25.3,Receptor Tyrosine Kinases|Protein phosphatase ...,321|694,21,0,1,...,0,0,0,0,0,0,0,0,0,0
3,18.0,4-aminobutyrate aminotransferase,protein-coding gene,gene with protein product,16p13.2,Missing,Missing,23,0,0,...,0,0,0,0,0,0,0,0,0,0
4,19.0,ATP binding cassette subfamily A member 1,protein-coding gene,gene with protein product,9q31,ATP binding cassette subfamily A,805,29,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3518,64763.0,zinc finger protein 574,protein-coding gene,gene with protein product,19q13.2,Zinc fingers C2H2-type,28,26166,0,0,...,0,0,0,0,0,0,0,0,0,0
3519,493821.0,"zinc finger protein 603, pseudogene",pseudogene,pseudogene,6p22.1,Missing,Missing,23322,0,0,...,0,0,0,0,0,0,0,0,0,0
3520,126208.0,zinc finger protein 787,protein-coding gene,gene with protein product,19q13.43,Zinc fingers C2H2-type,28,26998,0,0,...,0,0,0,0,0,0,0,0,1,0
3521,284391.0,zinc finger protein 844,protein-coding gene,gene with protein product,19p13.2,Zinc fingers C2H2-type,28,25932,0,0,...,0,0,0,0,0,0,0,0,0,0


### Encoding

In [None]:
def process_in_batches(
    column_data: List[str],
    model: nn.Module,
    tokenizer: AutoTokenizer,
    max_length: int,
    batch_size: int,
    device: Union[torch.device, str]
) -> torch.Tensor:
    """
    Processes a list of text data in batches, generating encoded representations using a specified model and tokenizer.

    **Parameters:**

    - column_data (List[str]): The list of text strings to process.
    - model (Union[nn.Module, Callable]): The PyTorch model or a custom function used for text encoding.
    - tokenizer (AutoTokenizer): The Hugging Face AutoTokenizer used for tokenization.
    - max_length (int): The maximum length of the tokenized sequences. Longer sequences will be truncated.
    - batch_size (int): The number of text strings to process in each batch.
    - device (Union[torch.device, str]): The device (e.g., 'cpu' or 'cuda') to use for computations.

    **Returns:**

    - torch.Tensor: A tensor of encoded representations for the input text data, with shape [num_examples, feature_size].
    """

    batched_output = []
    for i in range(0, len(column_data), batch_size):
        batch = column_data[i:i+batch_size]
        tokenized_inputs = tokenizer(batch, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True).to(device)
        with torch.no_grad():
            model_output = model(**tokenized_inputs)
        batched_output.append(model_output["last_hidden_state"].mean(dim=1))  # Taking the mean representation for each sequence
    return torch.cat(batched_output, dim=0)

In [None]:
def process_df_in_batches(
    df: pd.DataFrame,
    model: nn.Module,
    tokenizer: AutoTokenizer,
    device: Union[torch.device, str],
    batch_size: int = 512,
) -> torch.Tensor:
    """
    Processes each column of a pandas DataFrame in batches, returning a tensor of encoded representations.

    This function iterates over each column in the input DataFrame. For each column:
    1. NaN values are replaced with the string "empty element".
    2. All values are converted to strings.
    3. The text data is processed in batches using the provided `model` and `tokenizer`.
    The average length of sentences in a column (capped at 512) is used as `max_length` for tokenization for that column.

    **Parameters:**

    - df (pd.DataFrame): The DataFrame containing text data to be processed. Each column is processed individually.
    - model (nn.Module): The PyTorch model used for text processing.
    - tokenizer (AutoTokenizer): The Hugging Face AutoTokenizer used for tokenization.
    - device (Union[torch.device, str]): The device (e.g., 'cpu' or 'cuda') where the computations will be run.
    - batch_size (int, optional): The number of text strings to process in a single batch. Defaults to 512.

    **Returns:**

    - torch.Tensor: A tensor of shape [num_examples, num_columns, feature_size], where each element represents
      the encoded representation of text data from the corresponding column in the DataFrame.

    **Note:**

    This function assumes that the `process_in_batches` function is defined and accessible.
    It also assumes that the model and tokenizer are compatible with each other.
    """

    feature_matrix_list = []

    for col in tqdm(df.columns, desc="Processing DataFrame columns"):
        column_values_from_df = df[col].tolist()

        cleaned_column_data = []
        for item in column_values_from_df:
            if pd.isna(item):
                cleaned_column_data.append("empty element")
            else:
                cleaned_column_data.append(str(item))

        effective_max_length: int
        if not cleaned_column_data:
            effective_max_length = len("empty element")
        else:
            lengths = [len(sentence) for sentence in cleaned_column_data]
            if not lengths:
                effective_max_length = len("empty element")
            else:
                effective_max_length = int(sum(lengths) / len(lengths))

        effective_max_length = min(max(effective_max_length, 1), 512)

        encoded_column = process_in_batches(
            cleaned_column_data,
            model,
            tokenizer,
            effective_max_length,
            batch_size,
            device
        )
        feature_matrix_list.append(encoded_column)

    if not feature_matrix_list:
        expected_rows = df.shape[0] if df is not None and hasattr(df, 'shape') else 0
        # Attempt to get feature_size from model config if available, else a placeholder or raise error
        feature_size = 0
        if hasattr(model, 'config') and hasattr(model.config, 'hidden_size'):
            feature_size = model.config.hidden_size
        # If feature_size remains 0, this will create a [N,0,0] tensor which might be problematic.
        # Consider how an entirely empty DataFrame (no columns) should be handled.
        return torch.empty((expected_rows, 0, feature_size), device=device)

    try:
        encoded_matrix = torch.stack(feature_matrix_list, dim=1)
    except RuntimeError as e:
        print(f"Error stacking tensors: {e}")
        print("This might be due to inconsistent tensor shapes being returned for different columns,")
        print("possibly when a column leads to empty 'cleaned_column_data'.")
        raise e # Or return an appropriate empty/error tensor

    return encoded_matrix

In [None]:
diseases_table_df

Unnamed: 0,Disease ID,Disease Name,Disease Class,Definitions,# Disease(DOID),main_system_affected
0,C0036095,Salivary Gland Neoplasms,,Tumors or cancer of the SALIVARY GLANDS.,,
1,C0033941,"Psychoses, Substance-Induced",,Psychotic organic mental disorders resulting f...,,
2,C0043459,Zellweger Syndrome,inherited metabolic disorder,An autosomal recessive disorder due to defects...,DOID:905,
3,C0033860,Psoriasis,integumentary system disease,"A common genetically determined, chronic, infl...",DOID:8893,
4,C0027726,Nephrotic Syndrome,urinary system disease,A condition characterized by severe PROTEINURI...,DOID:1184,
...,...,...,...,...,...,...
514,C0005684,Malignant Neoplasm Of Urinary Bladder,cancer,,,
515,C0752347,Lewy Body Disease,nervous system disease,A neurodegenerative disease characterized by d...,,
516,C3160718,"Parkinson Disease, Late-Onset",,,,
517,C0311375,Arsenic Poisoning,,Disorders associated with acute or chronic exp...,,


In [None]:
encoded_diseases_tensors = process_df_in_batches(df = diseases_table_df.drop(["Disease ID","# Disease(DOID)"],axis = 1), model = model, tokenizer = tokenizer, device = device, batch_size = 768)

Processing DataFrame columns: 100%|██████████| 4/4 [00:06<00:00,  1.69s/it]


In [None]:
encoded_diseases_tensors_array = np.stack(encoded_diseases_tensors.cpu()).reshape(diseases_table_df.shape[0], -1)

In [None]:
columns_to_keep = ["name",	"locus_group",	"locus_type",	"location"]

encoded_genes_tensors = process_df_in_batches(df = genes_table_df[columns_to_keep],model = model, tokenizer = tokenizer, device = device, batch_size = 768).cpu()

encoded_genes_tensors = encoded_genes_tensors.cpu()

Processing DataFrame columns: 100%|██████████| 4/4 [00:13<00:00,  3.27s/it]


In [None]:
# Drop unnecessary columns just once, outside of the loop
selected = genes_table_df.drop(columns=["entrez_id", "name", "locus_group", "locus_type", "location", "gene_family", "gene_family_id", "id"]).to_numpy()
# Convert family_emb and encoded_genes_tensors to numpy arrays if they aren't already
family_emb_array = np.array(family_emb)
encoded_genes_tensors_array = np.stack(encoded_genes_tensors).reshape(genes_table_df.shape[0], -1)  # Assuming encoded_genes_tensors is a list of tensors

# Concatenate all arrays horizontally
encoded_genes_tensors_matrix = np.hstack((selected, family_emb_array, encoded_genes_tensors_array))

# Convert to Torch tensor
encoded_genes = torch.tensor(encoded_genes_tensors_matrix)

In [None]:
def reduce_dimensions_with_umap(data: np.ndarray, n_components: int = 512) -> np.ndarray:
    """
    Reduce the dimensions of a given matrix using UMAP.

    Args:
        data: A numpy ndarray of shape (n_samples, n_features) where n_samples is the number of samples
              and n_features is the number of features in the dataset.
        n_components: The number of dimensions to reduce the data to. Default is 256.

    Returns:
        A numpy ndarray of shape (n_samples, n_components) after dimensionality reduction.
    """
    # Initialize UMAP with the desired number of components
    umap_reducer = UMAP(n_components=n_components)

    # Fit the model to the data and transform it
    reduced_data = umap_reducer.fit_transform(data)

    return reduced_data

In [None]:
def normalize_with_robust_scaler(data: np.ndarray) -> np.ndarray:
    """
    Normalize each column of the dataset using the Robust Scaler.

    Args:
        data: A numpy ndarray of shape (n_samples, n_features) where n_samples is the number of samples
              and n_features is the number of features in the dataset.

    Returns:
        A numpy ndarray of shape (n_samples, n_features) after normalization.
    """
    # Initialize the RobustScaler
    scaler = RobustScaler()

    # Fit the scaler to the data and transform it
    normalized_data = scaler.fit_transform(data)

    return normalized_data

In [None]:
def pca_reduction_and_explained_variance(data: np.ndarray, n_components: int = 512) -> tuple:
    """
    Reduce the dimensions of a given matrix using PCA and calculate the explained variance.

    Args:
        data: A numpy ndarray of shape (n_samples, n_features) where n_samples is the number of samples
              and n_features is the number of features in the dataset.
        n_components: The number of principal components to reduce the data to. Default is 256.

    Returns:
        A tuple containing:
        - reduced_data: A numpy ndarray of shape (n_samples, n_components) after PCA reduction.
        - explained_variance_ratio: The percentage of variance explained by the selected components.
    """
    # Initialize PCA with the desired number of components
    pca = PCA(n_components=n_components)

    # Fit PCA on the data and transform it
    reduced_data = pca.fit_transform(data)

    # Calculate the percentage of variance explained by the selected components
    explained_variance_ratio = np.sum(pca.explained_variance_ratio_) * 100

    return reduced_data, explained_variance_ratio

In [None]:
n_components = 512

In [None]:
print("Shape of encoded_diseases_tensors_array before dimension reduction:", encoded_diseases_tensors_array.shape)
print("Shape of encoded_genes before dimension reduction:", encoded_genes.shape)

Shape of encoded_diseases_tensors_array before dimension reduction: (519, 3072)
Shape of encoded_genes before dimension reduction: torch.Size([3523, 3703])


In [None]:
# Perform PCA reduction and get the explained variance
reduced_data, explained_variance_ratio = pca_reduction_and_explained_variance(encoded_diseases_tensors_array,n_components)
print(f"Shape of reduced data: {reduced_data.shape}")
print(f"Explained variance ratio by the first {n_components} components: {explained_variance_ratio:.2f}%")

Shape of reduced data: (519, 512)
Explained variance ratio by the first 512 components: 100.00%


In [None]:
# Perform PCA reduction and get the explained variance
reduced_data, explained_variance_ratio = pca_reduction_and_explained_variance(encoded_genes,n_components)
print(f"Shape of reduced data: {reduced_data.shape}")
print(f"Explained variance ratio by the first {n_components} components: {explained_variance_ratio:.2f}%")

Shape of reduced data: (3523, 512)
Explained variance ratio by the first 512 components: 98.22%


In [None]:
encoded_genes_reduced = reduce_dimensions_with_umap(encoded_genes)
print(encoded_genes_reduced.shape)

encoded_genes_reduced_normalized = normalize_with_robust_scaler(encoded_genes_reduced)

encoded_genes_tensors_reduced_normalized = torch.from_numpy(encoded_genes_reduced_normalized.astype(np.float32))

encoded_genes_tensors_reduced_normalized.shape



(3523, 512)


torch.Size([3523, 512])

In [None]:
torch.save(encoded_genes_tensors_reduced_normalized, "/content/drive/MyDrive/Projects/GNN-Gene-Disease/Data/embeddings/encoded_genes_tensors_reduced_normalized.pt")

In [None]:
encoded_genes_tensors = torch.load("/content/drive/MyDrive/Projects/GNN-Gene-Disease/Data/embeddings/encoded_genes_tensors_reduced_normalized.pt");
encoded_genes_tensors.shape

torch.Size([3523, 512])

In [None]:
encoded_diseases_reduced = reduce_dimensions_with_umap(encoded_diseases_tensors_array)
print(encoded_diseases_reduced.shape)

encoded_diseases_reduced_normalized = normalize_with_robust_scaler(encoded_diseases_reduced)

encoded_diseases_tensors_reduced_normalized = torch.from_numpy(encoded_diseases_reduced_normalized.astype(np.float32))

torch.save(encoded_diseases_tensors_reduced_normalized, "/content/drive/MyDrive/Projects/GNN-Gene-Disease/Data/embeddings/encoded_diseases_tensors_reduced_normalized_without_gemini.pt")



(519, 512)


### Creating link connection

In [None]:
def create_link_array(genes_diseases_df, disease_features, gene_features):
    """
    Creates a link connection tensor following PyTorch Geometric's edge index format.
    """
    # Create a map for efficient index lookup
    disease_id_to_idx = {d_id: idx for idx, d_id in enumerate(disease_features["Disease ID"])}
    gene_id_to_idx = {g_id: idx for idx, g_id in enumerate(gene_features["id"])}

    # Initialize an empty list to store edge indices
    edge_list = []

    # Iterate through each row of the genes_diseases_df DataFrame
    for _, row in genes_diseases_df.iterrows():
        disease_id = row["# Disease ID"]
        gene_id = row["Gene ID"]

        if (disease_id in disease_id_to_idx) and (gene_id in gene_id_to_idx):
            source_index = disease_id_to_idx[disease_id]
            target_index = gene_id_to_idx[gene_id]
            edge_list.append([source_index, target_index])

    # Convert the list of edges to a PyTorch tensor with appropriate type and transpose it
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

    return edge_index

In [None]:
connection_disease_gene = create_link_array(genes_diseases_df, diseases_table_df, genes_table_df)

In [None]:
torch.save(connection_disease_gene, "/content/drive/MyDrive/Projects/GNN-Gene-Disease/Data/edges/connection_disease_gene_augmented.pt")