In [2]:
# prova

In [3]:
import re
import ast
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
import networkx as nx

from utils_martina.my_utils import *

In [4]:
logs_path = "..\\..\\explainability\GRETEL-repo\\output\\logs\\"
eval_manager_path = "..\\..\\explainability\GRETEL-repo\\output\\eval_manager\\"
output_path = "..\\..\\explainability\GRETEL-repo\\output\\evolution_3d_embeddings\\"

In [5]:
file_name = get_most_recent_file(eval_manager_path).split('.')[0]
print(file_name)

31352-Martina


In [6]:
# Set patient and record

# ["chb01_03", "chb01_04", "chb01_15", "chb01_16", "chb01_18", "chb01_21", "chb01_26"]
patient_id = "chb01"
record_id = "21"

# No penalizzazione temporale
# file_name = 16688-Martina

# Penalizzazione temporale
# file_name = 26040-Martina

if patient_id[:2] == "PN":      # Frequency Siena dataset
    frequency = 512
elif patient_id[:3] == "chb":   # Frequency CHB-MIT dataset
    frequency = 256

In [7]:
# Load information related to the EEG of patient_id, record_id
with open(f"EEG_data\EEG_data_params_{patient_id}_{record_id}.pkl", "rb") as f:
    loaded_variables = pickle.load(f)

indices = loaded_variables["indici"]
Start = loaded_variables["Start"]
End = loaded_variables["End"]
seizure_starts = loaded_variables["seizure_starts"]
seizure_ends = loaded_variables["seizure_ends"]
seizure_class = loaded_variables["seizure_class"]

In [8]:
# Load logs
with open(logs_path + file_name + '.info', "r") as file:
    content = file.read()

# Load eval_manager
with open(eval_manager_path + file_name + '.pkl', 'rb') as f:
    eval_manager = pickle.load(f)

In [9]:
graph_instance_list = []

pairs = eval_manager._evaluators[0].get_instance_explanation_pairs()
for g1 in eval_manager._evaluators[0].dataset.instances:
    graph_instance_list.append(g1)

In [12]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
import numpy as np

# --- Supponiamo che graph_instance_list sia la lista dei tuoi GraphInstance

def graph_instance_to_pyg_data(graph_instance):
    edge_index = torch.tensor(np.vstack(np.nonzero(graph_instance.data)), dtype=torch.long)
    x = torch.tensor(graph_instance.node_features, dtype=torch.float)
    edge_attr = None
    if graph_instance.edge_features is not None:
        edge_attr = torch.tensor(graph_instance.edge_features, dtype=torch.float)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

pyg_data_list = [graph_instance_to_pyg_data(g) for g in graph_instance_list]

loader = DataLoader(pyg_data_list, batch_size=16, shuffle=True)

# --- Parametri (modifica in base al tuo dataset)
node_feat_dim = pyg_data_list[0].x.shape[1]
latent_dim = 32
hidden_dim = 64
num_nodes = pyg_data_list[0].x.shape[0]  # se variabile, serve gestione più avanzata

class GraphEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_dim):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, latent_dim)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        return x

class GraphDecoder(torch.nn.Module):
    def __init__(self, latent_dim, num_nodes):
        super().__init__()
        self.linear = torch.nn.Linear(latent_dim, num_nodes * num_nodes)

    def forward(self, z):
        adj_recon = self.linear(z)
        adj_recon = adj_recon.view(-1, num_nodes, num_nodes)
        adj_recon = torch.sigmoid(adj_recon)
        return adj_recon

class GraphAutoencoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_dim, num_nodes):
        super().__init__()
        self.encoder = GraphEncoder(in_channels, hidden_channels, latent_dim)
        self.decoder = GraphDecoder(latent_dim, num_nodes)

    def forward(self, data):
        z = self.encoder(data.x, data.edge_index, data.batch)
        adj_recon = self.decoder(z)
        return adj_recon

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphAutoencoder(node_feat_dim, hidden_dim, latent_dim, num_nodes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

In [13]:
# --- Training loop
for epoch in range(50):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        adj_recon = model(batch)  # (batch_size, num_nodes, num_nodes)

        # Costruiamo matrice di adiacenza target batch-wise
        # batch.edge_index ha archi concatenati, dobbiamo fare matrice per ogni grafo
        # Qui semplifico supponendo num_nodes fisso e batch con grafi indipendenti
        
        batch_size = adj_recon.size(0)
        adj_true = torch.zeros((batch_size, num_nodes, num_nodes), device=device)
        
        node_offset = 0
        for i in range(batch_size):
            # Nodi appartenenti al grafo i
            node_mask = (batch.batch == i)
            node_indices = node_mask.nonzero(as_tuple=False).view(-1)

            # Trova archi per cui entrambi i nodi sono in node_indices
            edge_mask = (torch.isin(batch.edge_index[0], node_indices) & torch.isin(batch.edge_index[1], node_indices))

            edges = batch.edge_index[:, edge_mask]

            # Mappa gli indici dei nodi da globali a locali per il grafo i
            # node_indices contiene i nodi globali, dobbiamo riscrivere edges con indici locali (0..num_nodes_i-1)
            node_id_map = {node.item(): idx for idx, node in enumerate(node_indices)}

            edges_local = torch.zeros_like(edges)
            edges_local[0] = torch.tensor([node_id_map[n.item()] for n in edges[0]])
            edges_local[1] = torch.tensor([node_id_map[n.item()] for n in edges[1]])                                                                                                                                                                                                                    

            adj_true[i, edges_local[0], edges_local[1]] = 1.0
        
        loss = F.binary_cross_entropy(adj_recon, adj_true)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")


Epoch 1, Loss: 0.3782
Epoch 2, Loss: 0.3744
Epoch 3, Loss: 0.3736
Epoch 4, Loss: 0.3735
Epoch 5, Loss: 0.3728
Epoch 6, Loss: 0.3727
Epoch 7, Loss: 0.3724
Epoch 8, Loss: 0.3730
Epoch 9, Loss: 0.3725
Epoch 10, Loss: 0.3727
Epoch 11, Loss: 0.3731
Epoch 12, Loss: 0.3735
Epoch 13, Loss: 0.3727
Epoch 14, Loss: 0.3727
Epoch 15, Loss: 0.3730
Epoch 16, Loss: 0.3731
Epoch 17, Loss: 0.3727
Epoch 18, Loss: 0.3732
Epoch 19, Loss: 0.3731
Epoch 20, Loss: 0.3729
Epoch 21, Loss: 0.3728
Epoch 22, Loss: 0.3729
Epoch 23, Loss: 0.3727
Epoch 24, Loss: 0.3727
Epoch 25, Loss: 0.3730
Epoch 26, Loss: 0.3730


KeyboardInterrupt: 