In [12]:
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import numpy as np 
import os
from tqdm import tqdm

In [13]:
class GeneDataset(Dataset):
    def __init__(self, root, filenames, test=False, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data). 
        """
        self.test = test
        self.filenames = filenames
        super(GeneDataset, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.
            (The download func. is not implemented here)  
        """
        return self.filenames

    @property
    def processed_file_names(self):
        """ If these files are found in raw_dir, processing is skipped"""
        if self.test:
            return [F'{file_name}_test' for file_name in self.raw_paths]
        else:
            return self.raw_paths

    def download(self):
        pass

    def process(self):
        self.genes = pd.read_csv(self.raw_paths[0], sep="\t")
        self.genes = self.genes.drop(columns="Description")
        self.edges = pd.read_csv(self.raw_paths[1], sep="\t")

        node_feats = self._get_node_features(self.genes)
        edge_feats = self._get_edge_features(self.edges)
        edge_index = self._get_adjacency_info(self.edges)

        data = Data(x=node_feats, 
                    edge_index=edge_index,
                    edge_attr=edge_feats)
         
        if self.test:
            torch.save(data, os.path.join(self.processed_dir, 'graph_test.pt'))
        else:
            torch.save(data, os.path.join(self.processed_dir, 'graph.pt'))


    def _get_node_features(self, genes):
        genes["genes"] = genes["genes"].str[4:].astype(int)
        all_node_feats = genes.values.tolist()
        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.int64)

    def _get_edge_features(self, edges):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = edges["combined_score"].tolist()
        return torch.tensor(all_edge_feats, dtype=torch.float)


    def _get_adjacency_info(self, edges):
        """
        We want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        gene_1 = edges["gene1"].str[4:].astype(int)
        gene_2 = edges["gene2"].str[4:].astype(int)
        edges = pd.concat([gene_1, gene_2], axis=1).values.tolist()

        #iterate over the edges end duplicate it because for one edge we need: n1,n2 and n2,n1
        double_edges = []
        for edge in edges:
            double_edges += [ edge, [edge[1], edge[0]]]

        edge_indices = torch.tensor(double_edges)
        edge_indices = edge_indices.t().to(torch.int64).view(2, -1)
        return edge_indices

    def len(self):
        return self.genes.shape[0]

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        if self.test:
            graph = torch.load(os.path.join(self.processed_dir, 'graph_test.pt'), weights_only=False)
        else:
            graph = torch.load(os.path.join(self.processed_dir, 'graph.pt'), weights_only=False)

        #return with a given node
        return graph

In [17]:
dataset = GeneDataset(root="./data", filenames=["gtex_genes.csv", "gene_graph.csv"])

Processing...
Done!


In [15]:
genes = pd.read_csv("./data/raw/gtex_genes.csv")
genes.shape

(19506, 1)

In [18]:
print(dataset[0])

Data(x=[16127, 55], edge_index=[2, 468794], edge_attr=[234397])


A gráf kirajzoltatása
(nagyon lassan fut le!!!!!!!!! --> 1 óra volt a colab-ban)

In [27]:
# import networkx as nx
# from torch_geometric.utils import to_networkx
# import matplotlib.pyplot as plt

# G = to_networkx(dataset[0], to_undirected=True)
# plt.figure(figsize=(100, 100))
# nx.draw(G, with_labels=False, node_color='lightblue', font_weight='bold')
# plt.savefig("graph.svg", format="svg")

DataLoader nem kell

validation, test adatok: egyes betegségek oszlopbol leválasztani 20%ok jókat rosszakat is és az a teszt adathalmaz, ugyani így a validáció
    betegsékeg tesztelésre csak, nincs validációs adat hozzá

disgenetet úgy tovább szűrni, hogy az egyes betegséghez legalább x gén tartozzon --> végén majd kiprobálni, hogy nem szürök rajtuk

GCN --> a veszteség függvény legyen jó, sima bináris osztályozás

kedd 10:15kor