In [None]:
# 1. Instalar librer√≠a de evaluaci√≥n musical
!pip install mir_eval

# 2. Descomprimir el dataset
# IMPORTANTE: Aseg√∫rate de que la ruta del input coincida con donde Kaggle mont√≥ tu dataset.
# Por lo general es /kaggle/input/nombre-de-tu-dataset/processed_data_HPPNET_100.rar
import os
if not os.path.exists("/kaggle/working/processed_data_HPPNET_100"):
    print("üìÇ Descomprimiendo dataset... esto puede tardar un minuto.")
    # Ajusta 'tu-dataset-name' al nombre real en Kaggle
    !unrar x "/kaggle/input/tu-dataset-name/processed_data_HPPNET_100.rar" "/kaggle/working/"
    print("‚úÖ Descompresi√≥n terminada.")
else:
    print("üìÇ El dataset ya existe, saltando descompresi√≥n.")

In [None]:
import os
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import mir_eval 
from sklearn.metrics import f1_score, precision_score, recall_score, mean_squared_error
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import sys
from scipy.signal import find_peaks  # <--- NUEVO: Para Peak Picking

# ==========================================
# 0. CONFIGURACI√ìN Y EST√ÅNDARES
# ==========================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚öôÔ∏è Usando dispositivo: {DEVICE}")

SEED = 42
SR = 16000           
HOP_LENGTH = 512     
SEGMENT_FRAMES = 320 
BINS_PER_OCTAVE = 12  
# Nota: En Kaggle working directory es donde descomprimimos
DATA_PATH = Path("/kaggle/working/processed_data_HPPNET_100") 

# Hyperpar√°metros Optimizados para Kaggle
BATCH_SIZE = 32           # Subido de 4 a 32 para aprovechar GPU P100/T4
FINAL_EPOCHS = 50         
LEARNING_RATE = 0.0006    
PATIENCE_LR = 5           # Un poco m√°s de paciencia
FACTOR_LR = 0.5           
NUM_WORKERS = 4           # Kaggle tiene buena CPU I/O

# Umbrales
THRESHOLD_ONSET = 0.35    # Ajustado para Peak Picking
THRESHOLD_FRAME = 0.6     #ANTES 0.5
THRESHOLD_OFFSET = 0.4

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

# ==========================================
# 1. DATASET
# ==========================================
class PianoDataset(Dataset):
    def __init__(self, processed_dir, split='train', val_split=0.15):
        self.processed_dir = Path(processed_dir)
        p = self.processed_dir / "inputs_hcqt"
        if not p.exists(): raise RuntimeError(f"‚ùå Ruta no existe: {p}")
        
        all_files = sorted(list(p.glob("*.npy")))
        if len(all_files) == 0: raise RuntimeError(f"‚ùå No se encontraron archivos .npy en {p}")
        
        random.Random(SEED).shuffle(all_files)
        split_idx = int(len(all_files) * (1 - val_split))
        self.files = all_files[:split_idx] if split == 'train' else all_files[split_idx:]
        
        self.segments = []
        # Pre-calculamos segmentos
        print(f"   Calculando segmentos para {split}...")
        for idx, f in enumerate(self.files):
            try:
                # mmap_mode='r' lee solo la cabecera para ser r√°pido
                shape = np.load(f, mmap_mode='r').shape
                n_frames = shape[0]
                num_clips = math.ceil(n_frames / SEGMENT_FRAMES)
                for i in range(num_clips):
                    start = i * SEGMENT_FRAMES
                    end = min(start + SEGMENT_FRAMES, n_frames)
                    if (end - start) > 30: 
                        self.segments.append((idx, start, end))
            except: continue
        print(f"   ‚úÖ {split.upper()}: {len(self.segments)} segmentos cargados.")

    def __len__(self): return len(self.segments)

    def __getitem__(self, idx):
        file_idx, start, end = self.segments[idx]
        fid = self.files[file_idx].name
        try:
            base = self.processed_dir
            # Carga con mmap para velocidad
            hcqt = np.load(base / "inputs_hcqt" / fid, mmap_mode='r')[start:end] 
            onset = np.load(base / "targets_onset" / fid, mmap_mode='r')[start:end]
            frame = np.load(base / "targets_frame" / fid, mmap_mode='r')[start:end]
            offset = np.load(base / "targets_offset" / fid, mmap_mode='r')[start:end]
            vel = np.load(base / "targets_velocity" / fid, mmap_mode='r')[start:end]
            
            curr_len = hcqt.shape[0]
            if curr_len < SEGMENT_FRAMES:
                pad = SEGMENT_FRAMES - curr_len
                hcqt = np.pad(hcqt, ((0, pad), (0,0), (0,0)))
                onset = np.pad(onset, ((0, pad), (0,0)))
                frame = np.pad(frame, ((0, pad), (0,0)))
                offset = np.pad(offset, ((0, pad), (0,0)))
                vel = np.pad(vel, ((0, pad), (0,0)))
            
            hcqt_t = torch.tensor(hcqt).permute(2, 1, 0).float()
            
            return {
                "hcqt": hcqt_t,
                "onset": torch.tensor(onset).float(),
                "frame": torch.tensor(frame).float(),
                "offset": torch.tensor(offset).float(),
                "velocity": torch.tensor(vel).float()
            }
        except Exception as e:
            print(f"Error loading {fid}: {e}")
            z = torch.zeros(SEGMENT_FRAMES, 88)
            return {"hcqt": torch.zeros(3, 88, SEGMENT_FRAMES), "onset": z, "frame": z, "offset": z, "velocity": z}

# ==========================================
# 2. ARQUITECTURA FINAL (RESIDUAL + HDCONV + INSTANCENORM)
# ==========================================
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        p = torch.sigmoid(inputs)
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
        p_t = p * targets + (1 - p) * (1 - targets)
        loss = ce_loss * ((1 - p_t) ** self.gamma)
        if self.alpha >= 0:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            loss = alpha_t * loss
        return loss.mean()

# --- NUEVO: Bloque Residual ---
class ResidualBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.InstanceNorm2d(out_c, affine=True)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.InstanceNorm2d(out_c, affine=True)
        
        self.downsample = None
        if in_c != out_c:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=1, bias=False),
                nn.InstanceNorm2d(out_c, affine=True)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity  # <--- RESIDUAL CONNECTION
        out = self.relu(out)
        return out

# --- HDConv Corregida (Sin dilataci√≥n negativa) ---
class HDConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.convs = nn.ModuleList()
        # Arm√≥nicos superiores (el 0.5, 1 y 2 ya vienen en el input HCQT)
        harmonics = [1, 2, 3, 4] 
        
        for h in harmonics:
            if h == 1:
                d = 1 
            else:
                d = int(np.round(BINS_PER_OCTAVE * np.log2(h)))
            
            # Padding asim√©trico para mantener tama√±o con dilataci√≥n vertical
            self.convs.append(nn.Conv2d(
                in_channels, out_channels, 
                kernel_size=(3, 3), 
                padding=(d, 1), 
                dilation=(d, 1)
            ))
        self.fusion = nn.Conv2d(out_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x_sum = sum([conv(x) for conv in self.convs])
        return self.fusion(x_sum)

class AcousticModel(nn.Module):
    def __init__(self, in_channels, base_channels):
        super().__init__()
        # Entrada: Adapta HCQT (3 canales) a Base Channels
        self.input_conv = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(base_channels, affine=True),
            nn.ReLU()
        )
        
        # Bloques Residuales para profundizar sin perder se√±al
        self.res1 = ResidualBlock(base_channels, base_channels)
        self.res2 = ResidualBlock(base_channels, base_channels)
        
        # Visi√≥n Arm√≥nica
        self.hdc = HDConv(base_channels, base_channels)
        self.hdc_bn = nn.InstanceNorm2d(base_channels, affine=True)
        self.hdc_relu = nn.ReLU()
        
        # Contexto (M√°s residuales)
        self.context = nn.Sequential(
            ResidualBlock(base_channels, base_channels),
            ResidualBlock(base_channels, base_channels),
            ResidualBlock(base_channels, base_channels)
        )

    def forward(self, x):
        x = self.input_conv(x)
        x = self.res1(x)
        x = self.res2(x)
        
        # Sumamos la visi√≥n arm√≥nica (Residual Style)
        x_hdc = self.hdc(x)
        x_hdc = self.hdc_relu(self.hdc_bn(x_hdc))
        x = x + x_hdc 
        
        x = self.context(x)
        return x

class FG_LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        # LSTM Bidireccional
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.proj = nn.Linear(hidden_dim * 2, 1)
        
    def forward(self, x):
        b, c, f, t = x.shape
        # Permutar para LSTM (Batch*Freq, Time, Channels)
        x = x.permute(0, 2, 3, 1).reshape(b * f, t, c)
        self.lstm.flatten_parameters()
        output, _ = self.lstm(x)
        output = self.proj(output)
        output = output.view(b, f, t)
        return output.permute(0, 2, 1) 

class HPPNet(nn.Module):
    def __init__(self, in_channels=4, base_channels=24, lstm_hidden=128):  # <--- LSTM SUBIDO A 128
        super().__init__()
        self.acoustic_onset = AcousticModel(in_channels, base_channels)
        self.acoustic_other = AcousticModel(in_channels, base_channels)
        
        self.head_onset = FG_LSTM(base_channels, lstm_hidden)
        
        concat_dim = base_channels * 2
        self.head_frame = FG_LSTM(concat_dim, lstm_hidden)
        self.head_offset = FG_LSTM(concat_dim, lstm_hidden)
        self.head_velocity = FG_LSTM(concat_dim, lstm_hidden)

    def forward(self, x):
        feat_onset = self.acoustic_onset(x)
        logits_onset = self.head_onset(feat_onset)
        
        feat_onset_detached = feat_onset.detach()
        feat_other = self.acoustic_other(x)
        feat_combined = torch.cat([feat_other, feat_onset_detached], dim=1)
        
        logits_frame = self.head_frame(feat_combined)
        logits_offset = self.head_offset(feat_combined)
        
        #logits_offset_raw = self.head_offset(feat_combined)
        #logits_offset = logits_offset_raw * torch.sigmoid(logits_frame.detach())
        
        logits_velocity = self.head_velocity(feat_combined)
        
        return logits_onset, logits_frame, logits_offset, logits_velocity

# ==========================================
# 3. UTILS (CON PEAK PICKING)
# ==========================================
def tensor_to_notes(onset_pred, frame_pred, offset_pred, velocity_pred=None, t_onset=0.35, t_frame=0.6, t_offset=0.4):
    """
    Decodificaci√≥n con condicionamiento expl√≠cito de Offset.
    Nota: He a√±adido offset_pred a los argumentos.
    """
    notes = []
    for pitch in range(88):
        # 1. Peak Picking (Igual que antes)
        peaks, _ = find_peaks(onset_pred[:, pitch], height=t_onset, distance=2)
        
        for onset_frame in peaks:
            # 2. Validaci√≥n de Sustain (Igual que antes)
            check_frame = min(onset_frame + 1, frame_pred.shape[0] - 1)
            if frame_pred[check_frame, pitch] < t_frame:
                continue 
                
            # 3. Buscar Offset (MEJORADO)
            end_frame = onset_frame + 1
            # Buscamos hasta que se acabe el frame O encontremos un offset fuerte
            while end_frame < frame_pred.shape[0]:
                # Condici√≥n A: El frame sigue activo
                frame_active = frame_pred[end_frame, pitch] > t_frame
                
                # Condici√≥n B: NO hay un offset fuerte en este punto
                # (Si offset > t_offset, is_offset_hit es True, y paramos el loop)
                is_offset_hit = offset_pred[end_frame, pitch] > t_offset
                
                if frame_active and not is_offset_hit:
                    end_frame += 1
                else:
                    # Hemos encontrado el final (ya sea por ca√≠da de frame o por presencia de offset)
                    break
            
            # Ajuste fino: Si paramos por un Offset, incluimos ese frame como el final
            if end_frame < frame_pred.shape[0] and offset_pred[end_frame, pitch] > t_offset:
                 end_frame += 1

            # 4. Filtro duraci√≥n (Igual)
            if end_frame - onset_frame > 2:
                onset_time = onset_frame * HOP_LENGTH / SR
                offset_time = end_frame * HOP_LENGTH / SR
                
                # 5. Velocity (Igual)
                vel = 0
                if velocity_pred is not None:
                    vel_seg = velocity_pred[onset_frame:min(end_frame, onset_frame+5), pitch]
                    vel = np.mean(vel_seg) if len(vel_seg) > 0 else 0
                
                notes.append([onset_time, offset_time, pitch + 21, vel])
    return notes

def compute_metrics_standard(ref_notes_batch, est_notes_batch):
    total_tp, total_fp, total_fn = 0, 0, 0
    for ref_notes, est_notes in zip(ref_notes_batch, est_notes_batch):
        ref_arr = np.array(ref_notes)
        est_arr = np.array(est_notes)
        if len(ref_arr) == 0 and len(est_arr) == 0: continue
        if len(ref_arr) == 0:
            total_fp += len(est_arr); continue
        if len(est_arr) == 0:
            total_fn += len(ref_arr); continue

        # Para mir_eval, columnas 0 y 1 son tiempos, columna 2 es pitch
        ref_int, ref_p = ref_arr[:, :2], ref_arr[:, 2]
        est_int, est_p = est_arr[:, :2], est_arr[:, 2]
        
        matched = mir_eval.transcription.match_notes(
            ref_int, ref_p, est_int, est_p, onset_tolerance=0.05, offset_ratio=None
        )
        tp = len(matched)
        total_tp += tp
        total_fp += (len(est_p) - tp)
        total_fn += (len(ref_p) - tp)
        
    p = total_tp / (total_tp + total_fp + 1e-8)
    r = total_tp / (total_tp + total_fn + 1e-8)
    f1 = 2 * p * r / (p + r + 1e-8)
    return f1, p, r

def plot_training_history(csv_path="training_log_kaggle.csv"):
    if not os.path.exists(csv_path): return
    df = pd.read_csv(csv_path)
    
    sns.set_theme(style="whitegrid", context="paper")
    fig, axes = plt.subplots(2, 2, figsize=(16, 10))
    fig.suptitle('HPPNet Optimized (Residual + PeakPicking) Metrics', fontsize=16)

    sns.lineplot(data=df, x='epoch', y='train_loss', label='Train', ax=axes[0,0])
    sns.lineplot(data=df, x='epoch', y='val_loss', label='Val', ax=axes[0,0], linestyle='--')
    axes[0,0].set_title('Loss')

    sns.lineplot(data=df, x='epoch', y='onset_f1', label='F1', ax=axes[0,1], color='g')
    sns.lineplot(data=df, x='epoch', y='onset_p', label='Precision', ax=axes[0,1], linestyle=':', alpha=0.6)
    sns.lineplot(data=df, x='epoch', y='onset_r', label='Recall', ax=axes[0,1], linestyle=':', alpha=0.6)
    axes[0,1].set_title('Onset Metrics')

    sns.lineplot(data=df, x='epoch', y='frame_f1', label='Frame F1', ax=axes[1,0])
    sns.lineplot(data=df, x='epoch', y='offset_f1', label='Offset F1', ax=axes[1,0], color='orange')
    axes[1,0].set_title('Frame & Offset F1')

    sns.lineplot(data=df, x='epoch', y='velocity_mse', color='purple', ax=axes[1,1])
    axes[1,1].set_title('Velocity MSE')

    plt.tight_layout()
    plt.savefig("training_results_kaggle.png", dpi=200)
    print("üìä Gr√°ficas guardadas.")

# ==========================================
# 4. MAIN (ESTRUCTURA ORIGINAL OPTIMIZADA)
# ==========================================
if __name__ == "__main__":
    print(f"\nüöÄ HPPNET-SP FINAL KAGGLE Training ({DEVICE})")
    print(f"üîπ Config: Batch Size {BATCH_SIZE} | T4x2 (DataParallel) | Partial Val | Real-time Plotting")
    
    # 1. Cargar Datos
    train_ds = PianoDataset(DATA_PATH, split='train')
    val_ds = PianoDataset(DATA_PATH, split='val')
    
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    
    # 2. Inicializar Modelo
    model = HPPNet(in_channels=4, lstm_hidden=128).to(DEVICE)   #ANTES 3 , AHORA CAMBIA 
    
    # --- CAMBIO: ACTIVAR MULTI-GPU ---
    # Esto reparte el Batch Size 32 entre las dos tarjetas (16 y 16)
    if torch.cuda.device_count() > 1:
        print(f"üî• ¬°Activando Turbo! Usando {torch.cuda.device_count()} GPUs en DataParallel")
        model = nn.DataParallel(model)
    # ---------------------------------

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=FACTOR_LR, patience=PATIENCE_LR)
    scaler = torch.amp.GradScaler('cuda') 
    
    # 3. Losses
    crit_onset = FocalLoss(alpha=0.75, gamma=2.0).to(DEVICE)
    crit_frame = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0]).to(DEVICE))
    crit_offset = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0]).to(DEVICE))
    crit_vel = nn.MSELoss(reduction='none')

    # 4. Logs
    log_file = open("training_log_kaggle.csv", "w")
    header = "epoch,train_loss,val_loss,onset_f1,onset_p,onset_r,frame_f1,frame_p,frame_r,offset_f1,offset_p,offset_r,velocity_mse,lr\n"
    log_file.write(header)
    log_file.flush()
    
    best_f1 = 0.0

    try:
        for epoch in range(FINAL_EPOCHS):
            model.train()
            t_loss = 0
            
            # --- TRAIN ---
            with tqdm(train_loader, desc=f"Ep {epoch+1}/{FINAL_EPOCHS}", leave=False) as bar:
                for batch in bar:
                    hcqt = batch['hcqt'].to(DEVICE)
                    targets = {k: v.to(DEVICE) for k, v in batch.items() if k != 'hcqt'}
                    
                    optimizer.zero_grad()
                    with torch.amp.autocast('cuda'):
                        p_on, p_fr, p_off, p_vel = model(hcqt)
                        
                        l_on = crit_onset(p_on, targets['onset'])
                        l_fr = crit_frame(p_fr, targets['frame'])
                        l_off = crit_offset(p_off, targets['offset'])
                        
                        mask = targets['frame']
                        l_vel = (crit_vel(torch.sigmoid(p_vel), targets['velocity']) * mask).sum() / (mask.sum() + 1e-6)
                        
                        loss = (10.0 * l_on) + l_fr + l_off + l_vel
                    
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    t_loss += loss.item()
                    bar.set_postfix(loss=loss.item())
            
            avg_t_loss = t_loss / len(train_loader)

            # --- ESTRATEGIA: VALIDAR CADA 3 EPOCAS ---
            should_validate = ((epoch + 1) % 3 == 0) or ((epoch + 1) == FINAL_EPOCHS)

            if should_validate:
                # --- VAL ---
                model.eval()
                v_loss = 0
                ref_all, est_all = [], []
                fr_preds, fr_targs = [], []
                off_preds, off_targs = [], []
                vel_accum = 0; vel_count = 0
                
                with torch.no_grad():
                    for batch in val_loader:
                        hcqt = batch['hcqt'].to(DEVICE)
                        targets = {k: v.to(DEVICE) for k, v in batch.items() if k != 'hcqt'}
                        p_on, p_fr, p_off, p_vel = model(hcqt)
                        
                        l_on = crit_onset(p_on, targets['onset'])
                        l_fr = crit_frame(p_fr, targets['frame'])
                        l_off = crit_offset(p_off, targets['offset'])
                        mask = targets['frame']
                        l_vel = (crit_vel(torch.sigmoid(p_vel), targets['velocity']) * mask).sum() / (mask.sum() + 1e-6)
                        
                        v_loss += ((10.0 * l_on) + l_fr + l_off + l_vel).item()
                        
                        pr_on = torch.sigmoid(p_on)
                        pr_fr = torch.sigmoid(p_fr)
                        pr_off = torch.sigmoid(p_off)
                        
                        # Notes decoding
                        for i in range(len(hcqt)):
                            v_map = torch.sigmoid(p_vel[i]).cpu().numpy()
                            est = tensor_to_notes(pr_on[i].cpu().numpy(), pr_fr[i].cpu().numpy(),pr_off[i].cpu().numpy(), v_map, t_onset=THRESHOLD_ONSET, t_frame=THRESHOLD_FRAME,t_offset=THRESHOLD_OFFSET)
                            
                            ref = []
                            ref_on = targets['onset'][i].cpu().numpy()
                            ref_fr = targets['frame'][i].cpu().numpy()
                            for pitch in range(88):
                                ons = np.where(ref_on[:, pitch] > 0.5)[0]
                                for o in ons:
                                    e = o + 1
                                    while e < ref_fr.shape[0] and ref_fr[e, pitch] > 0.5: e += 1
                                    if e - o > 1: ref.append([o*HOP_LENGTH/SR, e*HOP_LENGTH/SR, pitch+21])
                            est_all.append(est)
                            ref_all.append(ref)
                        
                        # Pixel-wise decoding
                        fr_preds.append((pr_fr > THRESHOLD_FRAME).cpu().numpy().flatten())
                        fr_targs.append((targets['frame'] > 0.5).cpu().numpy().flatten())
                        
                        off_preds.append((pr_off > THRESHOLD_OFFSET).cpu().numpy().flatten())
                        off_targs.append((targets['offset'] > 0.5).cpu().numpy().flatten())
                        
                        v_p = torch.sigmoid(p_vel).cpu().numpy().flatten()
                        v_t = targets['velocity'].cpu().numpy().flatten()
                        m = mask.cpu().numpy().flatten().astype(bool)
                        if m.sum() > 0:
                            vel_accum += mean_squared_error(v_t[m], v_p[m]) * m.sum()
                            vel_count += m.sum()

                # --- METRICS ---
                avg_v_loss = v_loss / len(val_loader)
                
                onset_f1, onset_p, onset_r = compute_metrics_standard(ref_all, est_all)
                
                f_p = np.concatenate(fr_preds); f_t = np.concatenate(fr_targs)
                frame_f1 = f1_score(f_t, f_p, zero_division=0)
                frame_p = precision_score(f_t, f_p, zero_division=0)
                frame_r = recall_score(f_t, f_p, zero_division=0)
                
                o_p = np.concatenate(off_preds); o_t = np.concatenate(off_targs)
                offset_f1 = f1_score(o_t, o_p, zero_division=0)
                offset_p = precision_score(o_t, o_p, zero_division=0)
                offset_r = recall_score(o_t, o_p, zero_division=0)
                
                vel_mse = vel_accum / (vel_count + 1e-8)
                curr_lr = optimizer.param_groups[0]['lr']

                print("-" * 80)
                print(f"üèÅ Epoch {epoch+1} Results:")
                print(f"   üìâ Loss    : Train={avg_t_loss:.4f} | Val={avg_v_loss:.4f}")
                print(f"   üéπ Onset   : F1={onset_f1:.4f} | P={onset_p:.4f} | R={onset_r:.4f}")
                print(f"   üñºÔ∏è Frame   : F1={frame_f1:.4f} | P={frame_p:.4f} | R={frame_r:.4f}")
                # AQUI EST√Å LA METRICA QUE FALTABA
                print(f"   üèÅ Offset  : F1={offset_f1:.4f} | P={offset_p:.4f} | R={offset_r:.4f}")
                print(f"   ‚ö° Vel     : MSE={vel_mse:.4f}")
                print(f"   üß† LR      : {curr_lr:.2e}")
                print("-" * 80)
                
                log_line = f"{epoch+1},{avg_t_loss},{avg_v_loss},{onset_f1},{onset_p},{onset_r},{frame_f1},{frame_p},{frame_r},{offset_f1},{offset_p},{offset_r},{vel_mse},{curr_lr}\n"
                
                scheduler.step(onset_f1)
                
                if onset_f1 > best_f1:
                    best_f1 = onset_f1
                    # Guardado seguro con DataParallel
                    state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
                    torch.save(state_dict, "best_hppnet_kaggle.pth")
                    print(f"   üíæ Nuevo R√©cord! Modelo guardado (F1: {best_f1:.4f})")

                # --- GUARDAR GR√ÅFICA EN TIEMPO REAL ---
                log_file.write(log_line)
                log_file.flush()
                # Llamamos a plot aqu√≠ para que se actualice la imagen png cada vez que validamos
                plot_training_history("training_log_kaggle.csv")
                print("   üìä Gr√°fica actualizada.")

            else:
                # --- NO VALIDAR (Ahorro de tiempo) ---
                print(f"‚è© Epoch {epoch+1}: Train Loss={avg_t_loss:.4f} (Validaci√≥n saltada)")
                # Log con huecos vac√≠os para mantener formato CSV
                log_line = f"{epoch+1},{avg_t_loss},,,,,,,,,,,,{optimizer.param_groups[0]['lr']}\n"
                log_file.write(log_line)
                log_file.flush()
                
                # Guardar backup simple
                state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
                torch.save(state_dict, "latest_checkpoint.pth")

    except KeyboardInterrupt:
        print("\nüõë Entrenamiento detenido.")
        
    finally:
        log_file.close()
        # Asegurar gr√°fica final
        plot_training_history("training_log_kaggle.csv")