In [None]:
from scripts.gnn_model import gnn
import os
import torch
import matplotlib.pyplot as plt
import copy as copy
import time as time
import statistics
from scripts import utils
import random
from sklearn.metrics import (f1_score, average_precision_score, recall_score, roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def chunked_list(lst, chunk_size):

    '''Se encarga de seleccionar los nombres de los chunks de 
    grafos que serán cargados en RAM desde el disco para evitar 
    saturación'''

    for i in range(0, len(lst), chunk_size):
        yield lst[i:i + chunk_size]

class FocalLoss(torch.nn.Module):

    '''Función de pérdida, el parámetro solo_BCE activa y desactiva
    la focalización en la clase minoritaria necesaria para
    distribuciones desbalanceadas'''

    def __init__(self, alpha=0.25, gamma=2, solo_BCE=True):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.solo_BCE = solo_BCE

    def forward(self, logits, targets):
        BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            logits, targets, reduction='none')
        probs = torch.sigmoid(logits)
        pt = torch.where(targets == 1, probs, 1 - probs)
        focal_weight = self.alpha * (1 - pt) ** self.gamma
        if self.solo_BCE:
            out = BCE_loss.mean() 
        else:
            out = (focal_weight * BCE_loss).mean()
        return out

def train(model, train_dir, test_dir, config_number, epochs, lr, solo_BCE, chunk_size=50):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-5, cooldown=3)

    train_files = {}
    for fname in os.listdir(train_dir): #si los eventos estan divididos en subgrafos se agrupan losd e un mismo evento
        if not fname.endswith('.pt'): continue
        ev = fname.split('_')[0]
        train_files.setdefault(ev, []).append(os.path.join(train_dir, fname))
    for ev in train_files:
        train_files[ev].sort()
    train_events = sorted(train_files.keys())

    test_batches = [] #los grafos de validación siempre en memoria
    for fname in os.listdir(test_dir):
        if not fname.endswith('.pt'): continue
        g = torch.load(os.path.join(test_dir, fname), weights_only=False)
        for attr in ['hit_id', 'particle_id']:
            if hasattr(g, attr): delattr(g, attr)
        test_batches.append(g)

    all_pw = []
    for ev in train_events: #estimación de parámetros de focalización
        for chunk in chunked_list(train_files[ev], chunk_size):
            for path in chunk:
                g = torch.load(path, weights_only=False)
                y = g.y
                pos = int(y.sum().item())
                neg = int(y.numel() - pos)
                if pos > 0:
                    all_pw.append(neg/pos)
            del g
            torch.cuda.empty_cache()
            break
    mean_pw = statistics.mean(all_pw) if all_pw else 1.0
    alpha = mean_pw / (mean_pw + 1)
    print(f"Alpha para FocalLoss = {alpha:.4f}")
    loss_fn = FocalLoss(alpha=alpha, gamma=2, solo_BCE=solo_BCE)

    train_losses, train_accs = [], []
    test_losses,  test_accs  = [], []
    epoch_f1s, epoch_pr_aucs = [], []
    epoch_rec_pos, epoch_rec_neg = [], []
    epoch_roc_aucs = []

    y_true_final = y_probs_final = y_pred_final = None #par aultima validación

    # TRAIN ==================================
    for epoch in range(1, epochs+1):
        model.train()
        tot_corr = tot_edges = 0

        perm = train_events.copy()
        random.seed(epoch)
        random.shuffle(perm)

        for ev in perm:
            for chunk in chunked_list(train_files[ev], chunk_size):
                graphs = [torch.load(p, weights_only=False).to(device) for p in chunk]
                for g in graphs:
                    for attr in ['hit_id', 'particle_id']:
                        if hasattr(g, attr): delattr(g, attr)
                    y = g.y.float().to(device)
                    logits = model(g).view(-1)
                    optimizer.zero_grad()
                    loss = loss_fn(logits, y)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    preds = (torch.sigmoid(logits) > 0.5).float()
                    tot_corr += (preds == y).sum().item()
                    tot_edges += y.numel()
                del graphs
                torch.cuda.empty_cache()

        train_loss = loss.item()
        train_acc  = tot_corr / tot_edges
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # TEST ==================================
        model.eval()
        test_corr = test_tot = 0
        all_labels, all_probs = [], []
        with torch.no_grad():
            for g in test_batches:
                g = g.to(device)
                y = g.y.float().to(device)
                logits = model(g).view(-1)
                _ = loss_fn(logits, y).item()
                prob = torch.sigmoid(logits)
                test_corr += (prob.round() == y).sum().item()
                test_tot  += y.numel()
                all_labels.append(y.cpu())
                all_probs.append(prob.cpu())

        test_loss = sum(
            loss_fn(model(g.to(device)).view(-1), g.y.float().to(device)).item()
            for g in test_batches
        ) / len(test_batches)
        test_acc = test_corr / test_tot
        test_losses.append(test_loss)
        test_accs.append(test_acc)

        #para metricas
        y_t = torch.cat(all_labels).numpy()
        y_p = torch.cat(all_probs).numpy()
        y_h = (y_p > 0.5).astype(int)

        epoch_f1s.append(f1_score(y_t, y_h))
        epoch_pr_aucs.append(average_precision_score(y_t, y_p))
        epoch_rec_pos.append(recall_score(y_t, y_h, pos_label=1))
        epoch_rec_neg.append(recall_score(y_t, y_h, pos_label=0))
        fpr, tpr, _ = roc_curve(y_t, y_p)
        epoch_roc_aucs.append(auc(fpr, tpr))

        #para el plot final
        y_true_final = y_t
        y_probs_final = y_p
        y_pred_final = y_h

        print(
            f"\rEpoch {epoch:3d} | "
            f"Train L: {train_loss:.4f} Acc: {train_acc:.4f} | "
            f"Test  L: {test_loss:.4f} Acc: {test_acc:.4f} | "
            f"F1: {epoch_f1s[-1]:.4f} PR–AUC: {epoch_pr_aucs[-1]:.4f} ROC–AUC: {epoch_roc_aucs[-1]:.4f}",
            end="", flush=True
        )
        scheduler.step(test_loss)

 
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    #Loss
    axes[0,0].plot(train_losses, label="Train Loss")
    axes[0,0].plot(test_losses,  label="Test Loss")
    axes[0,0].set_title("Loss por Época")
    axes[0,0].set_xlabel("Época"); axes[0,0].set_ylabel("Loss")
    axes[0,0].legend(); axes[0,0].grid(True)

    #accuracy
    axes[0,1].plot(train_accs, label="Train Acc")
    axes[0,1].plot(test_accs,  label="Test Acc")
    axes[0,1].set_title("Accuracy por Época")
    axes[0,1].set_xlabel("Época"); axes[0,1].set_ylabel("Accuracy")
    axes[0,1].legend(); axes[0,1].grid(True)
    
    #métricas
    epochs_range = list(range(1, epochs+1))
    axes[0,2].plot(epochs_range, epoch_pr_aucs,  label="PR–AUC")
    axes[0,2].plot(epochs_range, epoch_f1s,      label="F1")
    axes[0,2].plot(epochs_range, epoch_rec_pos,  label="Recall Positiva")
    axes[0,2].plot(epochs_range, epoch_rec_neg,  label="Recall Negativa")
    axes[0,2].plot(epochs_range, epoch_roc_aucs, label="ROC–AUC")
    axes[0,2].set_title("Métricas por Época")
    axes[0,2].set_xlabel("Época"); axes[0,2].set_ylabel("Valor")
    axes[0,2].set_ylim(0,1.1); axes[0,2].legend(loc="lower right", ncol=2)
    axes[0,2].grid(True)

    #matriz de confusión
    cm = confusion_matrix(y_true_final, y_pred_final)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Neg","Pos"])
    disp.plot(ax=axes[1,0], cmap=plt.cm.Blues, colorbar=False)
    axes[1,0].set_title("Matriz de Confusión (Test)")

    #roc
    fpr_f, tpr_f, _ = roc_curve(y_true_final, y_probs_final)
    roc_auc_f = auc(fpr_f, tpr_f)
    axes[1,1].plot(fpr_f, tpr_f, label=f"AUC = {roc_auc_f:.4f}")
    axes[1,1].plot([0,1],[0,1],"k--", label="Aleatorio")
    axes[1,1].set_title("Curva ROC (Test)")
    axes[1,1].set_xlabel("FPR"); axes[1,1].set_ylabel("TPR")
    axes[1,1].set_xlim(0,1); axes[1,1].set_ylim(0,1)
    axes[1,1].legend(loc="lower right"); axes[1,1].grid(True)
    
    axes[1,2].axis('off')

    plt.tight_layout()
    final_f1 = epoch_f1s[-1]

    if solo_BCE:
        fig_path = f"models/plot_model_f1_{final_f1:.7f}.png"
        model_path = f"models/model_f1_{final_f1:.7f}_config_{config_number}.pt"
    else:
        fig_path = f"models/plot_model_f1_{final_f1:.7f}_pt_filter.png"
        model_path = f"models/model_f1_{final_f1:.7f}_config_{config_number}_pt_filter.pt"

    fig.savefig(fig_path, dpi=300, bbox_inches="tight")
    plt.show()

    torch.save(model.state_dict(), model_path)
    print(f"Gráfica guardada en: {fig_path}")
    print(f"Modelo guardado en: {model_path}")

    return model


# Train con filtro ideal

In [None]:
config_number = 2
model_hyper, preprocess_hyper = utils.load_hyper(name = f'config{config_number}')

model = gnn(**model_hyper)

trained = train(
    model,
    train_dir      = 'data/dataset_graphs_ideal/train',
    test_dir       = 'data/dataset_graphs_ideal/test',
    config_number  = 2,
    epochs         = 300,
    lr             = 1e-3,
    solo_BCE = True,
    chunk_size     = 50) # eventos cargados de cada vez


# Train con filtro real

In [None]:
%%time
# ejecutar esta celda solo si la celda anterior no ha sido ejecutada
config_number = 2
model_hyper, preprocess_hyper = utils.load_hyper(name = f'config{config_number}')

model = gnn(**model_hyper)

trained = train(
    model,
    train_dir      = 'data/dataset_graphs_real/train',
    test_dir       = 'data/dataset_graphs_real/test',
    config_number  = config_number,
    epochs         = 150,
    lr             = 1e-3,
    solo_BCE = False,
    chunk_size     = 50)