In [None]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import torch.nn.functional as F
import random

# --- הגדרות גלובליות (יש להתאים במידת הצורך) ---
DATA_PATH = "your_data_folder_path"  # <--- שנה לנתיב הנתונים שלך!
SEQUENCE_LENGTH = 512
HOP_SIZE = 256
SEED = 42
N_SPLITS = 5  # מספר ה-Folds ב-Cross-Validation של ה-Classifier

# פרמטרים של Autoencoder (לאימון יחיד)
AE_BATCH_SIZE = 64
ENCODING_DIM_AE = 64
AE_DROPOUT_RATE = 0.1
AE_LEARNING_RATE = 1e-3
AE_WEIGHT_DECAY = 1e-4
AE_TRAIN_EPOCHS_SINGLE = 75 # מספר epochs לאימון ה-AE היחיד, אפשר להתאים
AE_PATIENCE_SINGLE = 15     # סבלנות ל-early stopping של ה-AE היחיד
AE_INPUT_NOISE_STD = 0.05
AE_PLOT_RECONSTRUCTION_SINGLE = True # להציג גרף שחזורים של ה-AE היחיד

# פרמטרים של Classifier (עבור K-Fold)
CLASSIFIER_BATCH_SIZE = 64
CLASSIFIER_EPOCHS = 100 # אפשר להתאים, תלוי כמה מהר מתכנס
CLASSIFIER_PATIENCE = 25
CLASSIFIER_LR = 1e-3
CLASSIFIER_WEIGHT_DECAY = 1e-3
CLASSIFIER_NUM_AUG_PER_SAMPLE = 2

# הגדרת התקן (GPU אם זמין)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"CUDA Seed Set. Deterministic: {torch.backends.cudnn.deterministic}, Benchmark: {torch.backends.cudnn.benchmark}")


# --- חלק 1: טעינת נתונים והכנתם ---
def load_and_prepare_data(data_path_folder):
    file_mapping = {
        'car_nothing(AVI).csv': 'quiet',
        'carnew(AVI).csv': 'vehicle',
        'human_nothing(AVI).csv': 'quiet',
        'human(AVI).csv': 'human'
    }
    label_encoding = {'quiet': 0, 'vehicle': 1, 'human': 2}
    all_data = []
    all_labels = []

    if not os.path.exists(data_path_folder):
        print(f"Data folder {data_path_folder} not found. Please ensure the path is correct.")
        return np.array([]), np.array([])

    print(f"Loading data from: {data_path_folder}")
    for filename, activity_type in file_mapping.items():
        filepath = os.path.join(data_path_folder, filename)
        if not os.path.exists(filepath):
            print(f"Warning: File not found at {filepath}. Skipping.")
            continue
        try:
            df = pd.read_csv(filepath, header=None)
            if not df.empty and df.shape[1] > 0:
                data_values = df.iloc[:, 0].values.astype(np.float32) # Ensure float32 for PyTorch
                all_data.extend(data_values)
                all_labels.extend([label_encoding[activity_type]] * len(data_values))
                print(f"  Successfully loaded {len(data_values)} points from {filename} as '{activity_type}'")
            else:
                print(f"  Warning: File {filename} is empty or has no data columns. Skipping.")
        except Exception as e:
            print(f"  Error reading {filename}: {e}")

    all_data_np = np.array(all_data)
    all_labels_np = np.array(all_labels)

    if len(all_data_np) == 0:
        print("No data was loaded. Check DATA_PATH and file contents.")
    else:
        print(f"Total raw data points loaded: {len(all_data_np)}")
    return all_data_np, all_labels_np

def create_sequences_with_overlap(data, labels, sequence_length, hop_size):
    sequences = []
    sequence_labels = []
    unique_labels_vals = np.unique(labels)

    if len(data) == 0:
        print("Cannot create sequences from empty raw data.")
        return np.array([]), np.array([])

    print(f"Creating sequences with length {sequence_length} and hop {hop_size}...")
    for label_val in unique_labels_vals:
        label_indices = np.where(labels == label_val)[0]
        current_label_data = data[label_indices]

        if len(current_label_data) < sequence_length:
            print(f"  Not enough data for label {label_val} to create a sequence of length {sequence_length}. Has {len(current_label_data)} points. Skipping.")
            continue

        num_sequences_for_label = 0
        if sequence_length == hop_size:
             num_sequences_for_label = len(current_label_data) // sequence_length
        elif len(current_label_data) >= sequence_length :
            num_sequences_for_label = (len(current_label_data) - sequence_length) // hop_size + 1
        else: # Should be caught by the check above, but as a safeguard
            print(f"  Unexpectedly few data for label {label_val} after length check. Skipping.")
            continue
        
        print(f"  For label {label_val}, creating {num_sequences_for_label} sequences.")
        for i in range(num_sequences_for_label):
            start_idx = i * hop_size
            end_idx = start_idx + sequence_length
            sequences.append(current_label_data[start_idx:end_idx])
            sequence_labels.append(label_val)

    if not sequences:
        print("No sequences were created. Check data, sequence_length, and hop_size.")
        return np.array([]), np.array([])
        
    return np.array(sequences, dtype=np.float32), np.array(sequence_labels)

# --- פונקציות עזר ומודלים ---
def get_padding_for_dilation(kernel_size, dilation):
    return (kernel_size - 1) * dilation // 2

def add_noise_to_batch(batch_x, noise_std, device, is_training):
    if noise_std > 0 and is_training:
        noise = torch.randn_like(batch_x) * noise_std
        return batch_x + noise.to(device)
    return batch_x

def plot_ae_reconstructions(model, dataloader, device, num_samples=3, epoch_num=None, current_loss=None, sequence_length_param=SEQUENCE_LENGTH, plot_now=True, title_prefix=""):
    if not plot_now or num_samples == 0:
        return
        
    model.eval()
    samples_done = 0
    fig_height = 2.5 * num_samples
    fig, axes = plt.subplots(num_samples, 1, figsize=(12, fig_height), squeeze=False)

    with torch.no_grad():
        for batch_x_val, _ in dataloader:
            batch_x_val_original = batch_x_val.clone().to(device)
            reconstructed_x_val = model(batch_x_val_original)

            for i in range(batch_x_val_original.size(0)):
                if samples_done < num_samples:
                    original_signal = batch_x_val_original[i].cpu().squeeze().numpy()
                    reconstructed_signal = reconstructed_x_val[i].cpu().squeeze().numpy()
                    ax = axes[samples_done, 0]
                    ax.plot(original_signal, label='Original Signal', color='blue', alpha=0.7)
                    ax.plot(reconstructed_signal, label='Reconstructed Signal', color='red', linestyle='--')
                    mse_sample = np.mean((original_signal - reconstructed_signal)**2)
                    ax.legend()
                    ax.set_title(f"Example {samples_done+1} (Sample MSE: {mse_sample:.4f})")
                    ax.set_xlabel("Time Points")
                    ax.set_ylabel("Amplitude")
                    ax.grid(True, linestyle=':', alpha=0.7)
                    ax.set_xlim(0, sequence_length_param)
                    samples_done += 1
                else:
                    break
            if samples_done >= num_samples:
                break
    
    title_parts = [title_prefix, "Autoencoder Reconstructions"]
    if epoch_num is not None: title_parts.append(f"Epoch {epoch_num}")
    if current_loss is not None: title_parts.append(f"Val Loss: {current_loss:.6f}")
        
    fig.suptitle(" - ".join(filter(None, title_parts)), fontsize=16) # filter(None, ...) removes empty strings
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

class DilatedConvEncoderA(nn.Module):
    def __init__(self, input_channels=1, encoding_dim=ENCODING_DIM_AE, dropout_rate=AE_DROPOUT_RATE):
        super().__init__()
        self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=5, dilation=1, padding=get_padding_for_dilation(5,1))
        self.norm1 = nn.GroupNorm(8, 32); self.relu1 = nn.ReLU(); self.drop1 = nn.Dropout(dropout_rate)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=5, dilation=2, padding=get_padding_for_dilation(5,2))
        self.norm2 = nn.GroupNorm(8, 64); self.relu2 = nn.ReLU(); self.drop2 = nn.Dropout(dropout_rate)
        self.conv3 = nn.Conv1d(64, 128, kernel_size=5, dilation=4, padding=get_padding_for_dilation(5,4))
        self.norm3 = nn.GroupNorm(16, 128); self.relu3 = nn.ReLU(); self.drop3 = nn.Dropout(dropout_rate)
        self.conv4 = nn.Conv1d(128, 256, kernel_size=5, dilation=8, padding=get_padding_for_dilation(5,8))
        self.norm4 = nn.GroupNorm(16, 256); self.relu4 = nn.ReLU(); self.drop4 = nn.Dropout(dropout_rate)
        self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
        self.fc_encoded = nn.Linear(256, encoding_dim)

    def forward(self, x):
        s1 = self.drop1(self.relu1(self.norm1(self.conv1(x))))
        s2 = self.drop2(self.relu2(self.norm2(self.conv2(s1))))
        s3 = self.drop3(self.relu3(self.norm3(self.conv3(s2))))
        s4 = self.drop4(self.relu4(self.norm4(self.conv4(s3))))
        pooled = self.adaptive_pool(s4)
        encoded = self.fc_encoded(pooled.squeeze(-1))
        return encoded, (s1, s2, s3, s4)

class DilatedConvDecoderA(nn.Module):
    def __init__(self, output_channels=1, encoding_dim=ENCODING_DIM_AE, dropout_rate=AE_DROPOUT_RATE, sequence_length_param=SEQUENCE_LENGTH):
        super().__init__()
        self.fc_decoded = nn.Linear(encoding_dim, 256 * 1)
        self.upsample_initial = nn.Upsample(size=sequence_length_param, mode='nearest')
        self.conv_t4 = nn.ConvTranspose1d(256 + 256, 128, kernel_size=5, dilation=8, padding=get_padding_for_dilation(5,8))
        self.norm_t4 = nn.GroupNorm(16, 128); self.relu_t4 = nn.ReLU(); self.drop_t4 = nn.Dropout(dropout_rate)
        self.conv_t3 = nn.ConvTranspose1d(128 + 128, 64, kernel_size=5, dilation=4, padding=get_padding_for_dilation(5,4))
        self.norm_t3 = nn.GroupNorm(8, 64); self.relu_t3 = nn.ReLU(); self.drop_t3 = nn.Dropout(dropout_rate)
        self.conv_t2 = nn.ConvTranspose1d(64 + 64, 32, kernel_size=5, dilation=2, padding=get_padding_for_dilation(5,2))
        self.norm_t2 = nn.GroupNorm(8, 32); self.relu_t2 = nn.ReLU(); self.drop_t2 = nn.Dropout(dropout_rate)
        self.conv_t1 = nn.ConvTranspose1d(32 + 32, output_channels, kernel_size=5, dilation=1, padding=get_padding_for_dilation(5,1))

    def forward(self, x, skips):
        s1, s2, s3, s4 = skips
        x = self.fc_decoded(x); x = x.unsqueeze(-1); x = self.upsample_initial(x)
        x = torch.cat([x, s4], dim=1); x = self.drop_t4(self.relu_t4(self.norm_t4(self.conv_t4(x))))
        x = torch.cat([x, s3], dim=1); x = self.drop_t3(self.relu_t3(self.norm_t3(self.conv_t3(x))))
        x = torch.cat([x, s2], dim=1); x = self.drop_t2(self.relu_t2(self.norm_t2(self.conv_t2(x))))
        x = torch.cat([x, s1], dim=1); decoded = self.conv_t1(x)
        return decoded

class DilatedAutoencoderA(nn.Module):
    def __init__(self, input_channels=1, output_channels=1, encoding_dim=ENCODING_DIM_AE, dropout_rate=AE_DROPOUT_RATE, sequence_length_param=SEQUENCE_LENGTH):
        super().__init__()
        self.encoder = DilatedConvEncoderA(input_channels, encoding_dim, dropout_rate)
        self.decoder = DilatedConvDecoderA(output_channels, encoding_dim, dropout_rate, sequence_length_param)
    def forward(self, x):
        encoded, skips = self.encoder(x)
        decoded = self.decoder(encoded, skips)
        return decoded

def advanced_seismic_augmentation(signal, augment_prob=0.7): # signal is expected to be 1D tensor here
    if random.random() > augment_prob: return signal
    augmented = signal.clone()
    # 1. Gaussian noise
    if random.random() < 0.45: augmented += torch.randn_like(signal) * random.uniform(0.01, 0.05)
    # 2. Time shifting
    if random.random() < 0.25:
        max_shift = int(signal.shape[-1] * 0.05)
        if max_shift > 0 : # only shift if possible
            shift = random.randint(-max_shift, max_shift)
            augmented = torch.roll(augmented, shift, dims=-1)
    # 3. Amplitude scaling
    if random.random() < 0.25: augmented *= random.uniform(0.8, 1.2)
    # 4. Time stretching (more careful with dimensions)
    if random.random() < 0.2 and signal.shape[-1] > 1 : # Ensure signal is not too short
        stretch_factor = random.uniform(0.95, 1.05)
        length = signal.shape[-1]; new_length = int(length * stretch_factor)
        if new_length < 1: new_length = 1 # ensure new_length is at least 1

        # Add batch and channel dim for interpolate, then remove
        stretched = F.interpolate(signal.unsqueeze(0).unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0).squeeze(0)
        
        if stretched.shape[-1] != length: # Resize back to original length
            if stretched.shape[-1] < 1: # if somehow it became empty
                 stretched = torch.zeros_like(signal) # fallback or handle error
            else:
                stretched = F.interpolate(stretched.unsqueeze(0).unsqueeze(0), size=length, mode='linear', align_corners=False).squeeze(0).squeeze(0)
        augmented = stretched
    return augmented


class SingleStrongClassifier(nn.Module):
    def __init__(self, pretrained_encoder, encoding_dim=ENCODING_DIM_AE, num_classes=3):
        super().__init__()
        self.encoder = pretrained_encoder
        for param in self.encoder.parameters():
            param.requires_grad = True # Fine-tuning עדין
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(encoding_dim), nn.Dropout(0.2), nn.Linear(encoding_dim, 512), nn.GELU(),
            nn.BatchNorm1d(512), nn.Dropout(0.4), nn.Linear(512, 256), nn.GELU(),
            nn.BatchNorm1d(256), nn.Dropout(0.3), nn.Linear(256, 128), nn.GELU(),
            nn.BatchNorm1d(128), nn.Dropout(0.2), nn.Linear(128, 64), nn.GELU(),
            nn.BatchNorm1d(64), nn.Dropout(0.1), nn.Linear(64, num_classes)
        )
    def forward(self, x):
        features, _ = self.encoder(x)
        return self.classifier(features)

def get_cosine_scheduler(optimizer, num_epochs, warmup_epochs=5):
    def lr_lambda(epoch):
        if epoch < warmup_epochs: return epoch / warmup_epochs if warmup_epochs > 0 else 1.0
        else:
            progress = (epoch - warmup_epochs) / (num_epochs - warmup_epochs) if (num_epochs - warmup_epochs) > 0 else 0
            return 0.5 * (1 + np.cos(np.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def aggregate_curves(curves_list_of_lists):
    if not curves_list_of_lists or not any(curves_list_of_lists): return np.array([]), np.array([])
    # Filter out empty lists if any from folds that might have stopped very early
    valid_curves = [c for c in curves_list_of_lists if c]
    if not valid_curves: return np.array([]), np.array([])

    max_len = max(len(c) for c in valid_curves)
    padded_curves = [np.pad(c, (0, max_len - len(c)), 'edge') for c in valid_curves] # Pad with last value
    curves_np = np.array(padded_curves)
    mean_curve = np.mean(curves_np, axis=0)
    std_curve = np.std(curves_np, axis=0)
    return mean_curve, std_curve

# --- טעינת נתונים ראשונית ---
print("--- Initial Data Loading and Preparation ---")
X_raw, y_raw = load_and_prepare_data(DATA_PATH)
if len(X_raw) == 0:
    print("No raw data loaded. Exiting.")
    exit()

X_sequences, y_sequences = create_sequences_with_overlap(X_raw, y_raw, SEQUENCE_LENGTH, HOP_SIZE)
if len(X_sequences) == 0:
    print("No sequences were created. Exiting.")
    exit()
print(f"Total sequences created: {len(X_sequences)}")
num_unique_classes = len(np.unique(y_sequences))
print(f"Number of unique classes in sequences: {num_unique_classes}")


# --- שלב טרום K-Fold: אימון Autoencoder יחיד ---
print("\n--- Pre K-Fold: Single Autoencoder Training ---")
# חלוקה: 80% מכלל הסיקוונסים לאימון/ולידציה של ה-AE, מתוכם 20% לוולידציה פנימית של ה-AE.
# ה-20% הנותרים מכלל הסיקוונסים ישמשו כ-X_test_final עבור ה-classifier.
X_for_kfold_and_ae_val, X_test_final, y_for_kfold_and_ae_val, y_test_final = train_test_split(
    X_sequences, y_sequences, test_size=0.2, random_state=SEED, stratify=y_sequences # 20% לקבוצת מבחן סופית
)
# מתוך ה-80% הנותרים, ניקח 75% לאימון ה-AE ו-25% לוולידציה של ה-AE
X_train_ae_single, X_val_ae_single, _, _ = train_test_split( # y לא רלוונטי ל-AE
    X_for_kfold_and_ae_val, y_for_kfold_and_ae_val, # y רק לצורך stratify
    test_size=0.25, random_state=SEED, stratify=y_for_kfold_and_ae_val
)

print(f"Data for single AE training: {len(X_train_ae_single)} sequences")
print(f"Data for single AE validation: {len(X_val_ae_single)} sequences")
print(f"Data for classifier K-Fold (train/val): {len(X_for_kfold_and_ae_val)} sequences")
print(f"Data for final classifier test set: {len(X_test_final)} sequences")

# נרמול עבור אימון ה-AE היחיד
scaler_ae_single = StandardScaler()
X_train_ae_single_flat = X_train_ae_single.reshape(-1, 1); scaler_ae_single.fit(X_train_ae_single_flat)
X_train_ae_single_norm = scaler_ae_single.transform(X_train_ae_single_flat).reshape(X_train_ae_single.shape)
X_val_ae_single_norm = scaler_ae_single.transform(X_val_ae_single.reshape(-1, 1)).reshape(X_val_ae_single.shape)

X_train_ae_single_reshaped = X_train_ae_single_norm[:, np.newaxis, :]
X_val_ae_single_reshaped = X_val_ae_single_norm[:, np.newaxis, :]

X_train_ae_tensor_s = torch.tensor(X_train_ae_single_reshaped, dtype=torch.float32)
X_val_ae_tensor_s = torch.tensor(X_val_ae_single_reshaped, dtype=torch.float32)

train_loader_ae_single = DataLoader(TensorDataset(X_train_ae_tensor_s, X_train_ae_tensor_s), batch_size=AE_BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader_ae_single = DataLoader(TensorDataset(X_val_ae_tensor_s, X_val_ae_tensor_s), batch_size=AE_BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

autoencoder_global = DilatedAutoencoderA(
    encoding_dim=ENCODING_DIM_AE, dropout_rate=AE_DROPOUT_RATE, sequence_length_param=SEQUENCE_LENGTH
).to(device)
optimizer_ae_global = optim.AdamW(autoencoder_global.parameters(), lr=AE_LEARNING_RATE, weight_decay=AE_WEIGHT_DECAY)
criterion_ae_global = nn.MSELoss()
scheduler_ae_global = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ae_global, mode='min', factor=0.2, patience=AE_PATIENCE_SINGLE//2, min_lr=1e-6, verbose=False)

best_val_loss_ae_global = float('inf')
patience_counter_ae_global = 0
best_encoder_state_dict_global = None
history_ae_train_loss_single = []
history_ae_val_loss_single = []

print("Starting single AE training...")
for epoch in range(AE_TRAIN_EPOCHS_SINGLE):
    autoencoder_global.train()
    train_loss_epoch_ae = 0.0
    for batch_X, _ in train_loader_ae_single:
        batch_X_original = batch_X.to(device)
        batch_X_noisy = add_noise_to_batch(batch_X_original.clone(), AE_INPUT_NOISE_STD, device, autoencoder_global.training)
        optimizer_ae_global.zero_grad()
        outputs = autoencoder_global(batch_X_noisy)
        loss = criterion_ae_global(outputs, batch_X_original)
        loss.backward(); optimizer_ae_global.step()
        train_loss_epoch_ae += loss.item() * batch_X_original.size(0)
    train_loss_epoch_ae /= len(train_loader_ae_single.dataset)
    history_ae_train_loss_single.append(train_loss_epoch_ae)

    autoencoder_global.eval()
    val_loss_epoch_ae = 0.0
    with torch.no_grad():
        for batch_X_val, _ in val_loader_ae_single:
            batch_X_val = batch_X_val.to(device)
            outputs_val = autoencoder_global(batch_X_val)
            loss_val = criterion_ae_global(outputs_val, batch_X_val)
            val_loss_epoch_ae += loss_val.item() * batch_X_val.size(0)
    val_loss_epoch_ae /= len(val_loader_ae_single.dataset)
    history_ae_val_loss_single.append(val_loss_epoch_ae)
    
    current_lr_ae = optimizer_ae_global.param_groups[0]['lr']
    if (epoch + 1) % 5 == 0 or epoch == 0 or (epoch + 1) == AE_TRAIN_EPOCHS_SINGLE:
        print(f"  Single AE Epoch {epoch+1}/{AE_TRAIN_EPOCHS_SINGLE} - Train Loss: {train_loss_epoch_ae:.6f} - Val Loss: {val_loss_epoch_ae:.6f} - LR: {current_lr_ae:.1e}")
    
    scheduler_ae_global.step(val_loss_epoch_ae)
    if val_loss_epoch_ae < best_val_loss_ae_global:
        best_val_loss_ae_global = val_loss_epoch_ae
        best_encoder_state_dict_global = autoencoder_global.encoder.state_dict()
        patience_counter_ae_global = 0
        if (epoch + 1) % 5 == 0 or epoch == 0 or (epoch + 1) == AE_TRAIN_EPOCHS_SINGLE: # הדפס גם שיפורים
             print(f"    New best AE val_loss: {best_val_loss_ae_global:.6f}. Encoder state saved.")
    else:
        patience_counter_ae_global += 1
        if patience_counter_ae_global >= AE_PATIENCE_SINGLE:
            print(f"  Single AE Early stopping at epoch {epoch+1}. Best Val Loss: {best_val_loss_ae_global:.6f}")
            break

if best_encoder_state_dict_global is None:
    print("Error: AE training did not produce a best encoder state. This might happen if epochs are too few or data is problematic. Exiting.")
    exit()
print(f"Single AE training finished. Best Val Loss: {best_val_loss_ae_global:.6f}")

plt.figure(figsize=(10, 5))
plt.plot(history_ae_train_loss_single, label='Single AE Train Loss', color='dodgerblue')
plt.plot(history_ae_val_loss_single, label='Single AE Validation Loss', color='orangered', linestyle='--')
plt.title(f'Single Autoencoder Training Loss\nBest Val Loss: {best_val_loss_ae_global:.6f}')
plt.xlabel('Epoch'); plt.ylabel('Loss (MSE)'); plt.legend(); plt.grid(True, linestyle=':', alpha=0.6)
plt.tight_layout(); plt.show()

if AE_PLOT_RECONSTRUCTION_SINGLE:
    print("\nDisplaying reconstructions from the best single AE model...")
    autoencoder_global.encoder.load_state_dict(best_encoder_state_dict_global) # Ensure best encoder is used
    plot_ae_reconstructions(autoencoder_global, val_loader_ae_single, device, num_samples=3, epoch_num="Final (Best Single AE)", 
                            current_loss=best_val_loss_ae_global, sequence_length_param=SEQUENCE_LENGTH, plot_now=True, title_prefix="Single")


# --- K-Fold Cross-Validation (עבור Classifier בלבד, עם האנקודר הגלובלי) ---
# נשתמש ב- X_for_kfold_and_ae_val ו- y_for_kfold_and_ae_val עבור ה-K-Fold של ה-Classifier.
# קבוצת המבחן הסופית היא X_test_final, y_test_final.

skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
fold_clf_train_losses, fold_clf_val_losses = [], []
fold_clf_train_accs, fold_clf_val_accs = [], []
fold_clf_test_accs = []
fold_clf_test_reports = []
all_y_true_test_final, all_y_pred_test_final = [], []

for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_for_kfold_and_ae_val, y_for_kfold_and_ae_val)):
    print(f"\n--- Classifier K-Fold: Fold {fold_idx + 1}/{N_SPLITS} ---")
    
    X_train_fold_clf, X_val_fold_clf = X_for_kfold_and_ae_val[train_idx], X_for_kfold_and_ae_val[val_idx]
    y_train_fold_clf, y_val_fold_clf = y_for_kfold_and_ae_val[train_idx], y_for_kfold_and_ae_val[val_idx]

    # נרמול ספציפי ל-fold של ה-Classifier (על בסיס נתוני האימון של ה-fold)
    scaler_clf_fold = StandardScaler()
    X_train_flat_fold_clf = X_train_fold_clf.reshape(-1, 1); scaler_clf_fold.fit(X_train_flat_fold_clf)
    X_train_normalized_fold_clf = scaler_clf_fold.transform(X_train_flat_fold_clf).reshape(X_train_fold_clf.shape)
    X_val_normalized_fold_clf = scaler_clf_fold.transform(X_val_fold_clf.reshape(-1, 1)).reshape(X_val_fold_clf.shape)
    
    # נרמול קבוצת המבחן החיצונית עם ה-scaler של ה-fold הנוכחי (חשוב!)
    X_test_final_normalized_fold = scaler_clf_fold.transform(X_test_final.reshape(-1, 1)).reshape(X_test_final.shape)

    X_train_reshaped_fold_clf = X_train_normalized_fold_clf[:, np.newaxis, :]
    X_val_reshaped_fold_clf = X_val_normalized_fold_clf[:, np.newaxis, :]
    X_test_final_reshaped_fold_clf = X_test_final_normalized_fold[:, np.newaxis, :]

    X_train_tensor_f_clf = torch.tensor(X_train_reshaped_fold_clf, dtype=torch.float32)
    y_train_tensor_f_clf = torch.tensor(y_train_fold_clf, dtype=torch.long)
    X_val_tensor_f_clf = torch.tensor(X_val_reshaped_fold_clf, dtype=torch.float32)
    y_val_tensor_f_clf = torch.tensor(y_val_fold_clf, dtype=torch.long)
    X_test_final_tensor_f_clf = torch.tensor(X_test_final_reshaped_fold_clf, dtype=torch.float32)
    y_test_final_tensor_f_clf = torch.tensor(y_test_final, dtype=torch.long)

    train_loader_clf_f = DataLoader(TensorDataset(X_train_tensor_f_clf, y_train_tensor_f_clf), batch_size=CLASSIFIER_BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    val_loader_clf_f = DataLoader(TensorDataset(X_val_tensor_f_clf, y_val_tensor_f_clf), batch_size=CLASSIFIER_BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
    test_loader_clf_final_f = DataLoader(TensorDataset(X_test_final_tensor_f_clf, y_test_final_tensor_f_clf), batch_size=CLASSIFIER_BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

    encoder_for_clf_f = DilatedConvEncoderA(encoding_dim=ENCODING_DIM_AE, dropout_rate=AE_DROPOUT_RATE).to(device)
    encoder_for_clf_f.load_state_dict(best_encoder_state_dict_global)
    
    classifier_f = SingleStrongClassifier(
        encoder_for_clf_f, encoding_dim=ENCODING_DIM_AE, num_classes=num_unique_classes
    ).to(device)
    
    manual_weights_list = [2.5, 2.0, 1.0] 
    class_weights_tensor_f = torch.tensor(manual_weights_list, dtype=torch.float32).to(device)
    criterion_clf_f = nn.CrossEntropyLoss(weight=class_weights_tensor_f, label_smoothing=0.1)
    # אופטימיזציה רק על הפרמטרים של המסווג אם האנקודר קפוא, או על כולם אם מאפשרים fine-tuning
    params_to_optimize = classifier_f.parameters() # By default optimizes all (encoder fine-tuning + classifier)
    # if FREEZE_ENCODER_IN_CLASSIFIER_KFOLD: # Add a flag if you want to test this
    #     for param in classifier_f.encoder.parameters():
    #         param.requires_grad = False
    #     params_to_optimize = classifier_f.classifier.parameters()

    optimizer_clf_f = optim.AdamW(params_to_optimize, lr=CLASSIFIER_LR, weight_decay=CLASSIFIER_WEIGHT_DECAY, betas=(0.9, 0.999))
    scheduler_clf_f = get_cosine_scheduler(optimizer_clf_f, num_epochs=CLASSIFIER_EPOCHS, warmup_epochs=max(1, CLASSIFIER_EPOCHS // 10))

    best_val_acc_clf_f = 0.0
    patience_counter_clf_f = 0
    current_fold_clf_train_losses, current_fold_clf_val_losses = [], []
    current_fold_clf_train_accs, current_fold_clf_val_accs = [], []
    best_classifier_state_dict_f = None

    print(f"  Starting Classifier training for Fold {fold_idx+1}...")
    for epoch in range(CLASSIFIER_EPOCHS):
        classifier_f.train()
        train_loss_clf, train_correct_clf, train_total_clf = 0.0, 0, 0
        for batch_X, batch_y in train_loader_clf_f:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            aug_batch_X, aug_batch_y = [], []
            for i in range(batch_X.shape[0]):
                aug_batch_X.append(batch_X[i])
                aug_batch_y.append(batch_y[i])
                for _ in range(CLASSIFIER_NUM_AUG_PER_SAMPLE):
                    # advanced_seismic_augmentation expects 1D tensor (signal only)
                    aug_sample = advanced_seismic_augmentation(batch_X[i].squeeze(0)) # remove channel dim
                    aug_batch_X.append(aug_sample.unsqueeze(0)) # add channel dim back
                    aug_batch_y.append(batch_y[i])
            
            combined_X = torch.stack(aug_batch_X).to(device)
            combined_y = torch.stack(aug_batch_y).to(device)
            
            optimizer_clf_f.zero_grad()
            outputs = classifier_f(combined_X)
            loss = criterion_clf_f(outputs, combined_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params_to_optimize, max_norm=1.0) # Clip gradients for parameters being optimized
            optimizer_clf_f.step()
            
            train_loss_clf += loss.item() * combined_X.size(0)
            _, predicted = torch.max(outputs, 1)
            train_correct_clf += (predicted == combined_y).sum().item()
            train_total_clf += combined_y.size(0)
        
        train_loss_clf /= train_total_clf if train_total_clf > 0 else 1
        train_acc_clf = train_correct_clf / train_total_clf if train_total_clf > 0 else 0
        current_fold_clf_train_losses.append(train_loss_clf)
        current_fold_clf_train_accs.append(train_acc_clf)

        classifier_f.eval()
        val_loss_clf, val_correct_clf, val_total_clf = 0.0, 0, 0
        with torch.no_grad():
            for batch_X, batch_y in val_loader_clf_f:
                batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                outputs = classifier_f(batch_X)
                loss = criterion_clf_f(outputs, batch_y)
                val_loss_clf += loss.item() * batch_X.size(0)
                _, predicted = torch.max(outputs, 1)
                val_correct_clf += (predicted == batch_y).sum().item()
                val_total_clf += batch_y.size(0)
        
        val_loss_clf /= val_total_clf if val_total_clf > 0 else 1
        val_acc_clf = val_correct_clf / val_total_clf if val_total_clf > 0 else 0
        current_fold_clf_val_losses.append(val_loss_clf)
        current_fold_clf_val_accs.append(val_acc_clf)
        scheduler_clf_f.step()
        current_lr_clf = optimizer_clf_f.param_groups[0]['lr']

        if (epoch + 1) % 5 == 0 or epoch == 0 or (epoch + 1) == CLASSIFIER_EPOCHS:
            print(f"    CLF Fold {fold_idx+1} Epoch {epoch+1:3d}/{CLASSIFIER_EPOCHS} | Train: Loss={train_loss_clf:.4f}, Acc={train_acc_clf:.4f} | Val: Loss={val_loss_clf:.4f}, Acc={val_acc_clf:.4f} | LR={current_lr_clf:.1e}")

        if val_acc_clf > best_val_acc_clf_f:
            best_val_acc_clf_f = val_acc_clf
            best_classifier_state_dict_f = classifier_f.state_dict()
            patience_counter_clf_f = 0
            if (epoch + 1) % 5 == 0 or epoch == 0 or (epoch + 1) == CLASSIFIER_EPOCHS:
                print(f"      New best CLF val_acc for Fold {fold_idx+1}: {best_val_acc_clf_f:.4f}. Model state saved.")
        else:
            patience_counter_clf_f += 1
            if patience_counter_clf_f >= CLASSIFIER_PATIENCE:
                print(f"    CLF Early stopping at epoch {epoch+1} for Fold {fold_idx+1}. Best val_acc: {best_val_acc_clf_f:.4f}")
                break
    
    fold_clf_train_losses.append(current_fold_clf_train_losses)
    fold_clf_val_losses.append(current_fold_clf_val_losses)
    fold_clf_train_accs.append(current_fold_clf_train_accs)
    fold_clf_val_accs.append(current_fold_clf_val_accs)
    print(f"  Best CLF Validation Accuracy for Fold {fold_idx+1}: {best_val_acc_clf_f:.4f}")

    if best_classifier_state_dict_f:
        classifier_f.load_state_dict(best_classifier_state_dict_f)
    else:
        print(f"  Warning: No best classifier state dict saved for fold {fold_idx+1}. Using last state for test evaluation.")

    classifier_f.eval()
    y_true_test_f, y_pred_test_f = [], []
    with torch.no_grad():
        for batch_X, batch_y in test_loader_clf_final_f:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs = classifier_f(batch_X)
            _, predicted = torch.max(outputs, 1)
            y_true_test_f.extend(batch_y.cpu().numpy())
            y_pred_test_f.extend(predicted.cpu().numpy())
    
    all_y_true_test_final.extend(y_true_test_f)
    all_y_pred_test_final.extend(y_pred_test_f)
    
    test_acc_f = accuracy_score(y_true_test_f, y_pred_test_f)
    fold_clf_test_accs.append(test_acc_f)
    report_f = classification_report(y_true_test_f, y_pred_test_f, target_names=['quiet', 'vehicle', 'human'], output_dict=True, zero_division=0)
    fold_clf_test_reports.append(report_f)
    print(f"  Fold {fold_idx+1} Test Accuracy on final test set: {test_acc_f:.4f}")


# --- סיכום תוצאות K-Fold (עבור ה-Classifier) ---
print("\n\n--- Classifier K-Fold Cross-Validation Summary (using pre-trained AE) ---")
plt.style.use('seaborn-v0_8-whitegrid')

avg_clf_train_loss, std_clf_train_loss = aggregate_curves(fold_clf_train_losses)
avg_clf_val_loss, std_clf_val_loss = aggregate_curves(fold_clf_val_losses)
avg_clf_train_acc, std_clf_train_acc = aggregate_curves(fold_clf_train_accs)
avg_clf_val_acc, std_clf_val_acc = aggregate_curves(fold_clf_val_accs)

if avg_clf_train_loss.size > 0: # Check if aggregation was successful
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(avg_clf_train_loss, label='Avg CLF Train Loss', color='forestgreen')
    plt.fill_between(range(len(avg_clf_train_loss)), avg_clf_train_loss - std_clf_train_loss, avg_clf_train_loss + std_clf_train_loss, color='forestgreen', alpha=0.2)
    plt.title(f'Avg Classifier Training Loss ({N_SPLITS} Folds)')
    plt.xlabel('Epoch'); plt.ylabel('CrossEntropy Loss'); plt.legend(); plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(avg_clf_val_loss, label='Avg CLF Validation Loss', color='gold')
    plt.fill_between(range(len(avg_clf_val_loss)), avg_clf_val_loss - std_clf_val_loss, avg_clf_val_loss + std_clf_val_loss, color='gold', alpha=0.2)
    plt.title(f'Avg Classifier Validation Loss ({N_SPLITS} Folds)')
    plt.xlabel('Epoch'); plt.ylabel('CrossEntropy Loss'); plt.legend(); plt.grid(True)
    plt.tight_layout(); plt.show()
else:
    print("Could not generate classifier loss plots (no data).")


if avg_clf_train_acc.size > 0:
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(avg_clf_train_acc, label='Avg CLF Train Accuracy', color='mediumpurple')
    plt.fill_between(range(len(avg_clf_train_acc)), avg_clf_train_acc - std_clf_train_acc, avg_clf_train_acc + std_clf_train_acc, color='mediumpurple', alpha=0.2)
    plt.title(f'Avg Classifier Training Accuracy ({N_SPLITS} Folds)')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.ylim(0, 1.05); plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(avg_clf_val_acc, label='Avg CLF Validation Accuracy', color='darkorange')
    plt.fill_between(range(len(avg_clf_val_acc)), avg_clf_val_acc - std_clf_val_acc, avg_clf_val_acc + std_clf_val_acc, color='darkorange', alpha=0.2)
    plt.title(f'Avg Classifier Validation Accuracy ({N_SPLITS} Folds)')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.ylim(0, 1.05); plt.grid(True)
    plt.tight_layout(); plt.show()
else:
    print("Could not generate classifier accuracy plots (no data).")


if fold_clf_test_accs:
    mean_test_acc = np.mean(fold_clf_test_accs)
    std_test_acc = np.std(fold_clf_test_accs)
    print(f"\nAverage Test Accuracy over {N_SPLITS} Folds: {mean_test_acc:.4f} +/- {std_test_acc:.4f}")
    print(f"Individual Fold Test Accuracies: {[f'{acc:.4f}' for acc in fold_clf_test_accs]}")

    plt.figure(figsize=(6, 5))
    sns.boxplot(data=fold_clf_test_accs, palette='viridis', width=0.3)
    plt.title(f'Test Accuracies Across {N_SPLITS} Folds\nMean: {mean_test_acc:.4f} (Std: {std_test_acc:.4f})')
    plt.ylabel('Test Accuracy'); plt.xticks([0], [f'{N_SPLITS}-Fold CV']); plt.grid(True, linestyle='--', alpha=0.7)
    plt.show()
else:
    print("No test accuracies recorded for folds.")


print("\n--- Final Evaluation on Aggregated Test Set Predictions (Classifier K-Fold) ---")
target_names = ['quiet', 'vehicle', 'human'] # Make sure these match your label_encoding
if all_y_true_test_final and all_y_pred_test_final:
    final_accuracy_agg = accuracy_score(all_y_true_test_final, all_y_pred_test_final)
    print(f"Overall Accuracy on Final Test Set (Aggregated from {N_SPLITS} folds): {final_accuracy_agg:.4f}")
    print("\nOverall Classification Report (Aggregated):")
    print(classification_report(all_y_true_test_final, all_y_pred_test_final, target_names=target_names, zero_division=0))

    cm_final_agg = confusion_matrix(all_y_true_test_final, all_y_pred_test_final)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_final_agg, annot=True, fmt='d', cmap='Blues_r', 
                xticklabels=target_names, yticklabels=target_names, annot_kws={"size": 14})
    plt.title(f'Aggregated Confusion Matrix - Final Test Set\nOverall Accuracy: {final_accuracy_agg:.4f}', fontsize=15)
    plt.ylabel('True Label', fontsize=12); plt.xlabel('Predicted Label', fontsize=12)
    plt.xticks(fontsize=10); plt.yticks(fontsize=10); plt.show()
else:
    print("No aggregated test predictions available to generate final report and confusion matrix.")


print(f"\n📈 Summary of Classifier K-Fold Cross-Validation (using pre-trained AE):")
if fold_clf_test_accs:
    print(f"  Number of Folds (N_SPLITS): {N_SPLITS}")
    print(f"  Average Test Accuracy: {mean_test_acc:.4f} (Std: {std_test_acc:.4f})")
if all_y_true_test_final:
    print(f"  Overall Accuracy (on aggregated predictions from all folds' test runs): {final_accuracy_agg:.4f if 'final_accuracy_agg' in locals() else 'N/A'}")
print("--- End of Script ---")