In [1]:
import os
import torch
import random
import numpy as np

import networkx as nx
import matplotlib.pyplot as plt

from src.evaluation.evaluator_manager import EvaluatorManager
from src.evaluation.evaluator_manager_do import EvaluatorManager as PairedEvaluatorManager
from src.evaluation.evaluator_manager_triplets import EvaluatorManager as TripletsEvaluatorManager

from src.utils.context import Context

## Eliminazione cache dataset

In [2]:
elimina_cache_dataset = True

if elimina_cache_dataset:
    cartella = ".\data\cache\datasets"

    for file in os.listdir(cartella):
        if file.startswith("GCS-"):  # Controlla se il file inizia con "GCS-"
            percorso_completo = os.path.join(cartella, file)
            if os.path.isfile(percorso_completo):
                os.remove(percorso_completo)
                print(f"Eliminato: {percorso_completo}")


Eliminato: .\data\cache\datasets\GCS-189d68770f99189694c24d6a27f87039


## Run Gretel experiment

In [3]:
%%capture captured_output

path = 'config/GCS-GCN.jsonc'
# path = 'config/GCS-SVM.jsonc'
# path = 'config/GCS-KNN.jsonc'

print(f"Generating context for: {path}")
context = Context.get_context(path)
context.run_number = -1

context.logger.info(f"Executing: {context.config_file} Run: {context.run_number}")
context.logger.info(
    "Creating the evaluation manager......................................................."
)

if 'doe-triplets' in context.conf:
    context.logger.info("Creating the TRIPLET evaluators........................................................")
    eval_manager = TripletsEvaluatorManager(context)
if 'do-pairs' in context.conf:
    context.logger.info("Creating the PAIRED evaluators...............................................................")
    eval_manager = PairedEvaluatorManager(context)
else:
    context.logger.info("Creating the evaluators...............................................................")
    eval_manager = EvaluatorManager(context)

context.logger.info(
    "Evaluating the explainers............................................................."
)

eval_manager.evaluate()

2025-01-28 10:22:43,-1405591510 | INFO | 23968 - Executing: config/GCS-GCN.jsonc Run: -1
2025-01-28 10:22:43,-1405591258 | INFO | 23968 - Creating the evaluation manager.......................................................
2025-01-28 10:22:43,-1405591179 | INFO | 23968 - Creating the PAIRED evaluators...............................................................
2025-01-28 10:22:43,-1405590943 | INFO | 23968 - Creating: GCS-189d68770f99189694c24d6a27f87039
2025-01-28 10:22:43,-1405590927 | INFO | 23968 - Instantiating: src.dataset.generators.gcs.GCS
2025-01-28 10:22:52,-1405582126 | INFO | 23968 - Saved: GCS-189d68770f99189694c24d6a27f87039
2025-01-28 10:22:52,-1405582045 | INFO | 23968 - Created: GCS-189d68770f99189694c24d6a27f87039
2025-01-28 10:22:52,-1405581999 | INFO | 23968 - Instantiating: src.oracle.nn.gcn.DownstreamGCN
2025-01-28 10:22:52,-1405581985 | INFO | 23968 - Instantiating: torch.optim.RMSprop
2025-01-28 10:22:52,-1405581985 | INFO | 23968 - Instantiating: torch.nn.

ZeroDivisionError: division by zero

In [4]:
import pickle

# Ricava il nome del file di log più recente
file_name = max(os.listdir("..\\..\\explainability\GRETEL-repo\\output\\logs\\"), key=lambda f: os.path.getmtime(os.path.join("..\\..\\explainability\GRETEL-repo\\output\\logs\\", f)))
file_name = file_name.split('.')[0]
file_name = "..\\..\\explainability\GRETEL-repo\\output\\eval_manager\\" + file_name + ".pkl"

with open(file_name, 'wb') as f:
    pickle.dump(eval_manager, f)

In [15]:
print(captured_output.stdout)

Generating context for: config/GCS-GCN.jsonc
----------------------------------------------------------------
data_instance.id: 5
data_instance.label: 0
Prediction: 0
----------------------------------------------------------------
data_instance.id: 0
data_instance.label: 0
Prediction: 0
node embeddings: tensor([[0.0111, 0.0131]], dtype=torch.float64)
----------------------------------------------------------------
data_instance.id: 1
data_instance.label: 0
Prediction: 0
node embeddings: tensor([[0.0087, 0.0103]], dtype=torch.float64)
----------------------------------------------------------------
data_instance.id: 2
data_instance.label: 0
Prediction: 0
node embeddings: tensor([[0.0095, 0.0109]], dtype=torch.float64)
----------------------------------------------------------------
data_instance.id: 3
data_instance.label: 0
Prediction: 0
node embeddings: tensor([[0.0029, 0.0034]], dtype=torch.float64)
----------------------------------------------------------------
data_instance.id: 4


In [16]:
file_name = max(os.listdir("..\\..\\explainability\GRETEL-repo\\output\\logs\\"), key=lambda f: os.path.getmtime(os.path.join("..\\..\\explainability\GRETEL-repo\\output\\logs\\", f)))
file_name = file_name.split('.')[0]
file_name = "..\\..\\explainability\GRETEL-repo\\output\\embeddings\\" + file_name + ".txt"

with open(file_name, "w") as f:
    f.write(captured_output.stdout)

## Caricamento eval_manager

In [7]:
import os

# Ricava il nome del file di log più recente
def get_most_recent_file(folder_path):
    return max(os.listdir(folder_path), key=lambda f: os.path.getmtime(os.path.join(folder_path, f)))

folder_path = "..\\..\\explainability\GRETEL-repo\\output\\eval_manager\\"
most_recent_file_name = get_most_recent_file(folder_path)

file_path = folder_path + most_recent_file_name
file_path

'..\\..\\explainability\\GRETEL-repo\\output\\eval_manager\\23968-Martina.pkl'

In [8]:
import pickle

# Carica il dizionario da un file
with open(file_path, 'rb') as f:
    eval_manager = pickle.load(f)

In [9]:
import re

def get_oracle_and_explainer_names(eval_manager):
    string = eval_manager.evaluators[0].name

    oracle_name = re.search(r'using_(.*?)Oracle', string).group(1)
    if oracle_name == '': # Da controllare se è vero
        oracle_name = 'GCN'
    explainer_name = re.search(r'for_(.*?)Explainer', string).group(1)

    return f"{oracle_name} oracle - {explainer_name} explainer"

In [10]:
# eval_manager.evaluators[0].get_instance_and_counterfactual_classifications() # Non funziona

## Visualizzazione dei risultati

In [11]:
import pickle

# Calcolo delle coordinate dei nodi

# OPZIONE 1: nodi su griglia regolare
# def generate_grid_positions(num_rows, num_cols):
#     """
#     Genera le posizioni dei nodi in una griglia regolare.

#     :param num_rows: Numero di righe nella griglia
#     :param num_cols: Numero di colonne nella griglia
#     :return: Dizionario con le posizioni dei nodi
#     """
#     positions = {}
#     node_id = 0
#     for r in range(num_rows):
#         for c in range(num_cols):
#             positions[node_id] = (c, -r)  # Posizionamento dei nodi nella griglia
#             node_id += 1
#     return positions

# Definizione delle posizioni fisse dei nodi con una griglia di 4x6
# fixed_positions = generate_grid_positions(4, 6)

#######################################################################################

# OPZIONE 2: coordinate predefinite
# Carica il dizionario delle coordinate dei punti medi da file
with open('dizionario_punti_medi.pkl', 'rb') as f:
    fixed_positions = pickle.load(f)

In [12]:
# # Funzione per disegnare la differenza tra grafo originario e sottografo date le matrici di adiacenza
# def plot_graph_difference(A_grafo, A_sottografo):
    
#     # Crea grafi da matrici di adiacenza
#     # G_grafo = nx.from_numpy_array(A_grafo)
#     G_sottografo = nx.from_numpy_array(A_sottografo)
#     G_differenza = nx.from_numpy_array(A_grafo - A_sottografo)

#     # Disegna il grafo differenza con linee tratteggiate
#     nx.draw(G_differenza, ax=ax, pos=fixed_positions, with_labels=True, node_color='skyblue', node_size=250,
#             font_size=8, edge_color='gray', style='dashed')

#     # Disegna il sottografo con linee piene
#     nx.draw(G_sottografo, ax=ax, pos=fixed_positions, with_labels=True, node_color='skyblue', node_size=250,
#             font_size=8, edge_color='black')

#     plt.title("Differenza tra il grafo principale e il sottografo")
#     plt.show()

In [13]:
import networkx as nx
import matplotlib.pyplot as plt


# Funzione per disegnare la differenza tra grafo originario e la spiegazione date le matrici di adiacenza
def plot_graph_difference_updated(A_orig, A_expl):

    # Crea grafi da matrici di adiacenza
    G_orig = nx.from_numpy_array(A_orig)

    A_differenza = A_orig - A_expl
    archi_aggiunti = nx.from_numpy_array(A_differenza == -1)
    archi_rimossi = nx.from_numpy_array(A_differenza == 1)

    # Disegna il grafo originario
    nx.draw(G_orig, ax=ax, pos=fixed_positions, with_labels=True, node_color='skyblue', node_size=250,
            font_size=8, edge_color='black')

    # Disegna archi aggiunti
    nx.draw(archi_aggiunti, ax=ax, pos=fixed_positions, with_labels=True, node_color='skyblue', node_size=250,
            font_size=8, edge_color='green')

    # Disegna archi rimossi
    nx.draw(archi_rimossi, ax=ax, pos=fixed_positions, with_labels=True, node_color='skyblue', node_size=250,
            font_size=8, edge_color='red')

    plt.title("Grafo originario vs explainer")
    plt.show()

In [14]:
# Disegna i grafi e crea animazione

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.animation import writers
from IPython.display import HTML
import networkx as nx

%matplotlib notebook

img = plt.imread("nodi-vuoto.png")

fig, ax = plt.subplots(figsize=(6,6))
fig.suptitle(get_oracle_and_explainer_names(eval_manager), fontsize=12.5, x = 0.51)

def animate(i):
    ax.clear()
    ax.imshow(img)

    g1, g2 = eval_manager._evaluators[0].get_instance_explanation_pairs()[i]
    # plot_graph_difference(g1.data, g2.data)
    plot_graph_difference_updated(g1.data, g2.data)

    ax.set_title(f'Graph id: {g1.id} (classe: {g1.label})')

ani = FuncAnimation(fig, animate, frames=len(eval_manager._evaluators[0].get_instance_explanation_pairs()), repeat=False, blit=True)
ani.save("..\\..\\explainability\GRETEL-repo\\output\\video\\evoluzione_sottografi_" + most_recent_file_name.split('.')[0] + ".mp4", writer="ffmpeg", fps=10)

HTML(ani.to_jshtml())

<IPython.core.display.Javascript object>

TypeError: object of type 'NoneType' has no len()

## _______________________________________________________________

In [13]:
import datetime

now = datetime.datetime.now()
print("Ultima esecuzione completa:", now.strftime("%d/%m/%Y, ore %H:%M"))

Ultima esecuzione completa: 07/01/2025, ore 09:20


In [14]:
# Per lettura log e plot metriche: notebook "Lettura_EEG"