In [1]:
# Torch imports
import torch
import torchaudio
from torch import nn
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchvision.models as models
from pytorch_tcn import TCN

# Metrics and visualization
import time
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import seaborn as sns

# Models and feature extractor
from transformers import Wav2Vec2Processor, Wav2Vec2Model

# Others
import os
from tqdm import tqdm
import numpy as np
from datetime import datetime
import glob

  from .autonotebook import tqdm as notebook_tqdm


## Datasets

In [2]:
class CustomSpeechCommands(Dataset):
    def __init__(self, root, files_list, download=True, target_len=16000, mode="mfcc", cnn_model=None):
        """
        mode: 'mfcc', 'mfcc_delta', 'mfcc_delta_delta', 'cnn', 'wav2vec2'
        cnn_model: modelo CNN preentrenado o personalizado para extracción
        """
        self.target_len = target_len
        self.mode = mode
        self.cnn_model = cnn_model
        self.dataset = torchaudio.datasets.SPEECHCOMMANDS(root=root, download=download)
        self.indices = None
        self.splitter(files_list, root)

    def splitter(self, files_list, root):
        with open(files_list, 'r') as f:
            self.file_paths = [line.strip() for line in f.readlines()]

        self.all_paths = []
        for item in tqdm(self.dataset._walker, desc=f"Splitting {files_list}"):
            relative_path = os.path.relpath(
                item,
                start=os.path.join(root, "SpeechCommands", "speech_commands_v0.02")
            ).replace("\\", "/")
            self.all_paths.append(relative_path)

        self.indices = [i for i, path in enumerate(self.all_paths) if path in self.file_paths]
        print(f"Archivos encontrados: {len(self.indices)} / {len(self.file_paths)}")

    def pad_waveform(self, waveform):
        length = waveform.shape[-1]
        if length < self.target_len:
            waveform = F.pad(waveform, (0, self.target_len - length))
        elif length > self.target_len:
            waveform = waveform[:, :self.target_len]
        return waveform

    def extract_feature_single(self, waveform, sample_rate, feature_extractor=None, processor=None, device="cuda"):
        """
        Extrae features de UNA muestra según el modo configurado.
        """
        waveform = self.pad_waveform(waveform).to(device)

        if feature_extractor is not None:
            feature_extractor = feature_extractor.to(device)

        # --- MFCC ---
        if self.mode == "mfcc":
            feat = feature_extractor(waveform).squeeze(0).cpu().transpose(0, 1)

        # --- MFCC + Delta ---
        elif self.mode == "mfcc_delta":
            base = feature_extractor(waveform)
            delta = torchaudio.functional.compute_deltas(base)
            feat = torch.cat([base, delta], dim=1).squeeze(0).cpu().transpose(0, 1)

        # --- MFCC + Delta + Delta-Delta ---
        elif self.mode == "mfcc_delta_delta":
            base = feature_extractor(waveform)
            delta = torchaudio.functional.compute_deltas(base)
            delta2 = torchaudio.functional.compute_deltas(delta)
            feat = torch.cat([base, delta, delta2], dim=1).squeeze(0).cpu().transpose(0, 1)

        # --- CNN ---
        elif self.mode == "cnn":
            spec_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate).to(device)
            spec = spec_transform(waveform).unsqueeze(0)
            with torch.no_grad():
                embedding = self.cnn_model(spec.to(device)).cpu().squeeze()
            feat = embedding

        # --- Wav2Vec2 ---
        elif self.mode == "wav2vec2":
            waveform = waveform.squeeze(0)
            inputs = processor(
                waveform,
                sampling_rate=sample_rate,
                return_tensors="pt",
                padding=True
            ).to(device)
            with torch.no_grad():
                outputs = feature_extractor(**inputs)
            feat = outputs.last_hidden_state.squeeze(0).cpu()

        else:
            raise ValueError(f"Modo de extracción '{self.mode}' no soportado.")

        return feat

    def extract_features(self, feature_extractor=None, processor=None, device="cuda"):
        features, labels = [], []

        with torch.no_grad():
            for idx in tqdm(self.indices, desc=f"Extrayendo features ({self.mode})"):
                waveform, sample_rate, label, _, _ = self.dataset[idx]
                feat = self.extract_feature_single(
                    waveform, sample_rate, feature_extractor, processor, device
                )
                features.append(feat)
                labels.append(label)

        features = torch.stack(features)
        print(f"Features tensor: {features.shape}")

        return features, labels

    def save_features(self, feature_extractor=None, save_path=None, processor=None, device="cuda"):
        print(f"Guardando features ({self.mode}) en {save_path}")
        try:
            features, labels = self.extract_features(feature_extractor, processor, device)
            torch.save({"features": features, "labels": labels}, save_path)
            print(f"Features guardadas correctamente en {save_path}")
            print(f"Clases finales: {set(labels)}")
        except Exception as e:
            print(f"Error al guardar features en {save_path}: {e}")

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        original_idx = self.indices[idx]
        waveform, sample_rate, label, speaker_id, utterance_number = self.dataset[original_idx]
        waveform = self.pad_waveform(waveform)
        return waveform, sample_rate, label, speaker_id, utterance_number

class FeaturesDataset(Dataset):
    def __init__(self, features_path):
        """
        Carga un archivo .pt con 'features' y 'labels' previamente guardados.

        features_path: ruta al archivo .pt (por ejemplo 'data/train.pt')
        """
        data = torch.load(features_path)
        self.features = data["features"]
        self.labels = data["labels"]

        self.label_to_idx = {label: i for i, label in enumerate(sorted(set(self.labels)))}
        self.idx_to_label = {v: k for k, v in self.label_to_idx.items()}
        self.numeric_labels = torch.tensor([self.label_to_idx[l] for l in self.labels])

        print(f"Dataset cargado desde {features_path}")
        print(f" - {len(self.features)} ejemplos")
        print(f" - {len(self.label_to_idx)} clases")

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        feature = self.features[idx]
        label = self.numeric_labels[idx]
        return feature, label


## Models

In [3]:
class RNNModel(nn.Module):
    def __init__(
        self,
        rnn_type,
        n_input_channels,
        hidd_size=256,
        out_features = 35,
        num_layers=1,
    ):
        """
        Para utilizar una vanilla RNN entregue rnn_type="RNN"
        Para utilizar una LSTM entregue rnn_type="LSTM"
        Para utilizar una GRU entregue rnn_type="GRU"
        """
        super().__init__()

        self.rnn_type = rnn_type

        if rnn_type == "GRU":
            self.rnn_layer = nn.GRU(n_input_channels, hidd_size, batch_first=True, num_layers=num_layers)

        elif rnn_type == "LSTM":
            self.rnn_layer = nn.LSTM(n_input_channels, hidd_size, batch_first=True, num_layers=num_layers)

        elif rnn_type == "RNN":
            self.rnn_layer = nn.RNN(n_input_channels, hidd_size, batch_first=True, num_layers=num_layers, bidirectional=True)

        else:
            raise ValueError(f"rnn_type {rnn_type} not supported.")

        self.net = nn.Sequential(
            nn.Linear(hidd_size, out_features),
        )

        self.flatten_layer = nn.Flatten()

    def forward(self, x):
        if self.rnn_type == "GRU":
            out, h = self.rnn_layer(x)

        elif self.rnn_type == "LSTM":
            out, (h, c) = self.rnn_layer(x)

        elif self.rnn_type == "RNN":
            out, h = self.rnn_layer(x)

        out = h[-1]

        return self.net(out)

class CNNModel(nn.Module):
    def __init__(self, n_input_channels, hidd_size=64, out_features=35):
        """
        Modelo CNN (Convolutional Neural Network)

        Args:
            n_input_channels (int): Canales de entrada (e.g., 13 para MFCC)
            hidd_size (int): Número base de canales en las capas convolucionales
            out_features (int): Número de clases de salida (e.g., 35)
        """
        super().__init__()

        # --- Bloques Convolucionales ---
        # nn.Conv1d espera la entrada como (Batch, Channels, SeqLen)
        
        # (B, 13, T) -> (B, 64, T/2)
        self.conv_block1 = nn.Sequential(
            nn.Conv1d(n_input_channels, hidd_size, kernel_size=5, padding=2),
            nn.BatchNorm1d(hidd_size),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2)
        )
        
        # (B, 64, T/2) -> (B, 128, T/4)
        self.conv_block2 = nn.Sequential(
            nn.Conv1d(hidd_size, hidd_size * 2, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidd_size * 2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2)
        )

        # (B, 128, T/4) -> (B, 256, T/4)
        self.conv_block3 = nn.Sequential(
            nn.Conv1d(hidd_size * 2, hidd_size * 4, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidd_size * 4),
            nn.ReLU(),
        )

        # --- Pooling Global y Clasificación ---
        
        # Colapsa la dimensión de secuencia (T/4) a 1
        # (B, 256, T/4) -> (B, 256, 1)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        
        self.flatten = nn.Flatten()
        
        # (B, 256) -> (B, 35)
        self.fc = nn.Linear(hidd_size * 4, out_features)

    def forward(self, x):
        x = x.permute(0, 2, 1) 
        
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.global_pool(x)
        x = self.flatten(x)
        
        # Clasificar
        return self.fc(x)

class PositionalEncoding(nn.Module):
    """
    Implementa el Positional Encoding para añadir información de posición.
    """
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Matriz de Positional Encoding
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [BatchSize, SeqLen, EmbeddingDim]
        """
        # x.shape[1] es la longitud de la secuencia (SeqLen)
        x = x + self.pe[:x.size(1)].transpose(0, 1) # Transpose para hacer Broadcasting [1, SeqLen, EmbDim]
        return self.dropout(x)


class TransformerModel(nn.Module):
    def __init__(
        self,
        n_input_features: int,  # e.g., 13 for MFCCs
        n_output_classes: int = 35,
        d_model: int = 128,      # Dimensión de la representación del Transformer
        nhead: int = 8,          # Número de cabezas de atención
        d_hid: int = 256,        # Dimensión de la capa FeedForward (FNN)
        n_layers: int = 6,       # Número de bloques Codificadores
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.model_type = 'Transformer'
        self.d_model = d_model
        
        # 1. Proyección de entrada: de n_input_features a d_model
        self.input_projection = nn.Linear(n_input_features, d_model)
        
        # 2. Positional Encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        # 3. Bloques Codificadores del Transformer
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=d_hid, 
            dropout=dropout,
            batch_first=True # Importante para que el input sea [Batch, Seq, Feature]
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers)
        
        # 4. Capa de Clasificación Final
        # La estrategia es tomar la representación del PRIMER token (similar al [CLS] de BERT,
        # pero aquí usamos el primer frame de audio como vector de secuencia).
        self.classifier = nn.Sequential(
            nn.Linear(d_model, n_output_classes)
        )
        
        # Inicialización de pesos (buena práctica para Transformers)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.input_projection.bias.data.zero_()
        self.input_projection.weight.data.uniform_(-initrange, initrange)
        self.classifier[0].bias.data.zero_()
        self.classifier[0].weight.data.uniform_(-initrange, initrange)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [BatchSize, SeqLen, n_input_features]
        """
        # 1. Proyección de la entrada
        x = self.input_projection(x) * np.sqrt(self.d_model) # Factor de escalamiento
        
        # 2. Agregar Positional Encoding
        x = self.pos_encoder(x)
        
        # 3. Pasar por los Codificadores del Transformer
        # torch.Size([BatchSize, SeqLen, d_model])
        output = self.transformer_encoder(x) 
        
        # 4. Clasificación: Tomar la salida del primer frame (SeqLen=0) como 
        # la representación agregada de toda la secuencia.
        final_representation = output[:, 0, :] # [BatchSize, d_model]
        
        # 5. Capa Lineal Final
        return self.classifier(final_representation)

# %%


class CNN1DModel(nn.Module):
    def __init__(
        self,
        hidd_size=256,
        in_channels = 13,
        out_channels = 64,
        N_conv_blocks = 2,
    ):
        super().__init__()
        rnn_in = 0
        if N_conv_blocks == 1:
            self.conv_blocks = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size = 3, padding = 'same'),
                nn.ReLU(),
                nn.MaxPool1d(2)     
            )
            rnn_in = out_channels
        elif N_conv_blocks == 2:
            self.conv_blocks = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size = 3, padding = 'same'),
                nn.ReLU(),
                nn.MaxPool1d(2),
                nn.Conv1d(out_channels, out_channels, kernel_size = 3, padding = 'same'),
                nn.ReLU(),
                nn.MaxPool1d(2)
            )
            rnn_in = out_channels
        elif N_conv_blocks == 3:
            self.conv_blocks = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size = 3, padding = 'same'),
                nn.ReLU(),
                nn.MaxPool1d(2),
                nn.Conv1d(out_channels, out_channels, kernel_size = 3, padding = 'same'),
                nn.ReLU(),
                nn.MaxPool1d(2),
                nn.Conv1d(out_channels, out_channels, kernel_size = 3, padding = 'same'),
                nn.ReLU(),
                nn.MaxPool1d(2)
            )
            rnn_in = out_channels
            
        else:
            raise ValueError('Choose valid number (1-3)')

        self.rnn_layer = RNNModel(
            n_input_channels=rnn_in,
            rnn_type="RNN",
            hidd_size=hidd_size
        )

    def forward(self, x):
        perm_x = torch.permute(x, (0, 2, 1))
        conv_out = self.conv_blocks(perm_x)
        deperm_x = torch.permute(conv_out, (0, 2, 1))
        return self.rnn_layer(deperm_x)

class MejorCNN1DModel(nn.Module):
    def __init__(
        self,
        hidd_size=256,
        in_channels=13,   # Tus 13 MFCCs
        num_classes=35    # Clases de SpeechCommands
    ):
        super().__init__()
        
        # --- Bloques CNN ---
        # Aumentamos canales, usamos BatchNorm, y kernels más grandes
        
        self.conv_blocks = nn.Sequential(
            # Bloque 1
            nn.Conv1d(
                in_channels=in_channels, 
                out_channels=64, 
                kernel_size=7,  
                padding='same'
            ),
            nn.BatchNorm1d(64), 
            nn.ReLU(),
            nn.MaxPool1d(2),    
            nn.Dropout(0.2),    

            # Bloque 2
            nn.Conv1d(
                in_channels=64, 
                out_channels=128,
                kernel_size=5, 
                padding='same'
            ),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),    
            nn.Dropout(0.2),    

            # Bloque 3
            nn.Conv1d(
                in_channels=128, 
                out_channels=256,
                kernel_size=3, 
                padding='same'
            ),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(2)     
            # La secuencia de salida será L // 8
        )
        
        # --- Capa RNN ---
        # El input para la RNN ahora tiene 256 canales
        # (El tamaño del feature de la CNN)
        rnn_in_features = 256 
        
        self.rnn_layer = RNNModel(
            n_input_channels=rnn_in_features,
            rnn_type="GRU",       # <-- RECOMENDADO: Usa GRU o LSTM, no "RNN"
            hidd_size=hidd_size,
            num_classes=num_classes
        )

    def forward(self, x):
        
        perm_x = x.permute(0, 2, 1)
        conv_out = self.conv_blocks(perm_x)
        deperm_x = conv_out.permute(0, 2, 1)
        
        return self.rnn_layer(deperm_x)

class SpeechCommandTCN(nn.Module):
    def __init__(self, num_inputs, num_classes, tcn_params):
        super().__init__()
        # Instanciamos la TCN base
        self.tcn = TCN(num_inputs=num_inputs, **tcn_params)

        # Calculamos la dimensión de salida de la TCN (el último canal)
        # Si tcn_params['num_channels'] es [64, 64, 128, 128], output_dim es 128
        output_dim = tcn_params['num_channels'][-1]

        # Capa lineal para convertir features en probabilidades de clases
        self.classifier = nn.Linear(output_dim, num_classes)

    def forward(self, x):
        # x shape: (Batch, Time, Features) from FeaturesDataset
        # TCN expects (Batch, Channels, Time) when input_shape='NCL'
        x = x.permute(0, 2, 1) # Convert (Batch, Time, Features) to (Batch, Features, Time)

        # 1. Feature Extraction (TCN)
        out = self.tcn(x)  # Salida: (Batch, Out_Channels, Time)

        # 2. Global Average Pooling
        # Promediamos sobre la dimensión del tiempo (dim=2) para tener un vector por audio
        out = out.mean(dim=2) # Output: (Batch, Out_Channels)

        # 3. Clasificación
        return self.classifier(out)
def save_model(model, path, config=None):
    """
    Guarda un modelo PyTorch.

    Parameters:
    - model: instancia de nn.Module
    - path: destino del archivo .pt
    - config: diccionario con parámetros del modelo
    """
    os.makedirs(os.path.dirname(path), exist_ok=True)

    checkpoint = {
        "state_dict": model.state_dict(),
        "config": config if config is not None else {}
    }

    torch.save(checkpoint, path)
    print(f"Modelo guardado en {path}")

def load_trained_model(model_class, checkpoint_path, device="cuda", config=None, **override_kwargs):
    """
    Carga un modelo guardado, permitiendo entregar un config externo.

    Prioridades de config:
        1. config explícito pasado a la función
        2. config guardado en el checkpoint
        3. override_kwargs (pisan todo lo anterior)

    Parameters:
    - model_class: clase del modelo
    - checkpoint_path: ruta al archivo .pt
    - device: "cpu" o "cuda"
    - config: diccionario completo externo para construir el modelo
    - override_kwargs: parámetros puntuales que quieras reemplazar
    """

    checkpoint = torch.load(checkpoint_path, map_location=device)

    # 1) checkpoint config
    if checkpoint.get("config", {}) is not None:
        final_config = checkpoint.get("config", {}).copy()

    # 2) si el usuario entrega un config, reemplaza completamente al del checkpoint
    if config is not None:
        final_config = config.copy()

    # 3) override puntual de valores individuales
    final_config.update(override_kwargs)

    # construir modelo
    model = model_class(**final_config)

    model.load_state_dict(checkpoint["state_dict"])
    model.to(device)
    model.eval()

    print(f"Modelo cargado desde {checkpoint_path}")
    return model


## Trainers

In [4]:
def train_step(x_batch, y_batch, model, optimizer, criterion, use_gpu):
    # Predicción
    y_predicted = model(x_batch)

    # Cálculo de loss
    loss = criterion(y_predicted, y_batch)

    # Actualización de parámetros
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return y_predicted, loss


def evaluate(val_loader, model, criterion, use_gpu):
    cumulative_loss = 0
    cumulative_predictions = 0
    data_count = 0

    for x_val, y_val in val_loader:
        if use_gpu:
            x_val = x_val.cuda()
            y_val = y_val.cuda()

        y_predicted = model(x_val)

        loss = criterion(y_predicted, y_val)

        class_prediction = torch.argmax(y_predicted, axis=1).long()

        cumulative_predictions += (y_val == class_prediction).sum().item()
        cumulative_loss += loss.item() * y_val.shape[0]
        data_count += y_val.shape[0]

    val_acc = cumulative_predictions / data_count
    val_loss = cumulative_loss / data_count

    return val_acc, val_loss


def train_model(
    model,
    train_dataset,
    val_dataset,
    epochs,
    criterion,
    batch_size,
    lr,
    n_evaluations_per_epoch=6,
    use_gpu=False,
    patience=5,                 
    model_config = None,
    model_arch = "ola"
):
    if use_gpu:
        model.cuda()

    # Dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=use_gpu
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, pin_memory=use_gpu
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=1e-9)

    curves = {"train_acc": [], "val_acc": [], "train_loss": [], "val_loss": []}

    t0 = time.perf_counter()
    iteration = 0
    n_batches = len(train_loader)

    best_val_loss = float("inf")
    epochs_without_improvement = 0

    print(n_batches)

    for epoch in range(epochs):
        print(f"\rEpoch {epoch + 1}/{epochs}")
        cumulative_train_loss = 0
        cumulative_train_corrects = 0
        examples_count = 0

        model.train()
        for i, (x_batch, y_batch) in enumerate(train_loader):
            if use_gpu:
                x_batch = x_batch.cuda()
                y_batch = y_batch.cuda()

            y_predicted, loss = train_step(x_batch, y_batch, model, optimizer, criterion, use_gpu)

            cumulative_train_loss += loss.item() * x_batch.shape[0]
            examples_count += y_batch.shape[0]

            class_prediction = torch.argmax(y_predicted, axis=1).long()
            cumulative_train_corrects += (y_batch == class_prediction).sum().item()

            if (i % (n_batches // n_evaluations_per_epoch) == 0) and (i > 0):
                train_loss = cumulative_train_loss / examples_count
                train_acc = cumulative_train_corrects / examples_count
                print(f"Iteration {iteration} - Batch {i}/{len(train_loader)} - Train loss: {train_loss}, Train acc: {train_acc}")

            iteration += 1

        with torch.no_grad():
            val_acc, val_loss = evaluate(val_loader, model, criterion, use_gpu)

        print(f"Val loss: {val_loss}, Val acc: {val_acc}")

        train_loss = cumulative_train_loss / examples_count
        train_acc = cumulative_train_corrects / examples_count

        curves["train_acc"].append(train_acc)
        curves["val_acc"].append(val_acc)
        curves["train_loss"].append(train_loss)
        curves["val_loss"].append(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0

        else:
            epochs_without_improvement += 1
            print(f"Sin mejora. Paciencia: {epochs_without_improvement}/{patience}")

            if epochs_without_improvement >= patience:
                print("Early stopping activado!")
                break

    total_time = time.perf_counter() - t0
    print(f"Tiempo total de entrenamiento: {total_time:.4f} [s]")

    if model_config is not None:
        save_path = "model_weights"
        os.makedirs(save_path, exist_ok=True)
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        save_model(model, os.path.join(save_path, f"{model_arch}_{timestamp}.pt"))
    model.cpu()
    return curves, total_time



def evaluate_models_metrics(models, dataloader, criterion, use_gpu=True, num_repeats=15):
    """
    Evalúa múltiples modelos y calcula métricas promedio y desviación estándar.
    Mide el tiempo de inferencia por batch, promediando 'num_repeats' ejecuciones.
    """

    all_metrics = {
        "accuracy": [],
        "recall": [],
        "precision": [],
        "f1": [],
        "infer_time": []    # Lista de listas, donde cada lista interior son los tiempos promedio por batch
    }

    all_times = []      # ← promedio de tiempos por modelo (un valor por modelo)
    all_f1_list = []    # ← lista de f1 score por modelo (un valor por modelo)

    for model in models:
        model.eval()
        if use_gpu:
            model.cuda()

        y_true = []
        y_pred = []
        losses = []

        # Lista de tiempos de inferencia promedio (promedio de las 'num_repeats' ejecuciones)
        infer_times_per_batch = [] 

        with torch.no_grad():
            for X, y in dataloader:
                if use_gpu:
                    X, y = X.cuda(non_blocking=True), y.cuda(non_blocking=True)
                
                # Para asegurar consistencia, sincronizar GPU antes de empezar
                if use_gpu:
                    torch.cuda.synchronize()

                # ---- Medir tiempo (Inferencia Repetida) ----
                batch_run_times = []
                
                # Repetir la inferencia 'num_repeats' veces
                for r in range(num_repeats):
                    start_time = time.time()
                    outputs = model(X)
                    
                    if use_gpu:
                        torch.cuda.synchronize() # Esperar que la GPU termine
                    
                    end_time = time.time()
                    
                    # Descartar la primera ejecución (r=0) para evitar el tiempo de 'warm-up'
                    if r > 0: 
                        batch_run_times.append(end_time - start_time)

                # Calcular el tiempo de inferencia promedio para este batch
                # Si num_repeats=1, se añade 0 a batch_run_times. Hay que manejarlo.
                if len(batch_run_times) > 0:
                    avg_batch_time = np.mean(batch_run_times)
                else:
                    # Si solo se ejecutó una vez (num_repeats=1) o menos, usamos el primer tiempo.
                    # Nota: El tiempo del warm-up no es ideal, pero es el único disponible.
                    start_time = time.time()
                    outputs = model(X)
                    if use_gpu:
                        torch.cuda.synchronize()
                    avg_batch_time = time.time() - start_time
                    
                infer_times_per_batch.append(avg_batch_time)
                # --------------------------------------------
                
                loss = criterion(outputs, y)
                losses.append(loss.item())

                # Calcular predicciones y métricas
                preds = outputs.argmax(dim=1)
                y_true.extend(y.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())

        # --- Métricas totales por modelo ---
        acc = accuracy_score(y_true, y_pred)
        rec = recall_score(y_true, y_pred, average='macro', zero_division='warn')
        prec = precision_score(y_true, y_pred, average='macro', zero_division='warn')
        f1 = f1_score(y_true, y_pred, average='macro', zero_division='warn')

        all_metrics["accuracy"].append(acc)
        all_metrics["recall"].append(rec)
        all_metrics["precision"].append(prec)
        all_metrics["f1"].append(f1)
        
        # Guardar los tiempos de inferencia promedio por batch para este modelo
        all_metrics["infer_time"].append(infer_times_per_batch)
        all_times.append(np.mean(infer_times_per_batch))
        all_f1_list.append(f1)

    # --- Promedios Finales ---
    # Nota: Aquí promediamos las métricas entre los modelos
    metrics_mean = {k: np.mean(v) if k != "infer_time" else None for k, v in all_metrics.items()}
    metrics_std  = {k: np.std(v)  if k != "infer_time" else None for k, v in all_metrics.items()}
    
    # El promedio y STD de "infer_time" se maneja directamente con all_times
    metrics_mean["infer_time"] = np.mean(all_times)
    metrics_std["infer_time"] = np.std(all_times)

    return metrics_mean, metrics_std, all_metrics, all_times, all_f1_list



## Visualization

In [5]:
def plot_waveform(wf, sample_rate, label="", figname=None):
    """
    Muestra el waveform (izquierda) y los MFCCs (derecha) de una señal de audio.

    Parámetros:
        wf (Tensor): señal de audio [1, N] o [N]
        sample_rate (int): frecuencia de muestreo (Hz)
        label (str): etiqueta opcional para el título
        figname (str): ruta para guardar la figura (si es None, solo muestra)
    """
    if isinstance(wf, torch.Tensor):
        wf = wf.squeeze().cpu()

    # === Transformación MFCC ===
    mfcc_transform = torchaudio.transforms.MFCC(
        sample_rate=sample_rate,
        n_mfcc=13,
        melkwargs={"n_fft": 320, "hop_length": 160, "n_mels": 23},
        log_mels=True
    )
    mfcc = mfcc_transform(wf.unsqueeze(0)).squeeze().cpu().numpy()  # [n_mfcc, time]

    # === Crear figura con 2 subplots ===
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    sns.set_style("whitegrid")

    # --- Waveform ---
    time = torch.arange(0, len(wf)) / sample_rate
    axes[0].plot(time, wf.numpy(), color="steelblue", linewidth=1.0)
    axes[0].set_title("Waveform", fontsize=12)
    axes[0].set_xlabel("Tiempo [s]")
    axes[0].set_ylabel("Amplitud")

    # --- MFCC ---
    sns.heatmap(mfcc, ax=axes[1], cmap="viridis", cbar=True)
    axes[1].set_title("MFCCs", fontsize=12)
    axes[1].set_xlabel("Tiempo (frames)")
    axes[1].set_ylabel("Coeficiente MFCC")

    fig.suptitle(f"Audio: {label}", fontsize=14, y=1.02)
    plt.tight_layout()

    # === Guardar o mostrar ===
    if figname:
        name = os.path.join('img', f'{figname}.pdf')
        plt.savefig(name, bbox_inches="tight")
        print(f"Figura guardada en {name}")
    else:
        plt.show()

    plt.close(fig)

 
def plot_f1_vs_inference_time_with_error_bars(arch_names, mean_times, f1_distributions, time_stds=None, name=''):
    
    plt.figure(figsize=(10, 6))

    # Si no se pasan std de tiempo, usar 0
    if time_stds is None:
        time_stds = [0 for _ in mean_times]

    for i, arch in enumerate(arch_names):

        # ============================
        # 1) Calcular medias y desviaciones estándar
        # ============================
        mean_t = mean_times[i]
        mean_f1 = np.mean(f1_distributions[i])
        std_t = time_stds[i]     # std de tiempo ya está en la escala de ms
        std_f1 = np.std(f1_distributions[i])

        # ============================
        # 2) Dibujar punto con barras de error
        # ============================
        plt.errorbar(
            mean_t,
            mean_f1,
            xerr=std_t,          
            yerr=std_f1,         
            fmt='o',             
            capsize=5,           
            markersize=8,        
            linestyle='None',    
            label=arch
        )

    plt.xlabel("Tiempo de inferencia por batch (ms)")
    plt.ylabel("F1 (macro)")
    plt.title("Comparación arquitecturas: F1 vs Tiempo de inferencia" + (" - " + name if name else ""))
    plt.grid(alpha=0.3)
    plt.legend()
    plt.savefig(os.path.join('img','f1_vs_inference_time_error_bars.pdf'))
    plt.close()
    

def show_curves(all_curves, suptitle='', filename = ''):
    min_len = {k: min(len(c[k]) for c in all_curves) for k in all_curves[0].keys()}

    trimmed = {
        k: np.array([c[k][:min_len[k]] for c in all_curves])
        for k in all_curves[0].keys()
    }

    final_curve_means = {k: trimmed[k].mean(axis=0) for k in trimmed}
    final_curve_stds  = {k: trimmed[k].std(axis=0)  for k in trimmed}

    fig, ax = plt.subplots(1, 2, figsize=(13, 5))
    fig.set_facecolor('white')

    epochs = np.arange(len(final_curve_means["val_loss"])) + 1

    ax[0].plot(epochs, final_curve_means['val_loss'], label='validation')
    ax[0].plot(epochs, final_curve_means['train_loss'], label='training')
    ax[0].fill_between(epochs,
                       y1=final_curve_means["val_loss"] - final_curve_stds["val_loss"],
                       y2=final_curve_means["val_loss"] + final_curve_stds["val_loss"], alpha=.5)
    ax[0].fill_between(epochs,
                       y1=final_curve_means["train_loss"] - final_curve_stds["train_loss"],
                       y2=final_curve_means["train_loss"] + final_curve_stds["train_loss"], alpha=.5)
    ax[0].set_xlabel('Epoch')
    ax[0].set_ylabel('Loss')
    ax[0].set_title('Loss evolution during training')
    ax[0].legend()

    # ==== Plot de precisión ====
    ax[1].plot(epochs, final_curve_means['val_acc'], label='validation')
    ax[1].plot(epochs, final_curve_means['train_acc'], label='training')
    ax[1].fill_between(epochs,
                       y1=final_curve_means["val_acc"] - final_curve_stds["val_acc"],
                       y2=final_curve_means["val_acc"] + final_curve_stds["val_acc"], alpha=.5)
    ax[1].fill_between(epochs,
                       y1=final_curve_means["train_acc"] - final_curve_stds["train_acc"],
                       y2=final_curve_means["train_acc"] + final_curve_stds["train_acc"], alpha=.5)
    ax[1].set_xlabel('Epoch')
    ax[1].set_ylabel('Accuracy')
    ax[1].set_title('Accuracy evolution during training')
    ax[1].legend()

    fig.suptitle(suptitle, fontsize=16, weight="bold")

    filepath = os.path.join('img', f'{filename}.pdf')
    plt.savefig(filepath, bbox_inches='tight', format='pdf')
    plt.close(fig)

def get_metrics_and_confusion_matrix(models, dataset, name='', filename = ''):
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=min(32, len(dataset)))

    y_true = torch.cat([y for _, y in dataloader])
    total_classes = len(torch.unique(y_true))
    
    # Definir un ID seguro para 'others' que no choque con clases existentes
    OTHERS_ID = total_classes 

    if hasattr(dataset, "idx_to_label"):
        class_names = [dataset.idx_to_label[i] for i in range(total_classes)]
    elif hasattr(dataset, "labels"):
        class_names = dataset.labels
    else:
        class_names = [str(i) for i in range(total_classes)]

    counts = torch.bincount(y_true, minlength=total_classes)
    top10 = torch.argsort(counts, descending=True)[:10].tolist()
    top10_set = set(top10)

    # --- Preparar Grupo A (Top 10 + Others) ---
    y_true_A = y_true.clone()
    for cls in range(total_classes):
        if cls not in top10_set:
            y_true_A[y_true_A == cls] = OTHERS_ID  # Usar ID seguro

    ids_A = top10 + [OTHERS_ID] 
    label_names_A = [class_names[cls] for cls in top10] + ["others"]

    # --- Preparar Grupo B (Resto) ---
    mask_B = torch.tensor([c not in top10_set for c in y_true], dtype=torch.bool)
    y_true_B = y_true[mask_B]
    ids_B = sorted(torch.unique(y_true_B).tolist())
    label_names_B = [class_names[c] for c in ids_B]

    def map_groupA(pred):
        predA = pred.clone()
        for cls in range(total_classes):
            if cls not in top10_set:
                predA[predA == cls] = OTHERS_ID 
        return predA

    # --- Compute Group (MODIFICADA para incluir F1 total opcional) ---
    def compute_group(models, dataloader, y_true_group, target_ids, mask=None, map_func=None, compute_f1_total=False):
        cms, f1_scores = [], []
        f1_total_scores = [] # Lista para el F1 score de TODAS las clases

        for model in models:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model.to(device)
            model.eval()

            preds_raw = [] # Predicciones originales (sin mapear)
            preds_group = [] # Predicciones mapeadas (para CM/F1 del grupo)
            
            with torch.no_grad(): 
                for x, _ in dataloader:
                    x = x.to(device)
                    p = model(x).argmax(dim=1)
                    
                    preds_raw.append(p)

                    p_group = map_func(p) if map_func else p
                    preds_group.append(p_group)

            preds_raw = torch.cat(preds_raw)
            preds_group = torch.cat(preds_group)

            if mask is not None:
                preds_group = preds_group[mask]
                preds_raw = preds_raw[mask] # También aplicar máscara a las raw si es Grupo B

            preds_cpu = preds_group.cpu()
            
            # Cálculo de CM y F1 del Grupo (A o B)
            cm = confusion_matrix(
                y_true_group,
                preds_cpu,
                labels=target_ids, 
                normalize="true",
            )
            cms.append(cm)
            f1_scores.append(f1_score(y_true_group, preds_cpu, average='macro'))

            # Cálculo de F1 Score de TODAS las clases si es solicitado
            if compute_f1_total:
                # El F1 macro se calcula sobre y_true y preds_raw (sin mapear)
                f1_total = f1_score(y_true.cpu(), preds_raw.cpu(), average='macro')
                f1_total_scores.append(f1_total)

        if compute_f1_total:
            # Retorna las métricas del grupo MÁS las métricas totales
            return (
                np.mean(cms, axis=0),
                np.std(cms, axis=0),
                np.mean(f1_scores) * 100,
                np.std(f1_scores) * 100,
                np.mean(f1_total_scores) * 100,
                np.std(f1_total_scores) * 100,
            )
        
        # Retorna solo las métricas del grupo
        return (
            np.mean(cms, axis=0),
            np.std(cms, axis=0),
            np.mean(f1_scores) * 100,
            np.std(f1_scores) * 100,
        )


    # Llamadas a compute_group
    # 1. Llamada para el Grupo A (Top-10 + Others), también calcula el F1 Total
    cmA_mean, cmA_std, accA_mean, accA_std, f1_total_mean, f1_total_std = compute_group(
        models, dataloader, y_true_A, ids_A, map_func=map_groupA, compute_f1_total=True
    )
    
    # 2. Llamada para el Grupo B (Resto)
    cmB_mean, cmB_std, accB_mean, accB_std = compute_group(
        models, dataloader, y_true_B, ids_B, mask=mask_B, compute_f1_total=False # No es necesario calcular F1 Total aquí de nuevo
    )

    os.makedirs("img", exist_ok=True)

    def plot_both_cms(cmA_mean, cmA_std, labelsA, accA_mean, accA_std,
                            cmB_mean, cmB_std, labelsB, accB_mean, accB_std,
                            filename, f1_total_mean, f1_total_std, main_title=''): # NUEVOS PARAMETROS

        # Usar mismo vmin/vmax para ambas matrices → colorbar compartido
        vmin = min(cmA_mean.min(), cmB_mean.min())
        vmax = max(cmA_mean.max(), cmB_mean.max())
        
        # ------------------------------------------------------------
        # Añadir F1 Total al título principal
        if main_title:
             main_title = f"{main_title}\nOverall F1 (macro) = {f1_total_mean:.2f} ± {f1_total_std:.2f}%"
        else:
             main_title = f"Overall F1 (macro) = {f1_total_mean:.2f} ± {f1_total_std:.2f}%"

        fig, axs = plt.subplots(1, 2, figsize=(18, 8))
        fig.suptitle(main_title, fontsize=16, fontweight='bold')
        
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) 

        # ---------------------- CONF MAT A --------------------------
        ax = axs[0]
        imA = ax.imshow(cmA_mean, cmap=plt.cm.Blues, vmin=vmin, vmax=vmax)

        # ... (código de ticks y labels A) ...
        ax.set_xticks(np.arange(len(labelsA)))
        ax.set_yticks(np.arange(len(labelsA)))
        ax.set_xticklabels(labelsA, rotation=90, ha="center")
        ax.set_yticklabels(labelsA)
        
        norm = Normalize(vmin=vmin, vmax=vmax)

        for i in range(len(labelsA)):
            for j in range(len(labelsA)):
                rgba = imA.cmap(norm(cmA_mean[i, j]))
                luminance = 0.299*rgba[0] + 0.587*rgba[1] + 0.114*rgba[2]
                text_color = "white" if luminance < 0.5 else "black"
                ax.text(
                    j, i,
                    f"{cmA_mean[i,j]:.2f}\n±{cmA_std[i,j]:.2f}",
                    ha="center", va="center",
                    fontsize=8,
                    color=text_color
                )

        ax.set_xlabel("Predicted label")
        ax.set_ylabel("True label")
        # Subtítulo A con F1 macro del grupo A


        # ---------------------- CONF MAT B --------------------------
        ax = axs[1]
        imB = ax.imshow(cmB_mean, cmap=plt.cm.Blues, vmin=vmin, vmax=vmax)

        # ... (código de ticks y labels B) ...
        ax.set_xticks(np.arange(len(labelsB)))
        ax.set_yticks(np.arange(len(labelsB)))
        ax.set_xticklabels(labelsB, rotation=90, ha="center")
        ax.set_yticklabels(labelsB)



        for i in range(len(labelsB)):
            for j in range(len(labelsB)):
                rgba = imB.cmap(norm(cmB_mean[i, j]))
                luminance = 0.299*rgba[0] + 0.587*rgba[1] + 0.114*rgba[2]
                text_color = "white" if luminance < 0.5 else "black"
                ax.text(
                    j, i,
                    f"{cmB_mean[i,j]:.2f}",
                    ha="center", va="center",
                    fontsize=6, 
                    color=text_color,
                    rotation=45
                )

        ax.set_xlabel("Predicted label")
        ax.set_ylabel("True label")
        # Subtítulo B con F1 macro del grupo B

        cbar = fig.colorbar(imA, cax=cbar_ax) 
        cbar.set_label("Normalized frequency", rotation=270, labelpad=15)

        plt.subplots_adjust(right=0.9, top=0.88) 

        plt.savefig(filename, bbox_inches="tight")
        plt.close()

    # LLAMADA ACTUALIZADA a plot_both_cms, pasando el F1 TOTAL
    plot_both_cms(
        cmA_mean, cmA_std, label_names_A, accA_mean, accA_std,
        cmB_mean, cmB_std, label_names_B, accB_mean, accB_std,
        filename=f'{filename}.pdf',
        f1_total_mean=f1_total_mean,       # Nuevo parámetro
        f1_total_std=f1_total_std,         # Nuevo parámetro
        main_title=f"Results for model {'s' if len(models)>1 else ''} on {name}"
    )

    
    # También es útil retornar el valor
    return {
        "f1_total_mean": f1_total_mean,
        "f1_total_std": f1_total_std,
        "cmA_mean": cmA_mean,
        "cmB_mean": cmB_mean,
        # ... otras métricas si son necesarias ...
    }

## Feature extraction

In [6]:
# --- Configuración base ---
ROOT_DIR = "data"
SAVE_DIR = os.path.join(ROOT_DIR, "features")
device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(SAVE_DIR, exist_ok=True)

# --- Parámetros comunes ---
mfcc = torchaudio.transforms.MFCC(
    sample_rate=16000,
    n_mfcc=13,
    melkwargs={"n_fft": 320, "hop_length": 160, "n_mels": 23}
)

# --- Inicializar Wav2Vec2 una sola vez ---
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(device)

# --- Configuración de modos y extractores ---
modes = {
    "mfcc": mfcc,
    "mfcc_delta": mfcc,
    "mfcc_delta_delta": mfcc,
    "wav2vec2": wav2vec2,
}

# --- Procesar para train y val ---
for split in ["train", "val", "test"]:
    list_path = os.path.join(ROOT_DIR, f"{split}_list.txt")

    for mode, extractor in modes.items():
        save_path = os.path.join(SAVE_DIR, f"{split}_{mode}.pt")

        if os.path.isfile(save_path):
            print(f"{save_path} ya existe, saltando...")
            continue

        print(f"\nExtrayendo {mode} para {split}...")

        dataset = CustomSpeechCommands(ROOT_DIR, list_path, mode=mode)
        if mode == "wav2vec2":
            dataset.save_features(
                feature_extractor=extractor,
                processor=processor,
                device=device,
                save_path=save_path,
            )
        else:
            dataset.save_features(
                feature_extractor=extractor,
                device=device,
                save_path=save_path,
            )

print("\nExtracción de features completada.")




data/features/train_mfcc.pt ya existe, saltando...
data/features/train_mfcc_delta.pt ya existe, saltando...
data/features/train_mfcc_delta_delta.pt ya existe, saltando...
data/features/train_wav2vec2.pt ya existe, saltando...
data/features/val_mfcc.pt ya existe, saltando...
data/features/val_mfcc_delta.pt ya existe, saltando...
data/features/val_mfcc_delta_delta.pt ya existe, saltando...
data/features/val_wav2vec2.pt ya existe, saltando...
data/features/test_mfcc.pt ya existe, saltando...
data/features/test_mfcc_delta.pt ya existe, saltando...
data/features/test_mfcc_delta_delta.pt ya existe, saltando...
data/features/test_wav2vec2.pt ya existe, saltando...

Extracción de features completada.


In [7]:
SAVE_DIR = os.path.join(ROOT_DIR, 'petes')
# --- Procesar para train y val ---
for split in ["train", "val", "test"]:
    list_path = os.path.join(ROOT_DIR, f"{split}_list.txt")
        
    for hl in [320, 160, 54, 32, 16]:
        mode = torchaudio.transforms.MFCC(
            sample_rate=16000,
            n_mfcc=13,
            log_mels = True,
            melkwargs={"n_fft": 320, "hop_length": hl, "n_mels": 23}
        )
        save_path = os.path.join(SAVE_DIR, f"{split}_{hl}_mfcc.pt")
        if os.path.isfile(save_path):
            print(f"{save_path} ya existe, saltando...")
            continue
        print(f"\nExtrayendo {mode} para {split}...")

        dataset = CustomSpeechCommands(ROOT_DIR, list_path, mode='mfcc')
        if mode == "wav2vec2":
            dataset.save_features(
                feature_extractor=extractor,
                processor=processor,
                device=device,
                save_path=save_path,
            )
        else:
            dataset.save_features(
                feature_extractor=mode,
                device=device,
                save_path=save_path,
            )

print("\nExtracción de features completada.")

data/petes/train_320_mfcc.pt ya existe, saltando...
data/petes/train_160_mfcc.pt ya existe, saltando...
data/petes/train_54_mfcc.pt ya existe, saltando...
data/petes/train_32_mfcc.pt ya existe, saltando...
data/petes/train_16_mfcc.pt ya existe, saltando...
data/petes/val_320_mfcc.pt ya existe, saltando...
data/petes/val_160_mfcc.pt ya existe, saltando...
data/petes/val_54_mfcc.pt ya existe, saltando...
data/petes/val_32_mfcc.pt ya existe, saltando...
data/petes/val_16_mfcc.pt ya existe, saltando...
data/petes/test_320_mfcc.pt ya existe, saltando...
data/petes/test_160_mfcc.pt ya existe, saltando...
data/petes/test_54_mfcc.pt ya existe, saltando...
data/petes/test_32_mfcc.pt ya existe, saltando...
data/petes/test_16_mfcc.pt ya existe, saltando...

Extracción de features completada.


## Training

In [None]:
ROOT_DIR = os.path.join("data","petes")
SAVE_DIR = ROOT_DIR
MODEL_WEIGHT_PATH = 'seq_length_model_weights'
device = "cuda"

tcn_config = {
    'num_channels': [64, 64, 128, 128],
    'kernel_size': 3,
    'dilations': [1, 2, 4, 8],
    'dropout': 0,
    'causal': False,             # False es mejor para clasificación offline
    'input_shape': 'NCL',        # Changed to NCL as input is permuted to (Batch, Features, Time)
    'use_norm': 'batch_norm',            # Set to None as per instruction
    'use_skip_connections': True # Keep skip connections enabled
}

lr = 5e-4
batch_size = 32
criterion = nn.CrossEntropyLoss()
n_trains = 5
epochs = 30
use_gpu = True
architectures = ["GRU", "LSTM", 'RNN', "CNN", "TCNN", "Transformer"]

# Inicializar diccionarios para guardar los scores por arquitectura
f1_scores = {arch: [] for arch in architectures}
f1_stds = {arch: [] for arch in architectures}

hop_lengths = [320, 160, 54, 32, 16]
seq_len_seen = []

# --- Inicio del Loop Principal ---
for hop_length in hop_lengths:
    seq_len_input = 1 + 16000 // hop_length
    seq_len_seen.append(seq_len_input)
    print(f"Largo de secuencia de entrada: {seq_len_input}")

    # Cargar datasets (Asegúrate que estos archivos existan para cada hop_length)
    train_dataset = FeaturesDataset(os.path.join(SAVE_DIR, f"train_{hop_length}_mfcc.pt"))
    val_dataset   = FeaturesDataset(os.path.join(SAVE_DIR, f"val_{hop_length}_mfcc.pt"))
    test_dataset  = FeaturesDataset(os.path.join(SAVE_DIR, f"test_{hop_length}_mfcc.pt"))

    for arch in architectures:
        print(f"\n======= Entrenando modelos tipo {arch} (Seq Len: {seq_len_input}) =======")

        models = []
        curves = []

        for k in range(n_trains):
            # ... (Tu lógica de creación de modelo RNN, CNN, TCN, Transformer intacta) ...
            # ... (Simplemente asegúrate de que 'model' esté definido aquí) ...j
            # Ejemplo simplificado de tu bloque de creación:
            if arch in ["GRU", "LSTM", "RNN"]:
                config = {'rnn_type': arch, 'n_input_channels': 13, 'hidd_size': 256, 'out_features': 35, 'num_layers': 1}
                model = RNNModel(**config)
            elif arch == "CNN":
                config = {'n_input_channels': 13, 'hidd_size': 64, 'out_features': 35}
                model = CNNModel(**config)
            elif arch == "TCNN":
                config = {'num_inputs': 13, 'num_classes': 35, 'tcn_params': tcn_config}
                model = SpeechCommandTCN(**config)
            elif arch == 'Transformer':
                config = {'n_input_features': 13, 'n_output_classes': 35, 'd_model': 128, 'nhead': 8, 'd_hid': 512, 'n_layers': 4, 'dropout': 0.0}
                model = TransformerModel(**config)
            # 2. DETERMINAR RUTA DEL ARCHIVO
            # Usamos k+1 para que coincida con tus archivos (1, 2, 3, 4, 5)
            filename = f'{arch}_{k+1}_train_seq_len_[{seq_len_input}].pt'
            filepath = os.path.join(MODEL_WEIGHT_PATH, arch, filename)

            # 3. LÓGICA DE CARGA INTELIGENTE
            if os.path.exists(filepath):
                print(f"[SKIP] Cargando pesos existentes: {filename}")
                # Cargar pesos en el modelo inicializado
                model.load_state_dict(torch.load(filepath)['state_dict'])
                
                # Si tienes guardadas las curvas en otro lado, podrías intentar cargarlas, 
                # pero si no, 'curves' quedará vacío para este modelo (no afecta la evaluación final de F1).
            
            else:
                print(f"   [TRAIN] Archivo no encontrado, entrenando: {filename}")
                
                # Entrenamiento
                curve, _ = train_model(
                    model, train_dataset, val_dataset, epochs, criterion, batch_size, lr,
                    n_evaluations_per_epoch=3, use_gpu=use_gpu, patience=10,
                    model_config=config, model_arch=arch
                )
                curves.append(curve)
                
                # Guardar
                os.makedirs(os.path.dirname(filepath), exist_ok=True)
                save_model(model, filepath, config)

                # Entrenamiento
                curve, _ = train_model(
                    model, train_dataset, val_dataset, epochs, criterion, batch_size, lr,
                    n_evaluations_per_epoch=3, use_gpu=use_gpu, patience=10,
                    model_config=config, model_arch=arch
                )
                curves.append(curve)
                models.append(model)
                save_model(
                    model, 
                    os.path.join(
                        MODEL_WEIGHT_PATH,
                        f'{arch}',
                        f'{arch}_{k+1}_train_seq_len_[{seq_len_input}].pt'
                    ), 
                    config
                )
                # Mostrar curvas y matrices
                show_curves(
                    curves,
                    suptitle=f"{arch} con largo de secuencia de {seq_len_input} puntos",
                    filename=os.path.join(
                        f'{arch}',
                        f"{arch}_seq_len{seq_len_input}"
                    )
                )
        
        models.append(model)
        # Evaluar
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        metrics_mean, metrics_std, *_ = evaluate_models_metrics(models, test_loader, criterion, use_gpu=use_gpu)

        # 2. GUARDAR MÉTRICAS EN EL DICCIONARIO
        f1_scores[arch].append(metrics_mean["f1"])
        f1_stds[arch].append(metrics_std["f1"])

        get_metrics_and_confusion_matrix(
            models,
            test_dataset,
            name=f"{arch} con largo de secuencia de {seq_len_input} puntos", 
            filename=os.path.join(
                'img',
                f'{arch}',
                f"conf_mat_{arch}_seq_len{seq_len_input}"
            )
        )

# --- Gráfico Final Corregido ---
plt.figure(figsize=(10, 6))

# Iterar sobre cada arquitectura para dibujar su propia línea
for arch in architectures:
    plt.errorbar(
        seq_len_seen, 
        f1_scores[arch], 
        yerr=f1_stds[arch], 
        marker="o", 
        capsize=4, 
        label=arch,
        alpha=0.8
    )

plt.xlabel("Largo de secuencia de entrada (frames)")
plt.ylabel("F1-score promedio (± std)")
plt.title("Performance vs Largo de Secuencia por Arquitectura")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Leyenda fuera del gráfico si hay muchas
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("img/f1_vs_seq_len_comparison.pdf", bbox_inches="tight")
plt.show()

Largo de secuencia de entrada: 51
Dataset cargado desde data/petes/train_320_mfcc.pt
 - 32453 ejemplos
 - 35 clases
Dataset cargado desde data/petes/val_320_mfcc.pt
 - 3875 ejemplos
 - 35 clases
Dataset cargado desde data/petes/test_320_mfcc.pt
 - 4381 ejemplos
 - 35 clases

[SKIP] Cargando pesos existentes: GRU_1_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: GRU_2_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: GRU_3_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: GRU_4_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: GRU_5_train_seq_len_[51].pt

[SKIP] Cargando pesos existentes: LSTM_1_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: LSTM_2_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: LSTM_3_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: LSTM_4_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: LSTM_5_train_seq_len_[51].pt

[SKIP] Cargando pesos existentes: RNN_1_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: R

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



[SKIP] Cargando pesos existentes: CNN_1_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: CNN_2_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: CNN_3_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: CNN_4_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: CNN_5_train_seq_len_[51].pt


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



[SKIP] Cargando pesos existentes: TCNN_1_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: TCNN_2_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: TCNN_3_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: TCNN_4_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: TCNN_5_train_seq_len_[51].pt

[SKIP] Cargando pesos existentes: Transformer_1_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: Transformer_2_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: Transformer_3_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: Transformer_4_train_seq_len_[51].pt
[SKIP] Cargando pesos existentes: Transformer_5_train_seq_len_[51].pt
Largo de secuencia de entrada: 101
Dataset cargado desde data/petes/train_160_mfcc.pt
 - 32453 ejemplos
 - 35 clases
Dataset cargado desde data/petes/val_160_mfcc.pt
 - 3875 ejemplos
 - 35 clases
Dataset cargado desde data/petes/test_160_mfcc.pt
 - 4381 ejemplos
 - 35 clases

[SKIP] Cargando pesos existentes: GRU_1_train_seq_len_[1

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



[SKIP] Cargando pesos existentes: CNN_1_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: CNN_2_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: CNN_3_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: CNN_4_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: CNN_5_train_seq_len_[101].pt

[SKIP] Cargando pesos existentes: TCNN_1_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: TCNN_2_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: TCNN_3_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: TCNN_4_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: TCNN_5_train_seq_len_[101].pt

[SKIP] Cargando pesos existentes: Transformer_1_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: Transformer_2_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: Transformer_3_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: Transformer_4_train_seq_len_[101].pt
[SKIP] Cargando pesos existentes: Transformer_5_train_seq_len_[101].pt
Largo d

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



[SKIP] Cargando pesos existentes: CNN_1_train_seq_len_[297].pt
[SKIP] Cargando pesos existentes: CNN_2_train_seq_len_[297].pt
[SKIP] Cargando pesos existentes: CNN_3_train_seq_len_[297].pt
[SKIP] Cargando pesos existentes: CNN_4_train_seq_len_[297].pt
[SKIP] Cargando pesos existentes: CNN_5_train_seq_len_[297].pt

[SKIP] Cargando pesos existentes: TCNN_1_train_seq_len_[297].pt
[SKIP] Cargando pesos existentes: TCNN_2_train_seq_len_[297].pt
[SKIP] Cargando pesos existentes: TCNN_3_train_seq_len_[297].pt
[SKIP] Cargando pesos existentes: TCNN_4_train_seq_len_[297].pt
[SKIP] Cargando pesos existentes: TCNN_5_train_seq_len_[297].pt

[SKIP] Cargando pesos existentes: Transformer_1_train_seq_len_[297].pt
   [TRAIN] Archivo no encontrado, entrenando: Transformer_2_train_seq_len_[297].pt
1015
Epoch 1/30
Iteration 338 - Batch 338/1015 - Train loss: 1.625116198375865, Train acc: 0.5211098820058997
Iteration 676 - Batch 676/1015 - Train loss: 1.2364386987915068, Train acc: 0.6414327917282127
Ite

In [None]:
ROOT_DIR = os.path.join("data","features")
SAVE_DIR = ROOT_DIR
train_dataset = FeaturesDataset(os.path.join(SAVE_DIR, "train_wav2vec2.pt"))
val_dataset   = FeaturesDataset(os.path.join(SAVE_DIR, "val_wav2vec2.pt"))
test_dataset  = FeaturesDataset(os.path.join(SAVE_DIR, "test_wav2vec2.pt"))

In [None]:
device = "cuda"
def infer_model_type(path):
    name = path.lower()
    if "gru" in name:
        return "GRU"
    if "lstm" in name:
        return "LSTM"
    if "rnn" in name and "cnn" not in name:
        return "RNN"
    if "tcnn" in name:
        return "TCNN"
    if "transformer" in name:
        return "TRANSFORMER"
    raise ValueError(f"No pude inferir tipo de modelo desde: {path}")

MODEL_CLASS_BY_TYPE = {
    "GRU": RNNModel,
    "LSTM": RNNModel,
    "RNN": RNNModel,
    "TCNN": CNNModel,
    "TRANSFORMER": TransformerModel,
}

def load_model_by_type(model_path, device="cuda", config_override=None):
    """
    Carga un modelo según su tipo (GRU, LSTM, RNN, TCNN, TRANSFORMER)
    y permite pasar un config explícito.
    """
    model_type = infer_model_type(model_path)

    # si se entrega config explícito → úsalo
    if config_override is not None:
        return load_trained_model(
            MODEL_CLASS_BY_TYPE[model_type],
            model_path,
            device=device,
            config=config_override
        )

    # configs por defecto según tipo
    if model_type in ["GRU", "LSTM", "RNN"]:
        cfg = {
            'rnn_type': model_type,
            'n_input_channels': 768,
            'hidd_size': 256,
            'out_features': 35,
            'num_layers': 1
        }
        return load_trained_model(RNNModel, model_path, device=device, config=cfg)

    elif model_type == "TCNN":
        cfg = {
            'n_input_channels': 768,
            'hidd_size': 64,
            'out_features': 35
        }
        return load_trained_model(CNNModel, model_path, device=device, config=cfg)

    elif model_type == "TRANSFORMER":
        return load_trained_model(TransformerModel, model_path, device=device)

    else:
        raise ValueError(model_type)

model_paths = sorted(glob.glob("old_model_weights/*.pt"))

# Skippear transformers antes de procesar
model_paths = [p for p in model_paths if infer_model_type(p) != "TRANSFORMER"]

# Diccionario para guardar modelos por arquitectura
models_by_arch = {}
paths_by_arch = {}

for path in model_paths:
    arch = infer_model_type(path)

    model = load_model_by_type(path, device="cuda")
    models_by_arch.setdefault(arch, []).append(model)
    paths_by_arch.setdefault(arch, []).append(path)

print("Modelos cargados por arquitectura:")
for arch, lst in models_by_arch.items():
    print(f"{arch}: {len(lst)} modelos")

results_by_arch = {}
times_by_arch = {}
f1_dists_by_arch = {}


# batch_size = 32
# criterion = nn.CrossEntropyLoss()
# test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# for arch, model_list in models_by_arch.items():
#     print(f"\nEvaluando arquitectura: {arch}")

#     metrics_mean, metrics_std, all_metrics, mean_times, f1_dists = \
#         evaluate_models_metrics(model_list, test_loader, criterion)

    

#     results_by_arch[arch] = {
#         "mean": metrics_mean,
#         "std": metrics_std,
#         "all": all_metrics,
#     }
#     print(results_by_arch[arch])

#     # Guardamos tiempos y distribuciones f1
#     times_by_arch[arch] = mean_times
#     f1_dists_by_arch[arch] = f1_dists

# # ============================================================
# # 1) Calcular métricas agregadas por arquitectura
# # ============================================================

os.makedirs("img", exist_ok=True)

for arch, model_list in models_by_arch.items():
    print(f"\n=== Arquitectura: {arch} ===")

    # Path base para guardar el PDF
    outfile = os.path.join("img", f"conf_mat_{arch}.pdf")

    # Llamar directamente a tu función tal como existe
    metrics = get_metrics_and_confusion_matrix(
        model_list,
        test_dataset,
        name=outfile
    )



# arch_names = []
# arch_f1_values = []
# arch_f1_dists = []
# arch_time_mean = []
# arch_time_std = []

# for arch in models_by_arch.keys():
#     arch_names.append(arch)

#     # Distribuciones F1 (1 valor por modelo)
#     f1_vals = results_by_arch[arch]["all"]["f1"]
#     arch_f1_values.append(np.mean(f1_vals))
#     arch_f1_dists.append(f1_vals)

#     tvals = times_by_arch[arch]     
#     arch_time_mean.append(np.mean(tvals))
#     arch_time_std.append(np.std(tvals))


# arch_time_mean_ms = [t * 100000 for t in arch_time_mean]
# arch_time_std_ms  = [s * 100000 for s in arch_time_std]

# plot_f1_vs_inference_time_with_error_bars(
#     arch_names,
#     arch_time_mean_ms,
#     arch_f1_dists,
#     time_stds=arch_time_std_ms
# )


In [None]:
ROOT_DIR = os.path.join("data","features")
SAVE_DIR = ROOT_DIR
device = "cuda"

lr = 5e-4
batch_size = 32
criterion = nn.CrossEntropyLoss()
n_trains = 3
epochs = 40
use_gpu = True


# Diccionarios para guardar resultados
f1_scores = {"GRU": [], "LSTM": [], "TCNN": [], "RNN": []}
f1_stds   = {"GRU": [], "LSTM": [], "TCNN": [], "RNN": []}

# Cargar datasets
train_dataset = FeaturesDataset(os.path.join(SAVE_DIR, "train_wav2vec2.pt"))
val_dataset   = FeaturesDataset(os.path.join(SAVE_DIR, "val_wav2vec2.pt"))
test_dataset  = FeaturesDataset(os.path.join(SAVE_DIR, "test_wav2vec2.pt"))

# Entrenar cada tipo de modelo
for arch in ["TCNN"]:
    print(f"\n--- Entrenando modelo tipo {arch} ---")

    models = []
    curves = []

    for k in range(n_trains):
        print(f"Entrenamiento {k+1}/{n_trains}")

        # Crear modelo según tipo
        if arch in ["GRU", "LSTM", "RNN"]:
            config = {
                'rnn_type': arch,
                'n_input_channels': 768,
                'hidd_size': 256,
                'out_features': 35,
                'num_layers': 1
            }
            model = RNNModel(**config)
        elif arch == "TCNN":
            config = {
                'n_input_channels': 768,
                'hidd_size': 64,
                'out_features': 35
            }
            model = CNNModel(**config)
        else:
            raise ValueError("Modelo no reconocido")

        # Entrenamiento
        curve, _ = train_model(
            model,
            train_dataset,
            val_dataset,
            epochs,
            criterion,
            batch_size,
            lr,
            n_evaluations_per_epoch=3,
            use_gpu=use_gpu,
            patience=15,
            model_config=config,
            model_arch=arch
        )
        curves.append(curve)
        models.append(model)

    # Curvas de entrenamiento
    show_curves(curves, suptitle=f"{arch}")

    # Evaluación
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    metrics_mean, metrics_std, _ = evaluate_models_metrics(models, test_loader, criterion, use_gpu=use_gpu)

    f1_scores[arch].append(metrics_mean["f1"])
    f1_stds[arch].append(metrics_std["f1"])

    print(f"F1 {arch}: {metrics_mean['f1']:.3f} ± {metrics_std['f1']:.3f}")

    get_metrics_and_confusion_matrix(models, test_dataset, name=f"{arch}")



In [None]:
# Preliminary testing
lr = 5e-4
batch_size = 32
criterion = nn.CrossEntropyLoss()
n_trains = 1 # Número de repeticiones para obtener media y std
epochs = 20   # Aumenta las épocas, los Transformers suelen necesitar más


train_dataset = FeaturesDataset(os.path.join(SAVE_DIR, f"train_16_mfcc.pt"))
val_dataset = FeaturesDataset(os.path.join(SAVE_DIR, f"val_16_mfcc.pt"))
test_dataset = FeaturesDataset(os.path.join(SAVE_DIR, f"test_16_mfcc.pt"))

# Usa las dimensiones de tu dataset (MFCCs y número de clases)
N_INPUT_FEATURES = train_dataset.features.shape[2]  # 13 MFCCs
N_OUTPUT_CLASSES = len(train_dataset.label_to_idx)  # 35 clases

# --- Configuración del Transformer ---
TRANSFORMER_ARCH_PARAMS = {
    "n_input_features": N_INPUT_FEATURES,
    "n_output_classes": N_OUTPUT_CLASSES,
    "d_model": 128,
    "nhead": 8,
    "n_layers": 4, # Puedes empezar con 4-6 capas
    "d_hid": 512,  # Debe ser mayor que d_model, e.g., 4 * d_model
}
# -----------------------------------

ARCH = 'Transformer'
print(f'Entrenando Modelo {ARCH} con d_model={TRANSFORMER_ARCH_PARAMS["d_model"]}')

times_of_training = []
models = []
curves = []

for k in range(n_trains):
    print(f'Entrenando modelo {k+1}/{n_trains}')
    
    model = TransformerModel(**TRANSFORMER_ARCH_PARAMS) 
    
    # Entrenar
    all_curves, times = train_model(
        model, 
        train_dataset, 
        val_dataset, 
        epochs, 
        criterion, 
        batch_size, 
        lr, 
        n_evaluations_per_epoch=3, 
        use_gpu=True
    )
    curves.append(all_curves)
    times_of_training.append(times)
    models.append(model)
    
show_curves(curves, ARCH)
get_metrics_and_confusion_matrix(models, test_dataset, ARCH)