In [1]:
# ==============================================================================
# Cell 1: Setup and Configuration (WGAN-GP Version)
# ==============================================================================
#
# This version implements the WGAN-GP architecture to solve mode collapse.
#
# KEY CHANGES:
# - Switched from WGAN-WC to WGAN-GP (Wasserstein GAN with Gradient Penalty).
# - Generator now uses a Tanh activation for stable, smooth output.
# - Hyperparameters (LR, betas, lambda_gp) are tuned for WGAN-GP.
#
# ==============================================================================

# --- 1.1. Install necessary libraries ---
!pip install numpy scipy scikit-learn matplotlib seaborn pandas mne --quiet

# --- 1.2. Imports ---
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from scipy import signal
from scipy.linalg import sqrtm
from sklearn.decomposition import PCA
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# --- 1.3. Configuration & Setup ---
OUTPUT_DIR = "/kaggle/working/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Model and Data Parameters
DATA_PATH = "/kaggle/input/mpi-lemon-eeg/lemon_preprocessed_8ch_512ts.npy"
CHANNELS = 8
TIMESTEPS = 512
LATENT_DIM = 100
NUM_CLASSES = 2 # Note: Using dummy classes for this setup

# Training Hyperparameters for WGAN-GP
NUM_EPOCHS = 100 # WGAN-GP can take longer to converge, but is more stable
BATCH_SIZE = 64
LR = 1e-4 # A common, stable learning rate for WGAN-GP
BETA1 = 0.0 # Recommended beta for WGAN-GP
BETA2 = 0.9
CRITIC_ITERATIONS = 5 # Train critic more often, which is standard for WGANs
LAMBDA_GP = 10.0 # Gradient penalty coefficient, as recommended in the WGAN-GP paper
DIVERSITY_WEIGHT = 0.1 # Mode-seeking loss can be smaller as GP helps with diversity

# Evaluation
SAVE_INTERVAL = 10

# --- Setup device ---
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Configuration set for WGAN-GP training.")

Using device: cuda
Configuration set for WGAN-GP training.


In [2]:
# ==============================================================================
# Cell 2: Load Preprocessed Data
# ==============================================================================
# This cell is unchanged. It loads the .npy file from the specified path.
# ==============================================================================

try:
    # Make sure the Kaggle dataset is attached to this notebook.
    data = np.load(DATA_PATH)
    print(f"Successfully loaded data. Shape: {data.shape}")
    
    # Create dummy labels as the focus is on unconditional generation quality
    labels = np.random.randint(0, NUM_CLASSES, data.shape[0])
    print(f"Created dummy labels. Shape: {labels.shape}")
    
    dataset = TensorDataset(torch.from_numpy(data).float(), torch.from_numpy(labels).long())
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    print("Created DataLoader successfully.")
    
except FileNotFoundError:
    print(f"ERROR: Data file not found at '{DATA_PATH}'.")
    print("Please ensure your preprocessed .npy file is in a Kaggle Dataset and that the dataset is added to this notebook.")
except Exception as e:
    print(f"An error occurred: {e}")

Successfully loaded data. Shape: (35503, 512, 8)
Created dummy labels. Shape: (35503,)
Created DataLoader successfully.


In [3]:
# ==============================================================================
# Cell 3: Model Architectures with Tanh Activation
# ==============================================================================
#
# KEY CHANGE:
# - The final layer of the ResGenerator now uses `nn.Tanh()` to smoothly
#   squash the output to the [-1, 1] range. This is crucial for stable
#   gradient flow and prevents the hard-clipping artifacts.
#
# ==============================================================================
from torch.nn.utils import spectral_norm

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.BatchNorm1d(in_channels), nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        )
        self.shortcut = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv1d(in_channels, out_channels, kernel_size=1)
        )
    def forward(self, x):
        return self.conv_block(x) + self.shortcut(x)

class ResGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, channels, timesteps):
        super(ResGenerator, self).__init__()
        self.initial_len = timesteps // 16
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.fc = nn.Linear(latent_dim * 2, 256 * self.initial_len)
        self.res_blocks = nn.Sequential(ResBlock(256, 128), ResBlock(128, 64), ResBlock(64, 32), ResBlock(32, 16))
        
        # --- KEY CHANGE: Added nn.Tanh() for smooth, bounded output ---
        self.final_conv = nn.Sequential(
            nn.BatchNorm1d(16), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(16, channels, kernel_size=3, padding=1),
            nn.Tanh() # Output is now in [-1, 1]
        )
    def forward(self, noise, labels):
        c = self.label_emb(labels)
        x = torch.cat([noise, c], 1)
        x = self.fc(x)
        x = x.reshape(x.size(0), 256, self.initial_len)
        x = self.res_blocks(x)
        x = self.final_conv(x)
        return x.transpose(1, 2)

# The Discriminator is well-suited for WGAN-GP and remains unchanged.
class V6InspiredDiscriminator(nn.Module):
    def __init__(self, num_classes, channels, timesteps):
        super(V6InspiredDiscriminator, self).__init__()
        self.channels = channels
        self.timesteps = timesteps
        self.embedding_size = channels * timesteps
        self.label_emb = nn.Embedding(num_classes, self.embedding_size)
        self.conv_layers = nn.Sequential(
            spectral_norm(nn.Conv1d(channels * 2, 128, 5, 2, 2)), nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv1d(128, 256, 5, 2, 2)), nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv1d(256, 512, 5, 2, 2)), nn.LeakyReLU(0.2, inplace=True),
        )
        with torch.no_grad():
            dummy_input = torch.randn(1, channels * 2, timesteps)
            conv_out_size = self.conv_layers(dummy_input).reshape(1, -1).size(1)
        self.classifier = nn.Sequential(
            spectral_norm(nn.Linear(conv_out_size, 1024)), nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
            spectral_norm(nn.Linear(1024, 1))
        )
    def forward(self, eeg, labels):
        x = eeg.transpose(1, 2)
        c = self.label_emb(labels).reshape(-1, self.channels, self.timesteps)
        x = torch.cat([x, c], dim=1)
        x = self.conv_layers(x)
        x = x.reshape(x.size(0), -1)
        return self.classifier(x)

print("Model architectures defined (Generator with Tanh).")

Model architectures defined (Generator with Tanh).


In [4]:
# ==============================================================================
# Cell 4: The Ultimate Visualization Suite
# ==============================================================================
#
# This cell is unchanged. These plotting functions are essential for verifying
# that the mode collapse issue is solved. The histogram plot is now one of
# the most important qualitative checks.
#
# ==============================================================================
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity

def plot_figure_3(real_data, gen_data, save_path_prefix):
    """Generates plots for Figure 3a (PCA) and 3b (Histogram)."""
    pca = PCA(n_components=2)
    real_flat, gen_flat = real_data.reshape(real_data.shape[0], -1), gen_data.reshape(gen_data.shape[0], -1)
    pca.fit(real_flat)
    real_pca, gen_pca = pca.transform(real_flat), pca.transform(gen_flat)
    
    plt.figure(figsize=(10, 10))
    plt.scatter(real_pca[:, 0], real_pca[:, 1], alpha=0.5, label='Real EEG', s=15, c='blue')
    plt.scatter(gen_pca[:, 0], gen_pca[:, 1], alpha=0.5, label='Generated EEG', s=15, c='orange')
    plt.title('PCA Distribution of Real vs. Generated Data', fontsize=16)
    plt.xlabel('PC 1', fontsize=12)
    plt.ylabel('PC 2', fontsize=12)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.savefig(f"{save_path_prefix}_fig3a_pca.png")
    plt.close()

    plt.figure(figsize=(12, 7))
    sns.histplot(real_data.flatten(), color="blue", label='Real', stat='density', bins=100, alpha=0.7)
    sns.histplot(gen_data.flatten(), color="orange", label='Generated', stat='density', bins=100, alpha=0.7)
    plt.title('Global Amplitude Distribution', fontsize=16)
    plt.xlabel('Amplitude (Normalized)', fontsize=12)
    plt.ylabel('Density', fontsize=12)
    plt.legend()
    plt.savefig(f"{save_path_prefix}_fig3b_histogram.png")
    plt.close()
    print("Generated distribution plots (PCA, Histogram).")

def plot_figure_4(real_data, gen_data, save_path_prefix):
    """Generates per-channel PSD plots."""
    channel_names = ['F1', 'F2', 'C1', 'C2', 'P1', 'P2', 'O1', 'O2']
    sfreq = 98
    fig, axes = plt.subplots(4, 2, figsize=(18, 24), sharex=True, sharey=True)
    axes = axes.flatten()
    for i in range(real_data.shape[2]):
        f_real, psd_real = signal.welch(real_data[:, :, i], fs=sfreq, axis=1, nperseg=sfreq*2)
        f_gen, psd_gen = signal.welch(gen_data[:, :, i], fs=sfreq, axis=1, nperseg=sfreq*2)
        mean_psd_real, std_psd_real = np.mean(np.log10(psd_real + 1e-12), 0), np.std(np.log10(psd_real + 1e-12), 0)
        mean_psd_gen, std_psd_gen = np.mean(np.log10(psd_gen + 1e-12), 0), np.std(np.log10(psd_gen + 1e-12), 0)
        ax = axes[i]
        ax.plot(f_real, mean_psd_real, label='Real', color='blue')
        ax.fill_between(f_real, mean_psd_real - std_psd_real, mean_psd_real + std_psd_real, color='blue', alpha=0.2)
        ax.plot(f_gen, mean_psd_gen, label='Generated', color='orange')
        ax.fill_between(f_gen, mean_psd_gen - std_psd_gen, mean_psd_gen + std_psd_gen, color='orange', alpha=0.2)
        ax.set_title(f'Channel: {channel_names[i]}', fontsize=14)
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Log Power/Hz')
        ax.legend()
        ax.grid(True, linestyle='--')
        ax.set_xlim(0, 45)
    plt.tight_layout()
    plt.savefig(f"{save_path_prefix}_fig4_psd_per_channel.png")
    plt.close()
    print("Generated per-channel PSD plots.")

def plot_figure_5(real_data, gen_data, save_path_prefix):
    """Generates functional connectivity plots."""
    real_flat = real_data.reshape(-1, real_data.shape[2])
    gen_flat = gen_data.reshape(-1, gen_data.shape[2])
    conn_real = cosine_similarity(real_flat.T)
    conn_gen = cosine_similarity(gen_flat.T)
    fig, axes = plt.subplots(1, 2, figsize=(18, 8))
    sns.heatmap(conn_real, ax=axes[0], annot=True, fmt='.2f', cmap='viridis', vmin=0, vmax=1)
    axes[0].set_title('Real EEG Connectivity', fontsize=16)
    sns.heatmap(conn_gen, ax=axes[1], annot=True, fmt='.2f', cmap='viridis', vmin=0, vmax=1)
    axes[1].set_title('Generated EEG Connectivity', fontsize=16)
    fig.suptitle('Functional Connectivity Comparison', fontsize=20)
    plt.savefig(f"{save_path_prefix}_fig5_connectivity.png")
    plt.close()
    print("Generated connectivity plots.")

def plot_figure_6(frechet_distances, save_interval, save_path):
    """Generates FD progression plots."""
    fig, axes = plt.subplots(1, 2, figsize=(20, 7))
    fd_epochs = np.arange(1, len(frechet_distances) + 1) * save_interval
    fd_spectral = [d['FD Spectral (Normalized)'] for d in frechet_distances]
    fd_hjorth = [d['FD Hjorth (Normalized)'] for d in frechet_distances]
    axes[0].plot(fd_epochs, fd_spectral, 'b-o', label='Spectral FD')
    axes[0].set_title('Spectral FD Progression', fontsize=14)
    axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('FD (Normalized)'); axes[0].legend(); axes[0].grid(True)
    axes[1].plot(fd_epochs, fd_hjorth, 'g-o', label='Hjorth FD')
    axes[1].set_title('Hjorth FD Progression', fontsize=14)
    axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('FD (Normalized)'); axes[1].legend(); axes[1].grid(True)
    fig.suptitle('Fréchet Distance Progression Over Training', fontsize=22)
    plt.savefig(save_path)
    plt.close()
    print("Generated FD progression plot.")

print("Ultimate evaluation and plotting suite defined.")

Ultimate evaluation and plotting suite defined.


In [5]:
# ==============================================================================
# Cell 5: The WGAN-GP Training Loop
# ==============================================================================
#
# This is the definitive training loop, implementing the WGAN-GP algorithm.
#
# KEY CHANGES:
# - `compute_gradient_penalty` function is added.
# - The discriminator loss now includes the gradient penalty term.
# - Weight clipping (`p.data.clamp_`) is completely REMOVED.
# - Manual output clamping (`torch.clamp`) is REMOVED, as the generator's
#   Tanh activation now handles this correctly and smoothly.
#
# ==============================================================================

def compute_gradient_penalty(discriminator, real_samples, fake_samples, labels, device):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(real_samples.size(0), 1, 1, device=device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    
    d_interpolates = discriminator(interpolates, labels)
    
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(d_interpolates.size(), device=device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    
    gradients = gradients.reshape(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# --- HELPER FUNCTIONS FOR FD CALCULATION (UNCHANGED) ---
def get_yaregan_benchmark_features(data, sfreq=98):
    data_t = data.transpose(0, 2, 1)
    bands = {'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 13), 'beta': (13, 30), 'gamma': (30, 45)}
    freqs, psd = signal.welch(data_t, fs=sfreq, axis=2, nperseg=sfreq*2)
    abs_band_powers = np.stack([np.mean(psd[:, :, np.logical_and(freqs >= b[0], freqs <= b[1])], axis=2) for b in bands.values()], axis=-1)
    total_power = np.sum(abs_band_powers, axis=-1, keepdims=True); relative_psd_features = abs_band_powers / (total_power + 1e-10)
    dx = np.diff(data_t, axis=2); var_x, var_dx = np.var(data_t, axis=2), np.var(dx, axis=2)
    activity = var_x; mobility = np.sqrt(var_dx / (var_x + 1e-8)); complexity = np.sqrt(np.var(np.diff(dx, axis=2), axis=2) / (var_dx + 1e-8)) / (mobility + 1e-8)
    hjorth_features = np.stack([activity, mobility, complexity], axis=-1)
    return relative_psd_features.reshape(data.shape[0], -1), hjorth_features.reshape(data.shape[0], -1)

def calculate_frechet_distance(features1, features2):
    mu1, sigma1 = np.mean(features1, axis=0), np.cov(features1, rowvar=False)
    mu2, sigma2 = np.mean(features2, axis=0), np.cov(features2, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean): covmean = covmean.real
    fd = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fd

def get_final_fd_scores(real_data, gen_data, n_channels):
    real_spectral, real_hjorth = get_yaregan_benchmark_features(real_data)
    gen_spectral, gen_hjorth = get_yaregan_benchmark_features(gen_data)
    fd_spectral = calculate_frechet_distance(real_spectral, gen_spectral)
    fd_hjorth = calculate_frechet_distance(real_hjorth, gen_hjorth)
    return {
        "FD Spectral (Normalized)": fd_spectral / n_channels,
        "FD Hjorth (Normalized)": fd_hjorth / n_channels,
        "FD Total": (fd_spectral + fd_hjorth) / n_channels
    }
    
# --- Model and Optimizer Setup ---
generator = ResGenerator(LATENT_DIM, NUM_CLASSES, CHANNELS, TIMESTEPS).to(device)
discriminator = V6InspiredDiscriminator(NUM_CLASSES, CHANNELS, TIMESTEPS).to(device)

g_optimizer = optim.Adam(generator.parameters(), lr=LR, betas=(BETA1, BETA2))
d_optimizer = optim.Adam(discriminator.parameters(), lr=LR, betas=(BETA1, BETA2))

# --- Training State Storage ---
frechet_distances = []
best_fd = float('inf')
real_samples_for_eval = next(iter(dataloader))[0].numpy()[:256]

print(f"Starting WGAN-GP training for {NUM_EPOCHS} epochs...");
for epoch in range(NUM_EPOCHS):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    for i, (real_eegs, labels) in enumerate(pbar):
        real_eegs, labels = real_eegs.to(device), labels.to(device)
        batch_size = real_eegs.size(0)

        # --- Train Discriminator (Critic) ---
        d_optimizer.zero_grad()
        
        # Generate fake EEG data
        z = torch.randn(batch_size, LATENT_DIM, device=device)
        fake_eegs = generator(z, labels).detach()

        # Get scores for real and fake data
        real_validity = discriminator(real_eegs, labels)
        fake_validity = discriminator(fake_eegs, labels)
        
        # Calculate gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator, real_eegs.data, fake_eegs.data, labels.data, device)
        
        # Critic loss
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + LAMBDA_GP * gradient_penalty
        d_loss.backward()
        d_optimizer.step()

        # --- Train Generator ---
        # Train generator only once every CRITIC_ITERATIONS
        if i % CRITIC_ITERATIONS == 0:
            g_optimizer.zero_grad()
            
            # Generate two batches of fake data for diversity loss
            z1 = torch.randn(batch_size, LATENT_DIM, device=device)
            gen_labels = torch.randint(0, NUM_CLASSES, (batch_size,), device=device)
            fake_eegs1 = generator(z1, gen_labels)
            
            # Adversarial loss
            g_loss_adv = -torch.mean(discriminator(fake_eegs1, gen_labels))
            
            # Mode-seeking diversity loss
            z2 = torch.randn(batch_size, LATENT_DIM, device=device)
            fake_eegs2 = generator(z2, gen_labels)
            lz = torch.mean(torch.abs(z1 - z2))
            lf = torch.mean(torch.abs(fake_eegs1 - fake_eegs2))
            g_loss_div = DIVERSITY_WEIGHT * (lz / (lf + 1e-8))

            g_loss = g_loss_adv + g_loss_div
            g_loss.backward()
            g_optimizer.step()
        
            pbar.set_postfix({'D Loss': f'{d_loss.item():.4f}', 'G Loss': f'{g_loss.item():.4f}'})

    # --- Epoch End Evaluation ---
    if (epoch + 1) % SAVE_INTERVAL == 0 or epoch == NUM_EPOCHS - 1:
        print(f"\n--- Epoch {epoch+1} Evaluation ---")
        generator.eval()
        with torch.no_grad():
            z_eval = torch.randn(real_samples_for_eval.shape[0], LATENT_DIM, device=device)
            labels_eval = torch.randint(0, NUM_CLASSES, (real_samples_for_eval.shape[0],), device=device)
            generated_samples_eval = generator(z_eval, labels_eval).cpu().numpy()
            
            fd_scores = get_final_fd_scores(real_samples_for_eval, generated_samples_eval, CHANNELS)
            frechet_distances.append(fd_scores)
            print(f"  FD Scores (Norm) - Spectral: {fd_scores['FD Spectral (Normalized)']:.4f}, Hjorth: {fd_scores['FD Hjorth (Normalized)']:.4f}")
            
            if fd_scores['FD Total'] < best_fd:
                best_fd = fd_scores['FD Total']
                torch.save(generator.state_dict(), os.path.join(OUTPUT_DIR, 'generator_best_wgangp.pth'))
                torch.save(discriminator.state_dict(), os.path.join(OUTPUT_DIR, 'discriminator_best_wgangp.pth'))
                print(f"  New best model saved with Total FD: {best_fd:.4f}")
        generator.train()

# --- Final Steps ---
print("\nTraining complete.")
plot_figure_6(frechet_distances, SAVE_INTERVAL, os.path.join(OUTPUT_DIR, 'final_fd_progression_wgangp.png'))
print("Final progress plot saved.")

Starting WGAN-GP training for 100 epochs...


Epoch 1/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 2/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 3/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 4/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 5/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 6/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 7/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 8/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 9/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 10/100:   0%|          | 0/555 [00:00<?, ?it/s]


--- Epoch 10 Evaluation ---
  FD Scores (Norm) - Spectral: 0.0441, Hjorth: 0.2639
  New best model saved with Total FD: 0.3079


Epoch 11/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 12/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 13/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 14/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 15/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 16/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 17/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 18/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 19/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 20/100:   0%|          | 0/555 [00:00<?, ?it/s]


--- Epoch 20 Evaluation ---
  FD Scores (Norm) - Spectral: 0.0395, Hjorth: 0.1899
  New best model saved with Total FD: 0.2293


Epoch 21/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 22/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 23/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 24/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 25/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 26/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 27/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 28/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 29/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 30/100:   0%|          | 0/555 [00:00<?, ?it/s]


--- Epoch 30 Evaluation ---
  FD Scores (Norm) - Spectral: 0.0407, Hjorth: 0.1845
  New best model saved with Total FD: 0.2252


Epoch 31/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 32/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 33/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 34/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 35/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 36/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 37/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 38/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 39/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 40/100:   0%|          | 0/555 [00:00<?, ?it/s]


--- Epoch 40 Evaluation ---
  FD Scores (Norm) - Spectral: 0.0474, Hjorth: 0.1855


Epoch 41/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 42/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 43/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 44/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 45/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 46/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 47/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 48/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 49/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 50/100:   0%|          | 0/555 [00:00<?, ?it/s]


--- Epoch 50 Evaluation ---
  FD Scores (Norm) - Spectral: 0.0511, Hjorth: 0.1877


Epoch 51/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 52/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 53/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 54/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 55/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 56/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 57/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 58/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 59/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 60/100:   0%|          | 0/555 [00:00<?, ?it/s]


--- Epoch 60 Evaluation ---
  FD Scores (Norm) - Spectral: 0.0406, Hjorth: 0.1876


Epoch 61/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 62/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 63/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 64/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 65/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 66/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 67/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 68/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 69/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 70/100:   0%|          | 0/555 [00:00<?, ?it/s]


--- Epoch 70 Evaluation ---
  FD Scores (Norm) - Spectral: 0.0493, Hjorth: 0.1871


Epoch 71/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 72/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 73/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 74/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 75/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 76/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 77/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 78/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 79/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 80/100:   0%|          | 0/555 [00:00<?, ?it/s]


--- Epoch 80 Evaluation ---
  FD Scores (Norm) - Spectral: 0.0507, Hjorth: 0.1890


Epoch 81/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 82/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 83/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 84/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 85/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 86/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 87/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 88/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 89/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 90/100:   0%|          | 0/555 [00:00<?, ?it/s]


--- Epoch 90 Evaluation ---
  FD Scores (Norm) - Spectral: 0.0489, Hjorth: 0.1892


Epoch 91/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 92/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 93/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 94/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 95/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 96/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 97/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 98/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 99/100:   0%|          | 0/555 [00:00<?, ?it/s]

Epoch 100/100:   0%|          | 0/555 [00:00<?, ?it/s]


--- Epoch 100 Evaluation ---
  FD Scores (Norm) - Spectral: 0.0543, Hjorth: 0.1872

Training complete.
Generated FD progression plot.
Final progress plot saved.


In [6]:
# ==============================================================================
# Cell 6: Final Analysis and Comprehensive Visualization
# ==============================================================================
#
# This cell loads the best model trained with the WGAN-GP algorithm and
# runs the full evaluation suite. The resulting plots should now show
# a high-quality, continuous distribution, proving the fix was successful.
#
# ==============================================================================

print("\n--- Final WGAN-GP Model Analysis & Visualization ---")

try:
    # Use a larger batch for the final, high-quality plots
    real_samples_final = next(iter(dataloader))[0].numpy()
except StopIteration:
    print("Warning: Dataloader is empty. Using smaller evaluation batch for plots.")
    real_samples_final = real_samples_for_eval

winning_model_samples = None

try:
    print("\n--- Evaluating BEST WGAN-GP Model for Final Plots ---")
    best_generator = ResGenerator(LATENT_DIM, NUM_CLASSES, CHANNELS, TIMESTEPS).to(device)
    # Load the new model weights
    best_generator.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, 'generator_best_wgangp.pth')))
    best_generator.eval()
    
    with torch.no_grad():
        z_final = torch.randn(real_samples_final.shape[0], LATENT_DIM, device=device)
        labels_final = torch.randint(0, NUM_CLASSES, (real_samples_final.shape[0],), device=device)
        winning_model_samples = best_generator(z_final, labels_final).cpu().numpy()
    
    final_fd_scores = get_final_fd_scores(real_samples_final, winning_model_samples, CHANNELS)
    print(f"  > Final FD Scores (Norm) - Spectral: {final_fd_scores['FD Spectral (Normalized)']:.4f}, Hjorth: {final_fd_scores['FD Hjorth (Normalized)']:.4f}")

except FileNotFoundError:
    print("ERROR: Best WGAN-GP model ('generator_best_wgangp.pth') not found. Cannot generate final plots.")

if winning_model_samples is not None:
    print("\n--- Generating All Final Figures for Paper (WGAN-GP Version) ---")
    plot_prefix = os.path.join(OUTPUT_DIR, "Final_WGAN-GP_")
    plot_figure_3(real_samples_final, winning_model_samples, plot_prefix)
    plot_figure_4(real_samples_final, winning_model_samples, plot_prefix)
    plot_figure_5(real_samples_final, winning_model_samples, plot_prefix)
    print(f"\nAll final evaluation plots have been saved to the output directory with the prefix '{plot_prefix}'.")
else:
    print("Skipping final plot generation as no trained WGAN-GP model was found.")


--- Final WGAN-GP Model Analysis & Visualization ---

--- Evaluating BEST WGAN-GP Model for Final Plots ---
  > Final FD Scores (Norm) - Spectral: 0.0316, Hjorth: 0.1095

--- Generating All Final Figures for Paper (WGAN-GP Version) ---
Generated distribution plots (PCA, Histogram).
Generated per-channel PSD plots.
Generated connectivity plots.

All final evaluation plots have been saved to the output directory with the prefix '/kaggle/working/Final_WGAN-GP_'.
