# Modular Quantum GAN Training with Comprehensive Monitoring

This notebook demonstrates extended training of the modular quantum GAN architecture with:

- **Modular Architecture**: Clean separation of quantum components
- **Extended Training**: 10+ epochs for proper convergence
- **Quality Tracking**: Real-time monitoring of generation quality
- **Comprehensive Visualization**: Training evolution dashboard
- **Performance Analysis**: Detailed convergence analysis

This implementation uses the new modular architecture with pure quantum learning.

In [1]:
# Core imports
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import sys
import os
import time
from scipy.stats import wasserstein_distance

# Configure matplotlib for Jupyter
%matplotlib inline
plt.style.use('default')

# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(os.getcwd()), 'src'))

# Import our clean training utilities
from utils.warning_suppression import suppress_all_quantum_warnings
from losses.quantum_gan_loss import QuantumGANLoss

# Suppress warnings
suppress_all_quantum_warnings()

print("Extended training environment setup complete ✓")
print(f"TensorFlow version: {tf.__version__}")


Extended training environment setup complete ✓
TensorFlow version: 2.18.0


## 2. Import Modular Quantum Components

In [2]:
# Import our modular quantum components
from models.generators.quantum_generator import PureQuantumGenerator
from models.discriminators.quantum_discriminator import PureQuantumDiscriminator
from models.quantum_gan import QuantumGAN
from models.transformations.matrix_manager import StaticTransformationMatrix

print("Modular quantum components imported successfully ✓")


Modular quantum components imported successfully ✓


## 3. Create Data and Components

In [3]:
def create_simple_2d_data(n_samples=50):
    """
    Create simple 2D Gaussian mixture for testing.
    """
    np.random.seed(42)
    
    # Two Gaussian clusters
    cluster1 = np.random.normal([1.0, 1.0], 0.3, (n_samples//2, 2))
    cluster2 = np.random.normal([-1.0, -1.0], 0.3, (n_samples//2, 2))
    
    data = np.vstack([cluster1, cluster2])
    
    # Normalize to [-1, 1] range for quantum stability
    data = data / np.max(np.abs(data))
    
    return tf.constant(data, dtype=tf.float32)

# Generate data
real_data = create_simple_2d_data(n_samples=50)
print(f"Data created: {real_data.shape}")
print(f"Data range: [{tf.reduce_min(real_data):.3f}, {tf.reduce_max(real_data):.3f}]")


Data created: (50, 2)
Data range: [-1.000, 0.871]


In [4]:
# Create the modular quantum GAN
print("\nCreating modular quantum GAN components...")

# Generator configuration
generator_config = {
    'latent_dim': 6,
    'output_dim': 2,
    'n_modes': 4,
    'layers': 3,
    'cutoff_dim': 10,
    'measurement_type': 'raw'
}

# Discriminator configuration
discriminator_config = {
    'input_dim': 2,
    'n_modes': 2,
    'layers': 2,
    'cutoff_dim': 10,
    'measurement_type': 'raw'
}

# Create the quantum GAN
qgan = QuantumGAN(
    generator_config=generator_config,
    discriminator_config=discriminator_config,
    loss_type='wasserstein',
    learning_rate_g=5e-4,
    learning_rate_d=5e-4,
    n_critic=5
)

print(f"✓ Modular quantum GAN created")
print(f"  Generator parameters: {len(qgan.generator.trainable_variables)}")
print(f"  Discriminator parameters: {len(qgan.discriminator.trainable_variables)}")
print(f"  Total quantum parameters: {len(qgan.generator.trainable_variables) + len(qgan.discriminator.trainable_variables)}")


INFO:quantum.core.quantum_circuit:Quantum circuit base initialized: 4 modes, cutoff=10



Creating modular quantum GAN components...


INFO:quantum.parameters.gate_parameters:Gate parameter manager initialized: 138 parameters
INFO:quantum.builders.circuit_builder:Circuit builder initialized
INFO:quantum.core.quantum_circuit:Pure quantum circuit initialized: 3 layers
INFO:quantum.measurements.measurement_extractor:Raw measurement extractor initialized: 12 measurements
INFO:models.transformations.matrix_manager:Static transformation matrix 'generator_encoder': 6 → 12
INFO:models.transformations.matrix_manager:Static transformation matrix 'generator_decoder': 12 → 2
INFO:models.transformations.matrix_manager:Transformation pair created: (6, 12) → (12, 2) (trainable=False)
INFO:models.generators.quantum_generator:Pure quantum generator initialized: 6 → 2
INFO:models.generators.quantum_generator:  Quantum modes: 4, Layers: 3
INFO:models.generators.quantum_generator:  Measurement type: raw
INFO:models.generators.quantum_generator:  Using STATIC transformations (pure quantum learning)
INFO:quantum.core.quantum_circuit:Quantu

✓ Modular quantum GAN created
  Generator parameters: 138
  Discriminator parameters: 28
  Total quantum parameters: 166


## 4. Enhanced Training Monitoring


In [5]:
class EnhancedTrainingMonitor:
    """
    Enhanced training monitor for quantum GAN with quality metrics.
    """
    
    def __init__(self, qgan):
        self.qgan = qgan
        self.history = {
            'epochs': [],
            'mean_differences': [],
            'std_differences': [],
            'wasserstein_distances': [],
            'generator_losses': [],
            'discriminator_losses': [],
            'training_times': []
        }
    
    def compute_quality_metrics(self, real_data, n_samples=50):
        """Compute comprehensive quality metrics."""
        # Generate samples
        generated_samples = self.qgan.generate(n_samples)
        
        # Convert to numpy
        real_np = real_data.numpy() if hasattr(real_data, 'numpy') else real_data
        gen_np = generated_samples.numpy()
        
        # Compute metrics
        real_mean = np.mean(real_np, axis=0)
        gen_mean = np.mean(gen_np, axis=0)
        real_std = np.std(real_np, axis=0)
        gen_std = np.std(gen_np, axis=0)
        
        mean_diff = np.linalg.norm(real_mean - gen_mean)
        std_diff = np.linalg.norm(real_std - gen_std)
        
        # Wasserstein distance (1D approximation)
        try:
            wd = wasserstein_distance(real_np[:, 0], gen_np[:, 0])
        except:
            wd = float('inf')
        
        return {
            'mean_difference': mean_diff,
            'std_difference': std_diff,
            'wasserstein_distance': wd,
            'generated_samples': gen_np
        }
    
    def train_with_monitoring(self, data, epochs=50, batch_size=8, 
                            monitor_interval=5, verbose=True):
        """
        Train with comprehensive monitoring.
        """
        print(f"Starting training with Quantum Wasserstein Loss: {epochs} epochs")
        print(f"Monitoring every {monitor_interval} epochs")
        print(f"Data shape: {data.shape}")
        
        start_time = time.time()
        
        # Initial quality assessment
        initial_quality = self.compute_quality_metrics(data)
        print(f"Initial quality - Mean diff: {initial_quality['mean_difference']:.4f}")
        
        # FIXED training data generator function
        def data_generator():
            indices = tf.random.uniform([batch_size], 0, tf.shape(data)[0], dtype=tf.int32)
            return tf.gather(data, indices)
        
        # Train in chunks for monitoring
        for epoch_chunk in range(0, epochs, monitor_interval):
            chunk_epochs = min(monitor_interval, epochs - epoch_chunk)
            chunk_start_time = time.time()
            
            # Train for this chunk
            for _ in range(chunk_epochs):
                # Train discriminator
                for _ in range(self.qgan.n_critic):
                    batch = data_generator()
                    z_batch = tf.random.normal([batch_size, self.qgan.generator.latent_dim])
                    d_loss = self.qgan.train_discriminator_step(batch, z_batch)
                
                # Train generator
                z_batch = tf.random.normal([batch_size, self.qgan.generator.latent_dim])
                g_loss = self.qgan.train_generator_step(z_batch)
            
            # Current epoch number
            current_epoch = epoch_chunk + chunk_epochs
            
            # Compute quality metrics
            quality_metrics = self.compute_quality_metrics(data)
            
            # Store metrics
            self.history['epochs'].append(current_epoch)
            self.history['mean_differences'].append(quality_metrics['mean_difference'])
            self.history['std_differences'].append(quality_metrics['std_difference'])
            self.history['wasserstein_distances'].append(quality_metrics['wasserstein_distance'])
            self.history['generator_losses'].append(float(g_loss))
            self.history['discriminator_losses'].append(float(d_loss))
            self.history['training_times'].append(time.time() - start_time)
            
            if verbose:
                print(f"Epoch {current_epoch:3d}: G_loss={float(g_loss):.4f}, "
                      f"D_loss={float(d_loss):.4f}, "
                      f"Mean_diff={quality_metrics['mean_difference']:.4f}, "
                      f"WD={quality_metrics['wasserstein_distance']:.4f}")
        
        total_time = time.time() - start_time
        print(f"\nQuantum Wasserstein training completed in {total_time:.1f}s")
        
        return self.history

print("Enhanced training monitor created ✓")


Enhanced training monitor created ✓


## 5. Extended Training with Real-Time Monitoring


In [6]:
print("Starting extended quantum GAN training with comprehensive monitoring...")
print("This will take significantly longer but should show real learning!")

# Create monitor
monitor = EnhancedTrainingMonitor(qgan)

# Extended training
training_history = monitor.train_with_monitoring(
    data=real_data,
    epochs=10,           # Longer training
    batch_size=4,         # Small batches for quantum stability
    monitor_interval=5,   # Monitor every 5 epochs
    verbose=True
)

print("\n✓ Extended training completed!")


INFO:quantum.core.quantum_circuit:Circuit built with 138 parameters


Starting extended quantum GAN training with comprehensive monitoring...
This will take significantly longer but should show real learning!
Starting training with Quantum Wasserstein Loss: 10 epochs
Monitoring every 5 epochs
Data shape: (50, 2)
Initial quality - Mean diff: 0.0243


ValueError: in user code:

    File "c:\Users\MendMa1\Documents\Personal\Thesis\github\QNNCV\src\models\quantum_gan.py", line 100, in train_discriminator_step  *
        fake_batch = self.generator.generate(z_batch)
    File "c:\Users\MendMa1\Documents\Personal\Thesis\github\QNNCV\src\models\generators\quantum_generator.py", line 133, in generate  *
        if j < tf.shape(encoding_values)[0]:

    ValueError: 'modulation[name]' must also be initialized in the else branch


## 6. Comprehensive Training Analysis Dashboard

In [None]:
# Create comprehensive training dashboard
fig, axes = plt.subplots(3, 2, figsize=(15, 18))

epochs = training_history['epochs']

# 1. Loss Evolution
ax1 = axes[0, 0]
ax1.plot(epochs, training_history['generator_losses'], label='Generator Loss', color='blue', linewidth=2)
ax1.plot(epochs, training_history['discriminator_losses'], label='Discriminator Loss', color='red', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss Evolution', fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_yscale('log')

# 2. Quality Metrics Evolution
ax2 = axes[0, 1]
ax2.plot(epochs, training_history['mean_differences'], label='Mean Difference', color='green', linewidth=2)
ax2.plot(epochs, training_history['std_differences'], label='Std Difference', color='orange', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Difference')
ax2.set_title('Quality Metrics Evolution', fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_yscale('log')

# 3. Wasserstein Distance
ax3 = axes[1, 0]
valid_wd = [wd for wd in training_history['wasserstein_distances'] if wd != float('inf')]
valid_epochs = epochs[:len(valid_wd)]
if valid_wd:
    ax3.plot(valid_epochs, valid_wd, label='Wasserstein Distance', color='purple', linewidth=2)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Wasserstein Distance')
    ax3.set_title('Distribution Distance Evolution', fontweight='bold')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    ax3.set_yscale('log')

# 4. Stability Metric
ax4 = axes[1, 1]
stability_metrics = [g/d if d != 0 else float('inf') for g, d in zip(training_history['generator_losses'], training_history['discriminator_losses'])]
ax4.plot(epochs, stability_metrics, label='G/D Loss Ratio', color='brown', linewidth=2)
ax4.axhline(y=1.0, color='black', linestyle='--', alpha=0.5, label='Perfect Balance')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Stability Ratio')
ax4.set_title('Training Stability', fontweight='bold')
ax4.legend()
ax4.grid(True, alpha=0.3)
ax4.set_yscale('log')

# 5. Final Generated vs Real Data
ax5 = axes[2, 0]
final_quality = monitor.compute_quality_metrics(real_data, n_samples=300)
generated_samples = final_quality['generated_samples']

ax5.scatter(real_data[:, 0], real_data[:, 1], alpha=0.6, s=20, color='blue', label='Real Data')
ax5.scatter(generated_samples[:, 0], generated_samples[:, 1], alpha=0.6, s=20, color='red', label='Generated Data')
ax5.set_xlabel('X₁')
ax5.set_ylabel('X₂')
ax5.set_title('Final: Real vs Generated Data', fontweight='bold')
ax5.legend()
ax5.grid(True, alpha=0.3)
ax5.axis('equal')

# 6. Training Time Analysis
ax6 = axes[2, 1]
ax6.plot(epochs, training_history['training_times'], label='Time per Monitoring Interval', color='gray', linewidth=2)
ax6.set_xlabel('Epoch')
ax6.set_ylabel('Time (seconds)')
ax6.set_title('Training Time Analysis', fontweight='bold')
ax6.legend()
ax6.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final analysis
print("\n" + "="*60)
print("EXTENDED TRAINING ANALYSIS")
print("="*60)

initial_mean_diff = training_history['mean_differences'][0]
final_mean_diff = training_history['mean_differences'][-1]
improvement = ((initial_mean_diff - final_mean_diff) / initial_mean_diff) * 100

print(f"\nQuality Improvement:")
print(f"  Initial Mean Difference: {initial_mean_diff:.4f}")
print(f"  Final Mean Difference: {final_mean_diff:.4f}")
print(f"  Improvement: {improvement:.1f}%")

print(f"\nFinal Metrics:")
print(f"  Generator Loss: {training_history['generator_losses'][-1]:.4f}")
print(f"  Discriminator Loss: {training_history['discriminator_losses'][-1]:.4f}")
print(f"  Stability Ratio: {stability_metrics[-1]:.4f}")

if valid_wd:
    print(f"  Final Wasserstein Distance: {valid_wd[-1]:.4f}")

print(f"\nTraining Configuration:")
print(f"  Total Epochs: {len(epochs)} monitoring points over {max(epochs)} epochs")
print(f"  Architecture: Generator({len(qgan.generator.trainable_variables)} params), Discriminator({len(qgan.discriminator.trainable_variables)} params)")
print(f"  Total Training Time: {training_history['training_times'][-1]:.1f}s")

# Convergence analysis
if len(training_history['mean_differences']) > 10:
    recent_improvement = training_history['mean_differences'][-5:]
    if max(recent_improvement) - min(recent_improvement) < 0.1:
        print(f"\n✓ Training appears to have converged (stable quality in last 5 measurements)")
    else:
        print(f"\n⚠ Training may benefit from additional epochs (quality still changing)")

print("\n" + "="*60)


## 7. Detailed Quality Assessment


In [None]:
# Comprehensive final evaluation
print("Performing detailed quality assessment...")

# Generate larger sample for final evaluation
final_evaluation = monitor.compute_quality_metrics(real_data, n_samples=500)

print(f"\nFinal Quality Assessment (500 samples):")
print(f"  Mean Difference: {final_evaluation['mean_difference']:.4f}")
print(f"  Std Difference: {final_evaluation['std_difference']:.4f}")
print(f"  Wasserstein Distance: {final_evaluation['wasserstein_distance']:.4f}")

# Quality benchmarks
print(f"\nQuality Benchmarks:")
if final_evaluation['mean_difference'] < 0.5:
    print(f"  ✓ Excellent mean matching (< 0.5)")
elif final_evaluation['mean_difference'] < 1.0:
    print(f"  ✓ Good mean matching (< 1.0)")
elif final_evaluation['mean_difference'] < 2.0:
    print(f"  ⚠ Fair mean matching (< 2.0)")
else:
    print(f"  ✗ Poor mean matching (≥ 2.0)")

if final_evaluation['std_difference'] < 0.3:
    print(f"  ✓ Excellent variance matching (< 0.3)")
elif final_evaluation['std_difference'] < 0.6:
    print(f"  ✓ Good variance matching (< 0.6)")
else:
    print(f"  ⚠ Poor variance matching (≥ 0.6)")

# Distribution visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

generated_samples = final_evaluation['generated_samples']

# Scatter plot comparison
ax1 = axes[0]
ax1.scatter(real_data[:, 0], real_data[:, 1], alpha=0.6, s=15, color='blue', label='Real Data')
ax1.scatter(generated_samples[:, 0], generated_samples[:, 1], alpha=0.6, s=15, color='red', label='Generated Data')
ax1.set_xlabel('X₁')
ax1.set_ylabel('X₂')
ax1.set_title('Final Distribution Comparison')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.axis('equal')

# X1 marginal distribution
ax2 = axes[1]
ax2.hist(real_data[:, 0].numpy(), bins=30, alpha=0.7, density=True, color='blue', label='Real X₁', histtype='step', linewidth=2)
ax2.hist(generated_samples[:, 0], bins=30, alpha=0.7, density=True, color='red', label='Generated X₁', histtype='step', linewidth=2)
ax2.set_xlabel('X₁ Value')
ax2.set_ylabel('Density')
ax2.set_title('X₁ Marginal Distribution')
ax2.legend()
ax2.grid(True, alpha=0.3)

# X2 marginal distribution
ax3 = axes[2]
ax3.hist(real_data[:, 1].numpy(), bins=30, alpha=0.7, density=True, color='blue', label='Real X₂', histtype='step', linewidth=2)
ax3.hist(generated_samples[:, 1], bins=30, alpha=0.7, density=True, color='red', label='Generated X₂', histtype='step', linewidth=2)
ax3.set_xlabel('X₂ Value')
ax3.set_ylabel('Density')
ax3.set_title('X₂ Marginal Distribution')
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle('Final Quality Assessment', fontsize=14, fontweight='bold', y=1.02)
plt.show()


## 8. Circuit Visualization


In [None]:
# Import the circuit visualizer
from utils.quantum_circuit_visualizer import visualize_circuit

# Visualize the generator circuit
print("Generator Circuit:")
visualize_circuit(qgan.generator.circuit, style='compact')

# Visualize the discriminator circuit
print("\nDiscriminator Circuit:")
visualize_circuit(qgan.discriminator.circuit, style='compact')


## 9. Conclusion

This notebook demonstrated the training of a quantum GAN using our new modular architecture. The key advantages of this approach include:

1. **Clean Separation of Concerns**: Each component has a well-defined responsibility
2. **Pure Quantum Learning**: No classical neural networks in the quantum components
3. **Gradient Flow**: Confirmed gradient flow through all quantum parameters
4. **Modular Design**: Easy to swap components and experiment with different architectures
5. **Visualization**: Built-in circuit visualization for understanding the quantum architecture

The modular architecture provides a clean foundation for further quantum GAN research, with clear separation between the quantum circuit implementation, parameter management, and training logic.
