In [None]:
import pickle
import os
import numpy as np
from scipy.signal import resample
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
from scipy.interpolate import interp1d
import typing as tp
import random

# You MUST ensure these are defined or imported in your actual environment.
from seanet_style import SEANetTransformerClassifier, create_mask 

# --- Dataset Classes ---

class ClassificationDataset(Dataset):
    def __init__(self, input_files, output_files, min_window=1000, max_window=4000,
                 noise_stds=1e-6, transform_probability=0.95, step_size=50,
                 channel_wise_noise_stds=None,
                 scale_range=(0.9, 1.1),
                 num_warp_knots=4,
                 warp_scale=0.1):
        self.input_files = input_files
        self.output_files = output_files
        assert len(self.input_files) == len(self.output_files), "Input and output file counts mismatch."

        self.file_metadata = []
        # Use np.load(f, mmap_mode='r') to save memory
        self.input_memmaps = [np.load(f, mmap_mode='r') for f in self.input_files]
        self.output_memmaps = [np.load(f, mmap_mode='r') for f in self.output_files]

        self.min_window = min_window
        self.max_window = max_window
        self.step_size = step_size
        self.valid_window_sizes = list(range(min_window, max_window + 1, step_size))
        if not self.valid_window_sizes:
            self.valid_window_sizes = [min_window]

        self.noise_stds = noise_stds
        self.transform_prob = transform_probability
        self.channel_wise_noise_stds = channel_wise_noise_stds
        self.scale_range = scale_range

        self.num_warp_knots = num_warp_knots
        self.warp_scale = warp_scale

        self.total_windows = 0
        for i, (inp_mmap, out_mmap) in enumerate(zip(self.input_memmaps, self.output_memmaps)):
            channels, length = inp_mmap.shape

            num_windows_in_file = 0
            if length >= self.min_window:
                # Simple calculation of windows based on min_window and step_size
                num_windows_in_file = (length - self.min_window) // self.step_size + 1

            self.file_metadata.append({
                'input_shape': inp_mmap.shape,
                'output_shape': out_mmap.shape,
                'length': length,
                'channels': channels,
                'num_windows': num_windows_in_file,
                'start_idx_offset': self.total_windows
            })
            self.total_windows += num_windows_in_file

    def __len__(self):
        return self.total_windows

    def _time_warp(self, input_tensor, output_tensor):
        """
        Applies time warping to the input and output tensors.
        input_tensor: (C, L)
        output_tensor: (C', L)
        """
        length = input_tensor.shape[1]
        t = np.linspace(0, 1, length)

        knots_x = np.linspace(0, 1, self.num_warp_knots)
        knots_y = knots_x + np.random.normal(0, self.warp_scale, self.num_warp_knots)
        knots_y = np.sort(knots_y)

        knots_y[0] = 0
        knots_y[-1] = 1

        warp_func = interp1d(knots_x, knots_y, kind='cubic', fill_value="extrapolate")
        warped_t = warp_func(t)
        warped_t = np.clip(warped_t, 0, 1)

        warped_input_channels = []
        for i in range(input_tensor.shape[0]):
            interp_func = interp1d(t, input_tensor[i, :].numpy(), kind='linear', fill_value="extrapolate")
            warped_input_channels.append(interp_func(warped_t))
        warped_input = torch.from_numpy(np.array(warped_input_channels)).float()

        warped_output_channels = []
        for i in range(output_tensor.shape[0]):
            interp_func = interp1d(t, output_tensor[i, :].numpy(), kind='linear', fill_value="extrapolate")
            warped_output_channels.append(interp_func(warped_t))
        warped_output = torch.from_numpy(np.array(warped_output_channels)).float()

        # Reshape/pad to original length L
        if warped_input.shape[1] < length:
            pad_amount = length - warped_input.shape[1]
            warped_input = F.pad(warped_input, (0, pad_amount), 'constant', 0)
            warped_output = F.pad(warped_output, (0, pad_amount), 'constant', 0)
        elif warped_input.shape[1] > length:
            warped_input = warped_input[:, :length]
            warped_output = warped_output[:, :length]
        
        threshold = 0.5
        warped_output = (warped_output > threshold).float()
        return warped_input, warped_output


    def __getitem__(self, idx):
        current_offset = 0
        file_idx = -1
        local_idx = -1
        for i, meta in enumerate(self.file_metadata):
            if idx < current_offset + meta['num_windows']:
                file_idx = i
                local_idx = idx - current_offset
                break
            current_offset += meta['num_windows']

        if file_idx == -1:
            raise IndexError("Index out of bounds for dataset.")

        meta = self.file_metadata[file_idx]

        # 1. Randomly select window length
        window_length = np.random.choice(self.valid_window_sizes)

        # 2. Window selection strategy: base start point aligned with step size + random offset
        base_start = local_idx * self.step_size
        max_possible_start = max(0, meta['length'] - window_length)
        random_offset = np.random.randint(0, self.step_size)
        start = min(base_start + random_offset, max_possible_start)
        end = start + window_length

        if end > meta['length']:
            end = meta['length']
            start = max(0, end - window_length)

        input_data_np = self.input_memmaps[file_idx][:, start:end]
        output_data_np = self.output_memmaps[file_idx][:, start:end]

        input_window = torch.from_numpy(input_data_np).float()
        output_window = torch.from_numpy(output_data_np).float()

        # Assuming label is in the first channel
        if output_window.shape[0] > 1:
            output_window = output_window[0:1, :]

        # 3. Data Augmentation
        if np.random.rand() < self.transform_prob:
            # Random Scaling
            scale_factor = np.random.uniform(self.scale_range[0], self.scale_range[1])
            input_window = input_window * scale_factor

            # Randomly add noise
            if self.channel_wise_noise_stds is not None:
                # noise_scale is expected to be a tensor of shape (C, 1) or (C)
                noise_scale = self.channel_wise_noise_stds.to(input_window.device)
                if noise_scale.dim() == 1:
                    noise_scale = noise_scale.unsqueeze(-1) # (C, 1)
                
                # Expand noise_scale to match input_window.shape[1] for broadcasting
                noise = torch.randn_like(input_window) * noise_scale
            else:
                noise = torch.randn_like(input_window) * self.noise_stds

            input_window = input_window + noise

            # Apply time warping
            input_window, output_window = self._time_warp(input_window, output_window)

        # Ensure output_window length matches input_window length (time axis)
        if input_window.shape[-1] != output_window.shape[-1]:
             min_len = min(input_window.shape[-1], output_window.shape[-1])
             input_window = input_window[..., :min_len]
             output_window = output_window[..., :min_len]

        return input_window, output_window, input_window.shape[-1], output_window.shape[-1]


class ValidationDataset(Dataset):
    def __init__(self, input_files, output_files, chunk_size=1000):
        self.input_files = input_files
        self.output_files = output_files
        assert len(self.input_files) == len(self.output_files), "Input and output file counts mismatch."
        
        self.chunk_size = chunk_size
        self.data_chunks = []

        # Load data and split into chunks
        for i, (input_file, output_file) in enumerate(zip(self.input_files, self.output_files)):
            inp_full_np = np.load(input_file)
            out_full_np = np.load(output_file)

            min_len = min(inp_full_np.shape[1], out_full_np.shape[1])
            input_full_np = inp_full_np[:, :min_len]
            output_full_np = out_full_np[:, :min_len]

            # Assuming label is in the first channel
            if output_full_np.shape[0] > 1:
                output_full_np = output_full_np[0:1, :]

            total_input_len = input_full_np.shape[1]
            
            for start_idx in range(0, total_input_len, chunk_size):
                end_idx = min(start_idx + chunk_size, total_input_len)
                
                input_chunk = torch.from_numpy(input_full_np[:, start_idx:end_idx]).float()
                output_chunk = torch.from_numpy(output_full_np[:, start_idx:end_idx]).float()
                
                self.data_chunks.append((input_chunk, output_chunk, input_chunk.shape[1], output_chunk.shape[1]))

    def __len__(self):
        return len(self.data_chunks)

    def __getitem__(self, idx):
        return self.data_chunks[idx]

# --- Helper Functions ---

def collate_fn(batch):
    inputs, outputs, input_lengths_orig, output_lengths_orig = zip(*batch)
    
    input_lengths = list(input_lengths_orig)
    output_lengths = list(output_lengths_orig)
    
    max_input_len = max(input_lengths)
    padded_inputs = torch.stack([
        torch.nn.functional.pad(x, (0, max_input_len - x.shape[-1])) 
        for x in inputs
    ])
    
    max_output_len = max(output_lengths)
    padded_outputs = torch.stack([
        torch.nn.functional.pad(y, (0, max_output_len - y.shape[-1]))
        for y in outputs
    ])
    
    return padded_inputs, padded_outputs, input_lengths, output_lengths

def calculate_classification_metrics_optimized(tp, fp, tn, fn, all_predictions_probs_flat, all_targets_flat):
    """
    Calculates classification metrics using accumulated confusion matrix components.
    Includes ROC AUC calculation which still requires all probabilities and targets.
    """
    # Calculate F1, Precision, Recall from TP, FP, TN, FN
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
    
    # ROC AUC calculation
    if len(np.unique(all_targets_flat)) > 1 and len(all_predictions_probs_flat) > 0:
        roc_auc = roc_auc_score(all_targets_flat, all_predictions_probs_flat)
    else:
        roc_auc = float('nan') 

    return f1, precision, recall, roc_auc

def evaluate_model(model, dataloader, criterion, num_output_channels, device, threshold=0.5, max_auc_samples=500000):
    model.eval()
    total_loss = 0.0
    total_valid_elements = 0
    
    true_positives = 0
    false_positives = 0
    true_negatives = 0
    false_negatives = 0

    all_valid_predictions_probs_for_auc = []
    all_valid_targets_for_auc = []

    with torch.no_grad():
        progress_bar_val = tqdm(dataloader, desc="Validation")
        for batch in progress_bar_val:
            inputs, targets, input_lengths, output_lengths = batch
            inputs, targets = inputs.to(device), targets.to(device)

            targets = targets.float()

            predictions_logits = model(inputs, original_lengths=input_lengths)
            
            max_output_len = predictions_logits.size(-1)
            # Get time mask for valid output elements
            output_mask_batch_time = create_mask(
                [min(l, max_output_len) for l in output_lengths], 
                max_output_len,
                device
            )

            # Mask only along the time dimension (C=1 here)
            expanded_output_mask = output_mask_batch_time.unsqueeze(1) 

            predictions_logits_flat = predictions_logits.flatten()
            targets_flat = targets.flatten()

            # Flatten mask and select valid elements
            valid_elements_mask_flat = (~expanded_output_mask).flatten()

            valid_predictions_logits = predictions_logits_flat[valid_elements_mask_flat]
            valid_targets = targets_flat[valid_elements_mask_flat]

            if valid_targets.numel() > 0:
                loss_tensor = criterion(valid_predictions_logits, valid_targets)
                loss = loss_tensor.mean()
                total_loss += loss.item() * valid_targets.numel()
                total_valid_elements += valid_targets.numel()

                valid_predictions_probs = torch.sigmoid(valid_predictions_logits)
                valid_predictions_binary = (valid_predictions_probs >= threshold).int()

                true_positives += ((valid_predictions_binary == 1) & (valid_targets == 1)).sum().item()
                false_positives += ((valid_predictions_binary == 1) & (valid_targets == 0)).sum().item()
                true_negatives += ((valid_predictions_binary == 0) & (valid_targets == 0)).sum().item()
                false_negatives += ((valid_predictions_binary == 0) & (valid_targets == 1)).sum().item()

                # --- Sampling for AUC START ---
                current_batch_size_valid = valid_targets.numel()
                if len(all_valid_predictions_probs_for_auc) < max_auc_samples:
                    # If total samples below limit, add all
                    all_valid_predictions_probs_for_auc.extend(valid_predictions_probs.cpu().tolist())
                    all_valid_targets_for_auc.extend(valid_targets.cpu().tolist())
                elif current_batch_size_valid > 0:
                    # If over limit, randomly sample a fraction from current batch
                    num_to_sample = min(current_batch_size_valid, max_auc_samples // 10)
                    if num_to_sample > 0:
                        indices = torch.randperm(current_batch_size_valid)[:num_to_sample]
                        all_valid_predictions_probs_for_auc.extend(valid_predictions_probs[indices].cpu().tolist())
                        all_valid_targets_for_auc.extend(valid_targets[indices].cpu().tolist())
                        
                        if len(all_valid_predictions_probs_for_auc) > max_auc_samples:
                            # Keep only the latest samples up to max_auc_samples
                            all_valid_predictions_probs_for_auc = all_valid_predictions_probs_for_auc[-max_auc_samples:]
                            all_valid_targets_for_auc = all_valid_targets_for_auc[-max_auc_samples:]
                # --- Sampling for AUC END ---

            else:
                loss = torch.tensor(0.0).to(device)

    avg_loss = total_loss / (total_valid_elements + 1e-8) if total_valid_elements > 0 else float('nan')
    
    sampled_probs_np = np.array(all_valid_predictions_probs_for_auc)
    sampled_targets_np = np.array(all_valid_targets_for_auc)

    f1, precision, recall, roc_auc = calculate_classification_metrics_optimized(
        true_positives, false_positives, true_negatives, false_negatives,
        sampled_probs_np, sampled_targets_np
    )
    
    return avg_loss, f1, precision, recall, roc_auc

def plot_metrics(train_losses, val_losses, train_f1_scores, val_f1_scores,
                 train_precisions, val_precisions, train_recalls, val_recalls,
                 train_roc_aucs, val_roc_aucs,
                 num_epochs, save_dir="."):
    
    epochs = range(1, num_epochs + 1)

    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    fig.suptitle('Training and Validation Metrics per Epoch', fontsize=16)

    # Loss Plot
    axes[0, 0].plot(epochs, train_losses, label='Training Loss')
    # Filter out NaNs for plotting if validation was skipped
    valid_val_losses = [l for l in val_losses if not math.isnan(l)]
    axes[0, 0].plot(range(1, len(valid_val_losses) + 1), valid_val_losses, label='Validation Loss')
    axes[0, 0].set_title('Binary Cross Entropy Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)

    # F1 Plot
    axes[0, 1].plot(epochs, train_f1_scores, label='Training F1-score', color='orange')
    valid_val_f1_scores = [l for l in val_f1_scores if not math.isnan(l)]
    axes[0, 1].plot(range(1, len(valid_val_f1_scores) + 1), valid_val_f1_scores, label='Validation F1-score', color='red')
    axes[0, 1].set_title('F1-score')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('F1-score')
    axes[0, 1].legend()
    axes[0, 1].grid(True)

    # Precision Plot
    axes[0, 2].plot(epochs, train_precisions, label='Training Precision', color='green', linestyle='--')
    valid_val_precisions = [l for l in val_precisions if not math.isnan(l)]
    axes[0, 2].plot(range(1, len(valid_val_precisions) + 1), valid_val_precisions, label='Validation Precision', color='green')
    axes[0, 2].set_title('Precision')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('Precision')
    axes[0, 2].legend()
    axes[0, 2].grid(True)

    # Recall Plot
    axes[1, 0].plot(epochs, train_recalls, label='Training Recall', color='purple', linestyle='--')
    valid_val_recalls = [l for l in val_recalls if not math.isnan(l)]
    axes[1, 0].plot(range(1, len(valid_val_recalls) + 1), valid_val_recalls, label='Validation Recall', color='purple')
    axes[1, 0].set_title('Recall')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Recall')
    axes[1, 0].legend()
    axes[1, 0].grid(True)

    # ROC AUC Plot
    axes[1, 1].plot(epochs, train_roc_aucs, label='Training ROC AUC', color='brown', linestyle='--')
    valid_val_roc_aucs = [l for l in val_roc_aucs if not math.isnan(l)]
    axes[1, 1].plot(range(1, len(valid_val_roc_aucs) + 1), valid_val_roc_aucs, label='Validation ROC AUC', color='brown')
    axes[1, 1].set_title('ROC AUC')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('ROC AUC Score')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    
    axes[1, 2].axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(os.path.join(save_dir, 'classification_metrics_combined.png'))
    plt.close(fig)


In [None]:
# --- Dynamic STD Calculation and Data Splitting ---

# Root directories for data
MEG_ROOT = "./libribrain/data/meg_data_npy"
LABEL_ROOT = "./libribrain/data/labels_npy"
# Set the validation split ratio (0.0 means all data is used for training)
VALIDATION_SPLIT = 0.2 
NUM_CHANNELS = 48 # Fixed number of channels

# Get all file pairs
input_files = sorted([os.path.join(MEG_ROOT, f) for f in os.listdir(MEG_ROOT) if f.endswith('.npy')])
output_files = sorted([os.path.join(LABEL_ROOT, f) for f in os.listdir(LABEL_ROOT) if f.endswith('.npy')])

# Ensure file counts match
assert len(input_files) == len(output_files), "Input and output file counts mismatch in directories"

# Pair files and shuffle
file_pairs = list(zip(input_files, output_files))
random.seed(42) # Ensure consistent split
random.shuffle(file_pairs)

num_total = len(file_pairs)
# Calculate split, ensuring num_val is at least 0
num_val = int(num_total * VALIDATION_SPLIT)
num_val = max(0, num_val) 
num_train = num_total - num_val

# Split the file pairs
train_file_pairs = file_pairs[num_val:]
val_file_pairs = file_pairs[:num_val]

# Unzip the pairs
train_input_files, train_output_files = zip(*train_file_pairs)
if val_file_pairs:
    val_input_files, val_output_files = zip(*val_file_pairs)
else:
    # Handle the case where num_val is 0
    val_input_files, val_output_files = tuple(), tuple()

print(f"Total files: {num_total}, Train files: {num_train}, Validation files: {num_val}")


In [None]:
# --- Step 1: Dynamically Calculate Channel Standard Deviations (STD) on Training Set ---
print("--- Calculating Channel Standard Deviations on Training Set ---")
channel_sum = np.zeros(NUM_CHANNELS, dtype=np.float64)
channel_sq_sum = np.zeros(NUM_CHANNELS, dtype=np.float64)
channel_count = np.zeros(NUM_CHANNELS, dtype=np.int64)

for input_file in tqdm(train_input_files, desc="Processing MEG files for STD calculation"):
    # Use mmap_mode='r' to save memory
    data = np.load(input_file, mmap_mode='r').astype(np.float64)
    
    if data.shape[0] != NUM_CHANNELS:
        # Simple check for transposed data
        if data.shape[1] == NUM_CHANNELS:
            data = data.T
        else:
            print(f"Warning: File {input_file} has unexpected shape {data.shape}. Skipping.")
            continue
            
    T = data.shape[1]
    
    channel_sum += np.sum(data, axis=1)
    channel_sq_sum += np.sum(np.square(data), axis=1)
    channel_count += T

safe_channel_count = np.where(channel_count > 0, channel_count, 1)
mean = channel_sum / safe_channel_count
variance = (channel_sq_sum / safe_channel_count) - (mean**2)
std = np.sqrt(np.maximum(variance, 0))

# Convert to torch tensor for use in the dataset (C, 1)
channel_stds = torch.from_numpy(std).float().unsqueeze(1)
print("--- STD Calculation Complete ---")




In [None]:
# --- Step 2: Dynamically Calculate Positive Weight (pos_weight) for Training Labels ---
print("--- Calculating Positive Weight for Training Set Labels ---")
total_positive_samples = 0
total_negative_samples = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for output_file in tqdm(train_output_files, desc="Processing Label files for pos_weight"):
    # Label files should match MEG length
    label_data = np.load(output_file, mmap_mode='r').astype(np.float32)
    
    # Assuming labels are (C, T) and we only care about the first channel
    if label_data.ndim == 2:
        label_data = label_data[0, :] if label_data.shape[0] > 0 else label_data.flatten()
    
    # Labels should be 0 or 1
    total_positive_samples += np.sum(label_data == 1)
    total_negative_samples += np.sum(label_data == 0)

if total_positive_samples > 0:
    # pos_weight = Num_Negative / Num_Positive
    pos_weight_value = total_negative_samples / total_positive_samples
else:
    # Extreme case: no positive samples. Set to 1.0 (or a large number, but 1.0 is safer if training will continue)
    pos_weight_value = 1.0 
    print("Warning: No positive samples found in the training set labels! pos_weight set to 1.0.")

pos_weight = torch.tensor([pos_weight_value]).to(device)
print(f"Total Positive Samples in Train Set: {total_positive_samples}")
print(f"Total Negative Samples in Train Set: {total_negative_samples}")
print(f"Calculated pos_weight_value: {pos_weight_value:.6f}")
print("--- Pos Weight Calculation Complete ---")




In [None]:
# --- Step 3: Create Dataset and Dataloader ---

# Create training dataset and dataloader
dataset = ClassificationDataset(
    list(train_input_files), 
    list(train_output_files), 
    min_window=1000, 
    max_window=5000, 
    step_size=1200, 
    channel_wise_noise_stds=(channel_stds * 0.1) # Use calculated STD for noise scaling
)
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=16, shuffle=True, num_workers=4, pin_memory=True) 

# Create validation dataset and dataloader (conditional on split)
validation_dataloader = None 
if num_val > 0:
    validation_dataset = ValidationDataset(list(val_input_files), list(val_output_files))
    validation_dataloader = DataLoader(validation_dataset, collate_fn=collate_fn, batch_size=16, shuffle=False)
    print("Validation Dataloader created.")
else:
    print("Validation split is 0, skipping Validation Dataloader creation.")

# --- Model and Training Setup ---

num_epochs = 100
best_roc_auc = -1.0

# Model parameters
model_params = {
    'input_channels': 48,
    'sampling_rate': 250,
    'encoder_dimension': 16,
    'encoder_n_filters': 8,
    'encoder_ratios': [5, 5, 2],
    'transformer_n_heads': 4,
    'transformer_n_layers': 1,
    'transformer_dim_feedforward': 64,
    'transformer_dropout': 0.1,
    'decoder_out_channels': 1,
    'activation': 'GELU',
    'activation_params': {},
    'norm': 'InstanceNorm1d',
    'norm_params': {},
    'n_residual_layers': 1,
    'kernel_size': 7,
    'last_kernel_size': 7,
    'residual_kernel_size': 3,
    'dilation_base': 2,
    'compress': 2,
    'true_skip': False
}
model = SEANetTransformerClassifier(**model_params).to(device).float()

# Use dynamically calculated pos_weight
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=10, factor=0.5, min_lr=1e-8)
num_channels = model_params['decoder_out_channels']

# Metric tracking lists
train_losses = []
train_f1_scores = []
train_precisions = []
train_recalls = []
train_roc_aucs = []

val_losses = []
val_f1_scores = []
val_precisions = []
val_recalls = []
val_roc_aucs = []

classification_threshold = 0.5
max_auc_samples = 50000 

# --- Training Loop ---

for epoch in range(num_epochs):
    # --- Training Phase ---
    model.train()
    running_loss_sum = 0.0
    total_train_valid_elements = 0
    
    train_true_positives = 0
    train_false_positives = 0
    train_true_negatives = 0
    train_false_negatives = 0

    all_train_valid_predictions_probs_for_auc = []
    all_train_valid_targets_for_auc = []
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs} (Train)")
    
    for batch in progress_bar:
        inputs, targets, input_lengths, output_lengths = batch
        inputs, targets = inputs.to(device), targets.to(device)
        targets = targets.float()
        
        optimizer.zero_grad()
        predictions_logits = model(inputs, original_lengths=input_lengths)
        
        max_output_len = predictions_logits.size(-1)
        output_mask_batch_time = create_mask(
            [min(l, max_output_len) for l in output_lengths], 
            max_output_len,
            device
        )

        expanded_output_mask = output_mask_batch_time.unsqueeze(1) 

        predictions_logits_flat = predictions_logits.flatten()
        targets_flat = targets.flatten()

        valid_elements_mask_flat = (~expanded_output_mask).flatten()

        valid_predictions_logits = predictions_logits_flat[valid_elements_mask_flat]
        valid_targets = targets_flat[valid_elements_mask_flat]

        if valid_targets.numel() > 0:
            loss = criterion(valid_predictions_logits, valid_targets).mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            running_loss_sum += loss.item() * valid_targets.numel()
            total_train_valid_elements += valid_targets.numel()
            
            with torch.no_grad():
                valid_predictions_probs = torch.sigmoid(valid_predictions_logits)
                valid_predictions_binary = (valid_predictions_probs >= classification_threshold).int()

                train_true_positives += ((valid_predictions_binary == 1) & (valid_targets == 1)).sum().item()
                train_false_positives += ((valid_predictions_binary == 1) & (valid_targets == 0)).sum().item()
                train_true_negatives += ((valid_predictions_binary == 0) & (valid_targets == 0)).sum().item()
                train_false_negatives += ((valid_predictions_binary == 0) & (valid_targets == 1)).sum().item()

                if len(all_train_valid_predictions_probs_for_auc) < max_auc_samples:
                    all_train_valid_predictions_probs_for_auc.extend(valid_predictions_probs.cpu().tolist())
                    all_train_valid_targets_for_auc.extend(valid_targets.cpu().tolist())
        else:
            loss = torch.tensor(0.0).to(device)
            
        avg_loss_display = running_loss_sum / (total_train_valid_elements + 1e-8)
        progress_bar.set_postfix(
            loss=loss.item(),
            avg_loss=avg_loss_display
        )
            
    epoch_avg_loss = running_loss_sum / (total_train_valid_elements + 1e-8) if total_train_valid_elements > 0 else float('nan')

    train_f1, train_precision, train_recall, train_roc_auc = calculate_classification_metrics_optimized(
        train_true_positives, train_false_positives, train_true_negatives, train_false_negatives,
        np.array(all_train_valid_predictions_probs_for_auc),
        np.array(all_train_valid_targets_for_auc)
    )
    
    train_losses.append(epoch_avg_loss)
    train_f1_scores.append(train_f1)
    train_precisions.append(train_precision)
    train_recalls.append(train_recall)
    train_roc_aucs.append(train_roc_auc)

    print(f"\n--- Epoch {epoch+1}/{num_epochs} Training Summary ---")
    print(f"Train Average Loss (BCE): {epoch_avg_loss:.6f}")
    print(f"Train F1-score: {train_f1:.6f}")
    print(f"Train Precision: {train_precision:.6f}")
    print(f"Train Recall: {train_recall:.6f}")
    print(f"Train ROC AUC (Sampled): {train_roc_auc:.6f}")

    # --- Validation Phase (Conditional) ---
    if validation_dataloader is not None:
        val_loss, val_f1, val_precision, val_recall, val_roc_auc = evaluate_model(
            model, validation_dataloader, criterion, num_channels, device, classification_threshold, max_auc_samples=500000)
        
        val_losses.append(val_loss)
        val_f1_scores.append(val_f1)
        val_precisions.append(val_precision)
        val_recalls.append(val_recall)
        val_roc_aucs.append(val_roc_auc)


        print(f"\n--- Epoch {epoch+1}/{num_epochs} Validation Summary ---")
        print(f"Validation BCE Loss: {val_loss:.6f}")
        print(f"Validation F1-score: {val_f1:.6f}")
        print(f"Validation Precision: {val_precision:.6f}")
        print(f"Validation Recall: {val_recall:.6f}")
        print(f"Validation ROC AUC: {val_roc_auc:.6f}")
        print(f"Current Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Step scheduler based on validation metric
        scheduler.step(val_roc_auc) 

        # Save best model
        if val_roc_auc > best_roc_auc:
            best_roc_auc = val_roc_auc
            os.makedirs("./libribrain", exist_ok=True)
            torch.save(model.state_dict(), f"./libribrain/speech_detection.pth")
            print(f"New best model saved with validation ROC AUC: {best_roc_auc:.6f}")
            
    else:
        # If no validation data, step scheduler based on training F1 (or just skip/use loss)
        print("\nSkipping validation and scheduler step based on validation (VALIDATION_SPLIT = 0).")
        scheduler.step(train_f1) # Using training F1 to still allow LR decay
        
        # Append NaN values to keep metric lists the same length as epochs for plotting
        val_losses.append(float('nan'))
        val_f1_scores.append(float('nan'))
        val_precisions.append(float('nan'))
        val_recalls.append(float('nan'))
        val_roc_aucs.append(float('nan'))
        
print("\nTraining complete!")
torch.save(model.state_dict(), f"./libribrain/speech_detection_final.pth")
os.makedirs("./libribrain", exist_ok=True)
plot_metrics(train_losses, val_losses, train_f1_scores, val_f1_scores,
             train_precisions, val_precisions, train_recalls, val_recalls,
             train_roc_aucs, val_roc_aucs,
             num_epochs, save_dir="./libribrain/")
print(f"Combined classification metrics plot saved to ./libribrain/classification_metrics_combined.png")