In [13]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import categorical_accuracy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

from spektral.data import Dataset, DisjointLoader, Graph
from spektral.layers import GCSConv, GlobalAvgPool
from spektral.transforms.normalize_adj import NormalizeAdj

##########
#DataFrame
import pandas as pd
from numpy import array
#Matrice Sparsa
from scipy.sparse import csr_matrix
#Aprire i File
import os
#Per la gestione dei file già esistenti
from os.path import exists
##########
################################################################################
# Config
################################################################################
learning_rate = 1e-2  # Learning rate
epochs = 400  # Number of training epochs
es_patience = 10  # Patience for early stopping
batch_size = 32  # Batch size


################################################################################
# Load data
################################################################################
class MyDataset(Dataset):
    """
    A dataset of random colored graphs.
    The task is to classify each graph with the color which occurs the most in
    its nodes.
    The graphs have `n_colors` colors, of at least `n_min` and at most `n_max`
    nodes connected with probability `p`.
    """

    def __init__(self, n_samples, **kwargs):
        self.n_samples = n_samples
        super().__init__(**kwargs)

    def read(self):
        id_p_e = []
        def make_graph():
           
            proteine_path = "../Dati/proteine.csv"
            proteine = pd.read_csv(proteine_path, sep = ',')
            for filename in os.listdir("../Dati/Grafi/"):  
                with open(os.path.join("../Dati/Grafi/", filename), 'r', encoding='windows-1252') as f:
                    id_p = filename[0:4]
                    if (filename != ".DS_Store") and (id_p not in id_p_e):
                        id_p_e.append(id_p)
                        data = np.load("../Dati/Grafi/{}.npz".format(id_p), allow_pickle=True)
                        
                        #print(id_p, id_p_e)
                        
                        return Graph(x= data['x'], a= csr_matrix(data['a'].all()), e = None, y= data['y'])

        # We must return a list of Graph objects
        return [make_graph() for _ in range(self.n_samples)]

In [14]:
data = MyDataset(49, transforms=NormalizeAdj())

In [15]:
type(data)

__main__.MyDataset

In [17]:
data

MyDataset(n_graphs=49)

In [18]:
data.n_labels

2

In [19]:
for grafo in range(data.n_graphs):
    print(data[grafo], data[grafo].y)

Graph(n_nodes=126, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '2qrw']
Graph(n_nodes=128, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '1s56']
Graph(n_nodes=135, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '3vhb']
Graph(n_nodes=123, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '1s69']
Graph(n_nodes=128, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '2bkm']
Graph(n_nodes=127, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '1idr']
Graph(n_nodes=135, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '1vhb']
Graph(n_nodes=137, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '3lb2']
Graph(n_nodes=154, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '1ux9']
Graph(n_nodes=128, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '2gln']
Graph(n_nodes=154, n_node_features=1, n_edge_features=None, n_labels=2) ['a' '1umo']
Graph(n_nodes=119, n_node_features=1, n_edge_features=None, n_lab