In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.model_selection import train_test_split
from scipy.stats import pearsonr, wasserstein_distance
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns

# Set device and random seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
np.random.seed(42)

In [30]:
class ImprovedDiscriminator(nn.Module):
    def __init__(self):
        super(ImprovedDiscriminator, self).__init__()
        
        from torch.nn.utils import spectral_norm
        
        self.model = nn.Sequential(
            spectral_norm(nn.Linear(num_features + num_emotion_dims, 1024)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            spectral_norm(nn.Linear(1024, 512)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            spectral_norm(nn.Linear(512, 256)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            spectral_norm(nn.Linear(256, 128)),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(128, 1)
            
        )
        
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, eeg_data, emotion_labels):
        input_tensor = torch.cat([eeg_data, emotion_labels], dim=1)
        return self.model(input_tensor)

In [31]:
class ImprovedGenerator(nn.Module):
    def __init__(self):
        super(ImprovedGenerator, self).__init__()
        
        # Use spectral normalization for training stability
        from torch.nn.utils import spectral_norm
        
        self.input_layer = nn.Linear(nz + num_emotion_dims, 256)
        
        # Residual blocks for better gradient flow
        self.block1 = nn.Sequential(
            spectral_norm(nn.Linear(256, 512)),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2)
        )
        
        self.block2 = nn.Sequential(
            spectral_norm(nn.Linear(512, 1024)),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2)
        )
        
        self.block3 = nn.Sequential(
            spectral_norm(nn.Linear(1024, 2048)),
            nn.BatchNorm1d(2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2)
        )
        
        self.output_layer = nn.Linear(2048, num_features)
        
        # Initialize weights properly
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
    
    def forward(self, noise, emotion_labels):
        x = torch.cat([noise, emotion_labels], dim=1)
        x = torch.relu(self.input_layer(x))
        
        # Residual connections where possible
        x1 = self.block1(x)
        x2 = self.block2(x1)
        x3 = self.block3(x2)
        
        # Output without activation - let it learn the right range
        return self.output_layer(x3)


In [32]:
def inception_score(generated_samples, discriminator, emotion_labels, splits=10):
    """
    Calculate Inception Score for generated samples
    """
    discriminator.eval()
    scores = []
    
    with torch.no_grad():
        for i in range(splits):
            part = generated_samples[i * len(generated_samples) // splits:(i + 1) * len(generated_samples) // splits]
            part_emotions = emotion_labels[i * len(emotion_labels) // splits:(i + 1) * len(emotion_labels) // splits]
            
            part_tensor = torch.FloatTensor(part).to(device)
            part_emotions_tensor = torch.FloatTensor(part_emotions).to(device)
            
            pred = torch.sigmoid(discriminator(part_tensor, part_emotions_tensor))
            
            # Calculate score
            p_y = pred.mean(dim=0, keepdim=True)
            scores.append(torch.exp(torch.mean(pred * (torch.log(pred) - torch.log(p_y)))))
    
    return torch.mean(torch.stack(scores)).item(), torch.std(torch.stack(scores)).item()


In [33]:
def improved_preprocess_for_gan(X, y):
    """
    Better preprocessing for GAN training
    """
    # Flatten X
    X_flat = X.reshape(X.shape[0], -1)  # (samples, 2880)
    
    # Use StandardScaler for more stable training
    scaler_X = StandardScaler()
    X_normalized = scaler_X.fit_transform(X_flat)
    
    # Normalize emotion labels properly
    scaler_y = StandardScaler()
    y_normalized = scaler_y.fit_transform(y)
    
    return X_normalized, y_normalized, scaler_X, scaler_y

# Wasserstein Loss with Gradient Penalty (more stable than BCE)
def gradient_penalty(discriminator, real_data, fake_data, emotion_labels, lambda_gp=10):
    """
    Calculate gradient penalty for WGAN-GP
    """
    batch_size = real_data.size(0)
    alpha = torch.rand(batch_size, 1).to(device)
    alpha = alpha.expand_as(real_data)
    
    interpolated = alpha * real_data + (1 - alpha) * fake_data
    interpolated.requires_grad_(True)
    
    # Calculate discriminator output for interpolated data
    d_interpolated = discriminator(interpolated, emotion_labels)
    
    # Calculate gradients
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    
    return penalty


In [34]:
def calculate_generation_metrics(real_data, fake_data):
    """
    Calculate comprehensive metrics for generated vs real data
    """
    # Flatten for calculations
    real_flat = real_data.flatten()
    fake_flat = fake_data.flatten()
    
    metrics = {}
    
    # Distribution similarity
    metrics['wasserstein_distance'] = wasserstein_distance(real_flat, fake_flat)
    metrics['ks_statistic'], metrics['ks_p_value'] = stats.ks_2samp(real_flat, fake_flat)
    
    # Statistical moments
    metrics['mean_diff'] = abs(np.mean(real_flat) - np.mean(fake_flat))
    metrics['std_diff'] = abs(np.std(real_flat) - np.std(fake_flat))
    metrics['skewness_diff'] = abs(stats.skew(real_flat) - stats.skew(fake_flat))
    metrics['kurtosis_diff'] = abs(stats.kurtosis(real_flat) - stats.kurtosis(fake_flat))
    
    # Correlation with real data patterns
    if len(real_flat) == len(fake_flat):
        correlation, p_value = pearsonr(real_flat, fake_flat)
        metrics['correlation'] = correlation
        metrics['correlation_p_value'] = p_value
    
    # MSE and MAE
    if real_data.shape == fake_data.shape:
        metrics['mse'] = mean_squared_error(real_flat, fake_flat)
        metrics['mae'] = mean_absolute_error(real_flat, fake_flat)
        metrics['rmse'] = np.sqrt(metrics['mse'])
    
    # Frequency domain analysis
    real_fft = np.fft.fft(real_data.reshape(-1, 32, 90), axis=-1)
    fake_fft = np.fft.fft(fake_data.reshape(-1, 32, 90), axis=-1)
    
    real_power = np.mean(np.abs(real_fft) ** 2)
    fake_power = np.mean(np.abs(fake_fft) ** 2)
    metrics['power_spectrum_diff'] = abs(real_power - fake_power) / real_power
    
    return metrics



In [35]:


# IMPROVED Hyperparameters
nz = 128  # Increased noise dimension for more diversity
num_features = 2880  # 32 * 90 flattened
num_emotion_dims = 4
lr_g = 0.0001  # Lower learning rate for generator (more stable)
lr_d = 0.0002  # Higher learning rate for discriminator
beta1 = 0.5
batch_size = 32
num_epochs = 300  # More epochs for better convergence



# IMPROVED training function with comprehensive evaluation
def train_improved_gan(X, y, num_epochs=300, use_wgan_gp=True):
    """
    Enhanced GAN training with evaluation metrics
    """
    print(f"Training Enhanced GAN on data: X{X.shape}, y{y.shape}")
    
    # Split data for evaluation
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # Preprocess data
    X_train_proc, y_train_proc, scaler_X, scaler_y = improved_preprocess_for_gan(X_train, y_train)
    X_test_proc, y_test_proc, _, _ = improved_preprocess_for_gan(X_test, y_test)
    
    # Convert to tensors
    X_train_tensor = torch.FloatTensor(X_train_proc).to(device)
    y_train_tensor = torch.FloatTensor(y_train_proc).to(device)
    X_test_tensor = torch.FloatTensor(X_test_proc).to(device)
    y_test_tensor = torch.FloatTensor(y_test_proc).to(device)
    
    # Create dataloader
    dataset = TensorDataset(X_train_tensor, y_train_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    
    # Initialize improved networks
    generator = ImprovedGenerator().to(device)
    discriminator = ImprovedDiscriminator().to(device)
    
    print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
    
    # Optimizers with different learning rates
    optimizer_G = optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, 0.999))
    
    # Learning rate schedulers
    scheduler_G = optim.lr_scheduler.ExponentialLR(optimizer_G, gamma=0.995)
    scheduler_D = optim.lr_scheduler.ExponentialLR(optimizer_D, gamma=0.995)
    
    # Loss function
    if use_wgan_gp:
        criterion = None  # Use Wasserstein loss
    else:
        criterion = nn.BCEWithLogitsLoss()
    
    # Training history
    G_losses = []
    D_losses = []
    evaluation_history = []
    
    print("Starting Enhanced GAN Training...")
    
    for epoch in range(num_epochs):
        epoch_g_loss = 0
        epoch_d_loss = 0
        
        for i, (real_data, real_labels) in enumerate(dataloader):
            batch_size_current = real_data.size(0)
            
            # ---------------------
            # Train Discriminator (more frequently for stability)
            # ---------------------
            for _ in range(2):  # Train D twice per G training
                optimizer_D.zero_grad()
                
                if use_wgan_gp:
                    # Wasserstein loss with gradient penalty
                    # Real data
                    d_real = discriminator(real_data, real_labels).mean()
                    
                    # Fake data
                    noise = torch.randn(batch_size_current, nz).to(device)
                    fake_data = generator(noise, real_labels).detach()
                    d_fake = discriminator(fake_data, real_labels).mean()
                    
                    # Gradient penalty
                    gp = gradient_penalty(discriminator, real_data, fake_data, real_labels)
                    
                    # Wasserstein loss
                    loss_D = d_fake - d_real + gp
                else:
                    # Standard GAN loss
                    real_labels_disc = torch.ones(batch_size_current, 1).to(device)
                    fake_labels_disc = torch.zeros(batch_size_current, 1).to(device)
                    
                    # Real data
                    output_real = discriminator(real_data, real_labels)
                    loss_D_real = criterion(output_real, real_labels_disc)
                    
                    # Fake data
                    noise = torch.randn(batch_size_current, nz).to(device)
                    fake_data = generator(noise, real_labels).detach()
                    output_fake = discriminator(fake_data, real_labels)
                    loss_D_fake = criterion(output_fake, fake_labels_disc)
                    
                    loss_D = (loss_D_real + loss_D_fake) / 2
                
                loss_D.backward()
                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
                optimizer_D.step()
            
            # -----------------
            # Train Generator
            # -----------------
            optimizer_G.zero_grad()
            
            noise = torch.randn(batch_size_current, nz).to(device)
            fake_data = generator(noise, real_labels)
            
            if use_wgan_gp:
                # Wasserstein loss for generator
                loss_G = -discriminator(fake_data, real_labels).mean()
            else:
                # Standard GAN loss
                output_fake = discriminator(fake_data, real_labels)
                loss_G = criterion(output_fake, torch.ones(batch_size_current, 1).to(device))
            
            loss_G.backward()
            torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
            optimizer_G.step()
            
            epoch_g_loss += loss_G.item()
            epoch_d_loss += loss_D.item()
        
        # Update learning rates
        if epoch > 50:
            scheduler_G.step()
            scheduler_D.step()
        
        # Record losses
        avg_g_loss = epoch_g_loss / len(dataloader)
        avg_d_loss = epoch_d_loss / len(dataloader)
        G_losses.append(avg_g_loss)
        D_losses.append(avg_d_loss)
        
        # Comprehensive evaluation every 25 epochs
        if epoch % 25 == 0:
            print(f'Epoch [{epoch}/{num_epochs}] | G Loss: {avg_g_loss:.4f} | D Loss: {avg_d_loss:.4f}')
            
            # Generate samples for evaluation
            generator.eval()
            with torch.no_grad():
                test_noise = torch.randn(len(X_test), nz).to(device)
                generated_test = generator(test_noise, y_test_tensor)
                generated_test_np = generated_test.cpu().numpy()
                
                # Convert back to original scale
                generated_original = scaler_X.inverse_transform(generated_test_np)
                generated_shaped = generated_original.reshape(-1, 32, 90)
                
                # Calculate metrics
                metrics = calculate_generation_metrics(X_test, generated_shaped)
                evaluation_history.append({
                    'epoch': epoch,
                    'metrics': metrics,
                    'g_loss': avg_g_loss,
                    'd_loss': avg_d_loss
                })
                
                print(f"  Wasserstein Distance: {metrics['wasserstein_distance']:.4f}")
                print(f"  Mean Difference: {metrics['mean_diff']:.4f}")
                print(f"  Power Spectrum Diff: {metrics['power_spectrum_diff']:.4f}")
            
            generator.train()
    
    return generator, discriminator, G_losses, D_losses, evaluation_history, scaler_X, scaler_y






In [36]:

# Comprehensive visualization functions
def plot_comprehensive_results(G_losses, D_losses, evaluation_history, real_samples, fake_samples):
    """
    Create comprehensive visualization of GAN results
    """
    fig, axes = plt.subplots(3, 3, figsize=(20, 15))
    
    # 1. Training losses
    axes[0,0].plot(G_losses, label='Generator Loss', alpha=0.8)
    axes[0,0].plot(D_losses, label='Discriminator Loss', alpha=0.8)
    axes[0,0].set_xlabel('Epoch')
    axes[0,0].set_ylabel('Loss')
    axes[0,0].set_title('Training Losses')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. Wasserstein distance over time
    epochs = [eval_data['epoch'] for eval_data in evaluation_history]
    wasserstein_dists = [eval_data['metrics']['wasserstein_distance'] for eval_data in evaluation_history]
    axes[0,1].plot(epochs, wasserstein_dists, 'o-', color='red', alpha=0.8)
    axes[0,1].set_xlabel('Epoch')
    axes[0,1].set_ylabel('Wasserstein Distance')
    axes[0,1].set_title('Distribution Similarity Over Time')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Statistical moments comparison
    mean_diffs = [eval_data['metrics']['mean_diff'] for eval_data in evaluation_history]
    std_diffs = [eval_data['metrics']['std_diff'] for eval_data in evaluation_history]
    axes[0,2].plot(epochs, mean_diffs, 'o-', label='Mean Difference', alpha=0.8)
    axes[0,2].plot(epochs, std_diffs, 's-', label='Std Difference', alpha=0.8)
    axes[0,2].set_xlabel('Epoch')
    axes[0,2].set_ylabel('Difference')
    axes[0,2].set_title('Statistical Moments Differences')
    axes[0,2].legend()
    axes[0,2].grid(True, alpha=0.3)
    
    # 4. Data distribution comparison
    axes[1,0].hist(real_samples.flatten(), bins=50, alpha=0.7, label='Real Data', density=True)
    axes[1,0].hist(fake_samples.flatten(), bins=50, alpha=0.7, label='Generated Data', density=True)
    axes[1,0].set_xlabel('Feature Values')
    axes[1,0].set_ylabel('Density')
    axes[1,0].set_title('Data Distribution Comparison')
    axes[1,0].legend()
    
    # 5. Q-Q Plot
    real_flat_sorted = np.sort(real_samples.flatten())
    fake_flat_sorted = np.sort(fake_samples.flatten())
    min_len = min(len(real_flat_sorted), len(fake_flat_sorted))
    axes[1,1].scatter(real_flat_sorted[:min_len], fake_flat_sorted[:min_len], alpha=0.5, s=1)
    axes[1,1].plot([real_flat_sorted.min(), real_flat_sorted.max()], 
                   [real_flat_sorted.min(), real_flat_sorted.max()], 'r--')
    axes[1,1].set_xlabel('Real Data Quantiles')
    axes[1,1].set_ylabel('Generated Data Quantiles')
    axes[1,1].set_title('Q-Q Plot')
    
    # 6. Power spectrum comparison
    real_fft = np.fft.fft(real_samples[:100].reshape(-1, 90), axis=-1)
    fake_fft = np.fft.fft(fake_samples[:100].reshape(-1, 90), axis=-1)
    
    freqs = np.fft.fftfreq(90)[:45]  # Only positive frequencies
    real_power = np.mean(np.abs(real_fft[:, :45]) ** 2, axis=0)
    fake_power = np.mean(np.abs(fake_fft[:, :45]) ** 2, axis=0)
    
    axes[1,2].plot(freqs, real_power, label='Real Data', alpha=0.8)
    axes[1,2].plot(freqs, fake_power, label='Generated Data', alpha=0.8)
    axes[1,2].set_xlabel('Frequency')
    axes[1,2].set_ylabel('Power')
    axes[1,2].set_title('Power Spectrum Comparison')
    axes[1,2].legend()
    
    # 7. Sample EEG comparison
    sample_idx = 0
    im1 = axes[2,0].imshow(real_samples[sample_idx], aspect='auto', cmap='viridis')
    axes[2,0].set_title('Real EEG Sample')
    axes[2,0].set_xlabel('Time')
    axes[2,0].set_ylabel('Channels')
    plt.colorbar(im1, ax=axes[2,0])
    
    im2 = axes[2,1].imshow(fake_samples[sample_idx], aspect='auto', cmap='viridis')
    axes[2,1].set_title('Generated EEG Sample')
    axes[2,1].set_xlabel('Time')
    axes[2,1].set_ylabel('Channels')
    plt.colorbar(im2, ax=axes[2,1])
    
    # 8. Correlation matrix
    if len(real_samples) == len(fake_samples):
        correlation_matrix = np.corrcoef(real_samples.reshape(len(real_samples), -1), 
                                       fake_samples.reshape(len(fake_samples), -1))
        im3 = axes[2,2].imshow(correlation_matrix, cmap='coolwarm', vmin=-1, vmax=1)
        axes[2,2].set_title('Real vs Generated Correlation Matrix')
        plt.colorbar(im3, ax=axes[2,2])
    
    plt.tight_layout()
    plt.show()

In [37]:
def print_final_evaluation(evaluation_history):
    """
    Print comprehensive final evaluation metrics
    """
    if not evaluation_history:
        print("No evaluation history available.")
        return
    
    final_metrics = evaluation_history[-1]['metrics']
    
    print("\n" + "="*60)
    print("ENHANCED GAN FINAL EVALUATION METRICS")
    print("="*60)
    print(f"Wasserstein Distance:         {final_metrics['wasserstein_distance']:.6f}")
    print(f"KS Test Statistic:            {final_metrics['ks_statistic']:.6f} (p={final_metrics['ks_p_value']:.4e})")
    print(f"Mean Difference:              {final_metrics['mean_diff']:.6f}")
    print(f"Standard Deviation Diff:      {final_metrics['std_diff']:.6f}")
    print(f"Skewness Difference:          {final_metrics['skewness_diff']:.6f}")
    print(f"Kurtosis Difference:          {final_metrics['kurtosis_diff']:.6f}")
    
    if 'mse' in final_metrics:
        print(f"Mean Squared Error:           {final_metrics['mse']:.6f}")
        print(f"Root Mean Squared Error:      {final_metrics['rmse']:.6f}")
        print(f"Mean Absolute Error:          {final_metrics['mae']:.6f}")
    
    if 'correlation' in final_metrics:
        print(f"Pearson Correlation:          {final_metrics['correlation']:.4f} (p={final_metrics['correlation_p_value']:.4e})")
    
    print(f"Power Spectrum Difference:    {final_metrics['power_spectrum_diff']:.4f}")
    print("="*60)
    
    # Quality assessment
    print("\nQUALITY ASSESSMENT:")
    wass_dist = final_metrics['wasserstein_distance']
    if wass_dist < 0.1:
        print("🟢 EXCELLENT: Very similar distributions")
    elif wass_dist < 0.3:
        print("🟡 GOOD: Reasonably similar distributions")
    elif wass_dist < 0.5:
        print("🟠 MODERATE: Some differences in distributions")
    else:
        print("🔴 NEEDS IMPROVEMENT: Significant distribution differences")


In [38]:
def generate_enhanced_samples(generator, emotion_conditions, scaler_X, scaler_y, num_samples=100):
    """
    Generate enhanced samples with proper scaling and evaluation
    """
    generator.eval()
    device = next(generator.parameters()).device
    
    with torch.no_grad():
        # Normalize emotion conditions
        emotion_normalized = scaler_y.transform(emotion_conditions)
        emotion_tensor = torch.FloatTensor(emotion_normalized).to(device)
        
        # Generate noise
        noise = torch.randn(num_samples, nz).to(device)
        
        # Generate samples
        fake_data = generator(noise, emotion_tensor)
        fake_data_np = fake_data.cpu().numpy()
        
        # Convert back to original scale
        fake_data_original = scaler_X.inverse_transform(fake_data_np)
        fake_data_shaped = fake_data_original.reshape(num_samples, 32, 90)
    
    return fake_data_shaped


In [39]:
# Main enhanced training function
def main_enhanced_gan_training(X, y):
    """
    Main function for enhanced GAN training with comprehensive evaluation
    """
    print(f"Enhanced GAN Training - Data shape: X{X.shape}, y{y.shape}")
    
    # Train the enhanced GAN
    generator, discriminator, g_losses, d_losses, eval_history, scaler_X, scaler_y = train_improved_gan(
        X, y, num_epochs=200, use_wgan_gp=True
    )
    
    # Generate test samples
    test_emotions = np.array([[5, 5, 5, 5], [7, 3, 6, 4], [2, 8, 3, 7]] * 20)  # 60 samples
    generated_samples = generate_enhanced_samples(
        generator, test_emotions, scaler_X, scaler_y, num_samples=len(test_emotions)
    )
    
    # Create comprehensive visualizations
    # Use a subset of real data for comparison
    real_subset = X[:len(generated_samples)]
    plot_comprehensive_results(g_losses, d_losses, eval_history, real_subset, generated_samples)
    
    # Print final evaluation
    print_final_evaluation(eval_history)
    
    return generator, discriminator, generated_samples, scaler_X, scaler_y, eval_history



In [40]:

generator, discriminator, fake_samples, scaler_X, scaler_y, evaluation = main_enhanced_gan_training(X, y)

Enhanced GAN Training - Data shape: X(1280, 32), y(1280, 4)
Training Enhanced GAN on data: X(1280, 32), y(1280, 4)
Generator parameters: 8,698,432
Discriminator parameters: 3,643,393
Starting Enhanced GAN Training...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x36 and 2884x1024)