<a href="https://colab.research.google.com/github/SVJLucas/GraphMining/blob/main/GNN/Generating_Embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Installation commands
%%capture
!pip install feature-engine
!pip install umap-learn

# Standard library imports
import os
from typing import List, Union

# Third-party imports for data manipulation, machine learning, and plotting
import numpy as np
import pandas as pd
from tqdm import tqdm
import networkx as nx
import matplotlib.pyplot as plt

# PyTorch imports
import torch
from torch import nn

# Feature engineering tools
from feature_engine.imputation import CategoricalImputer
from feature_engine.encoding import OneHotEncoder

# NLP transformers
from transformers import AutoTokenizer, AutoModel

# Dimensionality reduction and preprocessing
from sklearn.preprocessing import RobustScaler
from sklearn.decomposition import PCA
from umap import UMAP

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [None]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


### Data dealing

In [None]:
import pandas as pd

In [None]:
diseases_table_df = pd.read_csv('/content/drive/MyDrive/Projetos/GNN-Gene-Disease/Data/processed_data/diseases_table_final.csv')
genes_diseases_df = pd.read_csv('/content/drive/MyDrive/Projetos/GNN-Gene-Disease/Data/processed_data/genes_diseases.csv')
genes_table_df = pd.read_csv('/content/drive/MyDrive/Projetos/GNN-Gene-Disease/Data/processed_data/genes_table_final.csv')

In [None]:
diseases_table_df.iloc[204]

"Tobacco and alcohol use, obesity, GERD, Barrett's esophagus"

In [None]:
diseases_table_df.iloc[204]

Unnamed: 0                                                            204
Disease ID                                                       C0014859
Disease Name                                         Esophageal Neoplasms
Disease Class                                                      cancer
Definitions                            Tumors or cancer of the ESOPHAGUS.
Main Symptom                                        Difficulty swallowing
Risk Factors            Tobacco and alcohol use, obesity, GERD, Barret...
Disease Class GPT                                                  Cancer
Main System Affected                                     Digestive system
Name: 204, dtype: object

In [None]:
genes_diseases_df[genes_diseases_df['# Disease ID']=='C0014859']

Unnamed: 0,# Disease ID,Disease Name,Gene ID
8114,C0014859,Esophageal Neoplasms,1012
8115,C0014859,Esophageal Neoplasms,1029
8116,C0014859,Esophageal Neoplasms,11178
8117,C0014859,Esophageal Neoplasms,11197
8118,C0014859,Esophageal Neoplasms,11214
...,...,...,...
8179,C0014859,Esophageal Neoplasms,841
8180,C0014859,Esophageal Neoplasms,864
8181,C0014859,Esophageal Neoplasms,8797
8182,C0014859,Esophageal Neoplasms,8856


In [None]:
genes_table_df[genes_table_df['Gene ID']==841]

Unnamed: 0.1,Unnamed: 0,Gene ID,hgnc_id,name,locus_group,locus_type,location,gene_family,gene_family_id,Start Chromossome,Start Chromossome Arm,Start Chromossome Loc,Start Chromossome SubLoc,End Chromossome Arm,End Chromossome Loc,End Chromossome SubLoc
2008,2008,11197,HGNC:18081,WNT inhibitory factor 1,protein-coding gene,gene with protein product,12q14.2,,,12.0,q,14.0,2.0,,,


In [None]:
genes_diseases_df.columns

Index(['# Disease ID', 'Disease Name', 'Gene ID'], dtype='object')

#### diseases preprocessing

In [None]:
diseases_table_df[diseases_table_df['Risk Factors'].isna()]

Unnamed: 0.1,Unnamed: 0,Disease ID,Disease Name,Disease Class,Definitions,Main Symptom,Risk Factors,Disease Class GPT,Main System Affected
60,60,C0027404,Narcolepsy,sleep disorder,A condition characterized by recurrent episode...,"Daytime somnolence, lapses in consciousness, a...",,,
67,67,C0030354,Papilloma,benign neoplasm,A circumscribed benign epithelial tumor projec...,circumscribed benign epithelial tumor,,benign epithelial neoplasm,epithelial tissue
199,199,C0006663,Calcinosis,acquired metabolic disease,Pathologic deposition of calcium salts in tiss...,Pathologic deposition of calcium salts in tissues,,,


In [None]:
disease_info = {
    "narcolepsy": {
        "risk_factors": "Unknown (more research needed)",
        "disease_class_gpt": "Neurological",
        "main_system_affected": "Nervous system (brain)"
    },
    "calcinosis": {
        "risk_factors": "Abnormal calcium metabolism, certain medical conditions, medications, genetic predisposition",
        "disease_class_gpt": "Metabolic disorder",
        "main_system_affected": "Varies depending on the type (e.g., musculoskeletal, skin)"
    }
}

In [None]:
for disease_id in [60, 199]:
  disease_name = diseases_table_df.loc[disease_id, "Disease Name"].lower()
  if disease_name in disease_info:
    info = disease_info[disease_name]
    diseases_table_df.loc[disease_id, ["Risk Factors", "Disease Class GPT", "Main System Affected"]] = info.values()

In [None]:
diseases_table_df["Risk Factors"][67] = "Human papillomavirus (HPV) infection, weakened immune system, chronic irritation"

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  diseases_table_df["Risk Factors"][67] = "Human papillomavirus (HPV) infection, weakened immune system, chronic irritation"


In [None]:
diseases_table_df = diseases_table_df.drop('Unnamed: 0',axis=1)

#### Genes Preprocessing

In [None]:
rows_to_drop = ["Unnamed: 0","End Chromossome Arm","End Chromossome Loc","End Chromossome SubLoc"]

In [None]:
genes_table_df = genes_table_df.drop(rows_to_drop,axis=1)

In [None]:
genes_table_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6948 entries, 0 to 6947
Data columns (total 12 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   Gene ID                   6948 non-null   int64  
 1   hgnc_id                   6948 non-null   object 
 2   name                      6948 non-null   object 
 3   locus_group               6948 non-null   object 
 4   locus_type                6948 non-null   object 
 5   location                  6948 non-null   object 
 6   gene_family               4692 non-null   object 
 7   gene_family_id            4692 non-null   object 
 8   Start Chromossome         6916 non-null   float64
 9   Start Chromossome Arm     6671 non-null   object 
 10  Start Chromossome Loc     6881 non-null   float64
 11  Start Chromossome SubLoc  4787 non-null   float64
dtypes: float64(3), int64(1), object(8)
memory usage: 651.5+ KB


In [None]:
genes_table_df = genes_table_df.astype(object)

#### Categorical Imputer

In [None]:
ci = CategoricalImputer(imputation_method='missing')
ci.fit(genes_table_df)
genes_table_df = ci.transform(genes_table_df)

In [None]:
genes_table_df

Unnamed: 0,Gene ID,hgnc_id,name,locus_group,locus_type,location,gene_family,gene_family_id,Start Chromossome,Start Chromossome Arm,Start Chromossome Loc,Start Chromossome SubLoc
0,1462,HGNC:2464,versican,protein-coding gene,gene with protein product,5q14.2-q14.3,Hyalectan proteoglycans|V-set domain containin...,574|590|1179,5.0,q,14.0,2.0
1,1612,HGNC:2674,death associated protein kinase 1,protein-coding gene,gene with protein product,9q34.1,Ankyrin repeat domain containing|Death-associa...,403|1021,9.0,q,34.0,1.0
2,182,HGNC:6188,jagged 1,protein-coding gene,gene with protein product,20p12.1-p11.23,CD molecules,471,20.0,p,12.0,1.0
3,2011,HGNC:3332,microtubule affinity regulating kinase 2,protein-coding gene,gene with protein product,11q13.1,Missing,Missing,11.0,q,13.0,1.0
4,2019,HGNC:3342,engrailed homeobox 1,protein-coding gene,gene with protein product,2q14.2,NKL subclass homeoboxes and pseudogenes,519,2.0,q,14.0,2.0
...,...,...,...,...,...,...,...,...,...,...,...,...
6943,407037,HGNC:31632,microRNA 320a,non-coding RNA,"RNA, micro",8p21.3,MicroRNAs,476,8.0,p,21.0,3.0
6944,79727,HGNC:15986,lin-28 homolog A,protein-coding gene,gene with protein product,1p35.3,Zinc fingers CCHC-type,74,1.0,p,35.0,3.0
6945,8505,HGNC:8605,poly(ADP-ribose) glycohydrolase,protein-coding gene,gene with protein product,10q11.23,Missing,Missing,10.0,q,11.0,23.0
6946,8668,HGNC:3272,eukaryotic translation initiation factor 3 sub...,protein-coding gene,gene with protein product,1p34.1,WD repeat domain containing|Eukaryotic transla...,362|1121,1.0,p,34.0,1.0


In [None]:
# Initialize a set with "Missing" to collect unique family IDs, although "Missing" will be removed later
families_set = set(["Missing"])
max_num = -1  # This variable is initialized but never used later, so it can be removed

for families in genes_table_df["gene_family_id"]:
    if families != "Missing":
        families = families.split("|")
        # max_num calculation is unnecessary since it's not used anywhere after its assignment
        # max_num = max(max_num, len(families))
        for family in families:
            families_set.add(family)
# Remove "Missing" after collecting all unique families, since it's not needed for encoding
families_set.remove("Missing")

# Create a mapping from family ID to its index for encoding purposes
order = {key: i for i, key in enumerate(families_set)}

# Initialize a list to hold the one-hot encoded vectors
family_emb = []

for families in genes_table_df["gene_family_id"]:
    # Split families by "|" or set to empty if "Missing"
    families = [] if families == "Missing" else families.split("|")
    emb = np.zeros(len(families_set))
    for family in families:
        if family in order:  # This check ensures only valid families are encoded
            emb[order[family]] = 1
    family_emb.append(emb)

In [None]:
genes_table_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6948 entries, 0 to 6947
Data columns (total 12 columns):
 #   Column                    Non-Null Count  Dtype 
---  ------                    --------------  ----- 
 0   Gene ID                   6948 non-null   int64 
 1   hgnc_id                   6948 non-null   object
 2   name                      6948 non-null   object
 3   locus_group               6948 non-null   object
 4   locus_type                6948 non-null   object
 5   location                  6948 non-null   object
 6   gene_family               6948 non-null   object
 7   gene_family_id            6948 non-null   object
 8   Start Chromossome         6948 non-null   object
 9   Start Chromossome Arm     6948 non-null   object
 10  Start Chromossome Loc     6948 non-null   object
 11  Start Chromossome SubLoc  6948 non-null   object
dtypes: int64(1), object(11)
memory usage: 651.5+ KB


In [None]:
columuns_to_encode = [
                      "Start Chromossome",
                      "Start Chromossome Arm",
                      "Start Chromossome Loc",
                      "Start Chromossome SubLoc"
                      ]

ohe = OneHotEncoder(variables=columuns_to_encode)
ohe.fit(genes_table_df)
genes_table_df = ohe.transform(genes_table_df)

In [None]:
genes_table_df

Unnamed: 0,Gene ID,hgnc_id,name,locus_group,locus_type,location,gene_family,gene_family_id,Start Chromossome_5.0,Start Chromossome_9.0,...,Start Chromossome SubLoc_32.0,Start Chromossome SubLoc_22.0,Start Chromossome SubLoc_31.0,Start Chromossome SubLoc_41.0,Start Chromossome SubLoc_12.0,Start Chromossome SubLoc_43.0,Start Chromossome SubLoc_5.0,Start Chromossome SubLoc_6.0,Start Chromossome SubLoc_42.0,Start Chromossome SubLoc_223.0
0,1462,HGNC:2464,versican,protein-coding gene,gene with protein product,5q14.2-q14.3,Hyalectan proteoglycans|V-set domain containin...,574|590|1179,1,0,...,0,0,0,0,0,0,0,0,0,0
1,1612,HGNC:2674,death associated protein kinase 1,protein-coding gene,gene with protein product,9q34.1,Ankyrin repeat domain containing|Death-associa...,403|1021,0,1,...,0,0,0,0,0,0,0,0,0,0
2,182,HGNC:6188,jagged 1,protein-coding gene,gene with protein product,20p12.1-p11.23,CD molecules,471,0,0,...,0,0,0,0,0,0,0,0,0,0
3,2011,HGNC:3332,microtubule affinity regulating kinase 2,protein-coding gene,gene with protein product,11q13.1,Missing,Missing,0,0,...,0,0,0,0,0,0,0,0,0,0
4,2019,HGNC:3342,engrailed homeobox 1,protein-coding gene,gene with protein product,2q14.2,NKL subclass homeoboxes and pseudogenes,519,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6943,407037,HGNC:31632,microRNA 320a,non-coding RNA,"RNA, micro",8p21.3,MicroRNAs,476,0,0,...,0,0,0,0,0,0,0,0,0,0
6944,79727,HGNC:15986,lin-28 homolog A,protein-coding gene,gene with protein product,1p35.3,Zinc fingers CCHC-type,74,0,0,...,0,0,0,0,0,0,0,0,0,0
6945,8505,HGNC:8605,poly(ADP-ribose) glycohydrolase,protein-coding gene,gene with protein product,10q11.23,Missing,Missing,0,0,...,0,0,0,0,0,0,0,0,0,0
6946,8668,HGNC:3272,eukaryotic translation initiation factor 3 sub...,protein-coding gene,gene with protein product,1p34.1,WD repeat domain containing|Eukaryotic transla...,362|1121,0,0,...,0,0,0,0,0,0,0,0,0,0


### Encoding

In [None]:
def process_in_batches(
    column_data: List[str],
    model: nn.Module,
    tokenizer: AutoTokenizer,
    max_length: int,
    batch_size: int,
    device: Union[torch.device, str]
) -> torch.Tensor:
    """
    Processes a list of text data in batches, generating encoded representations using a specified model and tokenizer.

    **Parameters:**

    - column_data (List[str]): The list of text strings to process.
    - model (Union[nn.Module, Callable]): The PyTorch model or a custom function used for text encoding.
    - tokenizer (AutoTokenizer): The Hugging Face AutoTokenizer used for tokenization.
    - max_length (int): The maximum length of the tokenized sequences. Longer sequences will be truncated.
    - batch_size (int): The number of text strings to process in each batch.
    - device (Union[torch.device, str]): The device (e.g., 'cpu' or 'cuda') to use for computations.

    **Returns:**

    - torch.Tensor: A tensor of encoded representations for the input text data, with shape [num_examples, feature_size].
    """

    batched_output = []
    for i in range(0, len(column_data), batch_size):
        batch = column_data[i:i+batch_size]
        tokenized_inputs = tokenizer(batch, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True).to(device)
        with torch.no_grad():
            model_output = model(**tokenized_inputs)
        batched_output.append(model_output["last_hidden_state"].mean(dim=1))  # Taking the mean representation for each sequence
    return torch.cat(batched_output, dim=0)

def process_df_in_batches(
    df: pd.DataFrame,
    model: nn.Module,
    tokenizer: AutoTokenizer,
    device: Union[torch.device, str],
    batch_size: int = 512,
) -> torch.Tensor:
    """
    Processes each column of a pandas DataFrame in batches, returning a tensor of encoded representations.

    This function iterates over each column in the input DataFrame and treats it as a separate set of text data. It processes the text data in batches using the provided `model` and `tokenizer`.

    **Parameters:**

    - df (pd.DataFrame): The DataFrame containing text data to be processed. Each column is processed individually.
    - model (nn.Module): The PyTorch model used for text processing.
    - tokenizer (AutoTokenizer): The Hugging Face AutoTokenizer used for tokenization.
    - device (Union[torch.device, str]): The device (e.g., 'cpu' or 'cuda') where the computations will be run.
    - batch_size (int, optional): The number of text strings to process in a single batch. Defaults to 512.
    - max_length (int, optional): The initial maximum length of the tokenized sequences. This value may be adjusted based on the data. Defaults to 256.

    **Returns:**

    - torch.Tensor: A tensor of shape [num_examples, num_columns, feature_size], where each element represents the encoded representation of text data from the corresponding column in the DataFrame.

    **Note:**

    This function assumes that the `process_in_batches` function is defined and accessible. It also assumes that the model and tokenizer are compatible with each other.
    """

    feature_matrix_list = []

    for col in tqdm(df.columns, desc="Processing DataFrame"):
        column_data = df[col].tolist()
        lengths = [len(sentence) for sentence in column_data]
        average_length = int(sum(lengths) / len(lengths))
        average_length = min(average_length, 512)  # Cap the average length to 512 to prevent excessive padding/truncation

        encoded_column = process_in_batches(column_data, model, tokenizer, average_length, batch_size, device)
        feature_matrix_list.append(encoded_column)

    encoded_matrix = torch.stack(feature_matrix_list, dim=1)
    return encoded_matrix

In [None]:
encoded_diseases_tensors = process_df_in_batches(df = diseases_table_df.drop(["Disease ID"],axis = 1), model = model, tokenizer = tokenizer, device = device, batch_size = 768)

Processing DataFrame: 100%|██████████| 7/7 [00:06<00:00,  1.13it/s]


In [None]:
columns_to_keep = ["name",	"locus_group",	"locus_type",	"location"]

encoded_genes_tensors = process_df_in_batches(df = genes_table_df[columns_to_keep],model = model, tokenizer = tokenizer, device = device, batch_size = 768).cpu()

encoded_genes_tensors = encoded_genes_tensors.cpu()

In [None]:
y# Drop unnecessary columns just once, outside of the loop
selected = genes_table_df.drop(columns=["Gene ID", "hgnc_id", "name", "locus_group", "locus_type", "location", "gene_family", "gene_family_id"]).to_numpy()

# Convert family_emb and encoded_genes_tensors to numpy arrays if they aren't already
family_emb_array = np.array(family_emb)
encoded_genes_tensors_array = np.stack(encoded_genes_tensors).reshape(genes_table_df.shape[0], -1)  # Assuming encoded_genes_tensors is a list of tensors

# Concatenate all arrays horizontally
encoded_genes_tensors_matrix = np.hstack((selected, family_emb_array, encoded_genes_tensors_array))

# Convert to Torch tensor
encoded_genes = torch.tensor(encoded_genes_tensors_matrix)

In [None]:
np.save("/content/drive/MyDrive/Projetos/GNN-Gene-Disease/Data/embeddings/posisions_encoded.npy",selected)

In [None]:
def reduce_dimensions_with_umap(data: np.ndarray, n_components: int = 256) -> np.ndarray:
    """
    Reduce the dimensions of a given matrix using UMAP.

    Args:
        data: A numpy ndarray of shape (n_samples, n_features) where n_samples is the number of samples
              and n_features is the number of features in the dataset.
        n_components: The number of dimensions to reduce the data to. Default is 256.

    Returns:
        A numpy ndarray of shape (n_samples, n_components) after dimensionality reduction.
    """
    # Initialize UMAP with the desired number of components
    umap_reducer = UMAP(n_components=n_components)

    # Fit the model to the data and transform it
    reduced_data = umap_reducer.fit_transform(data)

    return reduced_data

# Assuming `original_data` is your (6948, 3870) matrix
# original_data = np.random.rand(6948, 3870)  # Example: replace with your actual data

encoded_genes_tensors_reduced = reduce_dimensions_with_umap(encoded_genes_tensors)
print(encoded_genes_tensors_reduced.shape)  # This will print (6948, 256)


(6948, 256)


In [None]:
def normalize_with_robust_scaler(data: np.ndarray) -> np.ndarray:
    """
    Normalize each column of the dataset using the Robust Scaler.

    Args:
        data: A numpy ndarray of shape (n_samples, n_features) where n_samples is the number of samples
              and n_features is the number of features in the dataset.

    Returns:
        A numpy ndarray of shape (n_samples, n_features) after normalization.
    """
    # Initialize the RobustScaler
    scaler = RobustScaler()

    # Fit the scaler to the data and transform it
    normalized_data = scaler.fit_transform(data)

    return normalized_data

In [None]:
encoded_genes_tensors_reduced_normalized = normalize_with_robust_scaler(encoded_genes_tensors_reduced)

In [None]:
torch.save(encoded_genes_tensors_reduced_normalized, "/content/drive/MyDrive/Projetos/GNN-Gene-Disease/Data/embeddings/encoded_genes_tensors_reduced_normalized.pt")

In [None]:
def pca_reduction_and_explained_variance(data: np.ndarray, n_components: int = 256) -> tuple:
    """
    Reduce the dimensions of a given matrix using PCA and calculate the explained variance.

    Args:
        data: A numpy ndarray of shape (n_samples, n_features) where n_samples is the number of samples
              and n_features is the number of features in the dataset.
        n_components: The number of principal components to reduce the data to. Default is 256.

    Returns:
        A tuple containing:
        - reduced_data: A numpy ndarray of shape (n_samples, n_components) after PCA reduction.
        - explained_variance_ratio: The percentage of variance explained by the selected components.
    """
    # Initialize PCA with the desired number of components
    pca = PCA(n_components=n_components)

    # Fit PCA on the data and transform it
    reduced_data = pca.fit_transform(data)

    # Calculate the percentage of variance explained by the selected components
    explained_variance_ratio = np.sum(pca.explained_variance_ratio_) * 100

    return reduced_data, explained_variance_ratio



# Perform PCA reduction and get the explained variance
reduced_data, explained_variance_ratio = pca_reduction_and_explained_variance(encoded_genes_tensors)
print(f"Shape of reduced data: {reduced_data.shape}")  # This will print (6948, 256)
print(f"Explained variance ratio by the first 256 components: {explained_variance_ratio:.2f}%")


Shape of reduced data: (6948, 256)
Explained variance ratio by the first 256 components: 93.00%


In [None]:
torch.save(encoded_genes, "/content/drive/MyDrive/Projetos/GNN-Gene-Disease/Data/embeddings/genes_df_encoded.pt")

### Creating link connection

In [None]:
def create_link_array(genes_diseases_df, disease_features, gene_features):
    """
    Creates a link connection tensor following PyTorch Geometric's edge index format.
    """
    # Create a map for efficient index lookup
    disease_id_to_idx = {d_id: idx for idx, d_id in enumerate(disease_features["Disease ID"])}
    gene_id_to_idx = {g_id: idx for idx, g_id in enumerate(gene_features["Gene ID"])}

    # Initialize an empty list to store edge indices
    edge_list = []

    # Iterate through each row of the genes_diseases_df DataFrame
    for _, row in genes_diseases_df.iterrows():
        disease_id = row["# Disease ID"]
        gene_id = row["Gene ID"]

        if (disease_id in disease_id_to_idx) and (gene_id in gene_id_to_idx):
            source_index = disease_id_to_idx[disease_id]
            target_index = gene_id_to_idx[gene_id]
            edge_list.append([source_index, target_index])

    # Convert the list of edges to a PyTorch tensor with appropriate type and transpose it
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

    return edge_index

In [None]:
connection_disease_gene = create_link_array(genes_diseases_df, diseases_table_df, genes_table_df)

In [None]:
torch.save(connection_disease_gene, "/content/drive/MyDrive/Projetos/GNN-Gene-Disease/Data/edges/connection_disease_gene_augmented.pt")

In [None]:
torch.save(connection_disease_gene, "/content/drive/MyDrive/Projetos/GNN-Gene-Disease/Data/edges/connection_disease_gene_augmented.pt")