# File Location: notebooks/06_advanced_mechanics/13_manual_optimization_gan.ipynb

# Manual Optimization with GAN Implementation

This notebook demonstrates advanced manual optimization techniques in PyTorch Lightning using a Generative Adversarial Network (GAN) as the primary example. We'll explore multi-optimizer setups, manual backward passes, and custom training loops.

## Learning Objectives
- Implement manual optimization in PyTorch Lightning
- Build and train a GAN with separate optimizers
- Use manual_backward() for custom gradient computation
- Handle complex training dynamics with multiple loss functions

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import os
from collections import OrderedDict

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Lightning version: {pl.__version__}")
```

## 1. Understanding Manual Optimization

```python
class ManualOptimizationConcepts:
    """
    Manual optimization concepts in PyTorch Lightning:
    1. automatic_optimization = False
    2. Multiple optimizers for different model components
    3. manual_backward() for custom gradient computation
    4. Optimizer stepping control
    5. Learning rate scheduling management
    """
    
    @staticmethod
    def explain_benefits():
        benefits = {
            "Fine-grained Control": "Control when and how optimizers step",
            "Multiple Models": "Different optimization strategies per model",
            "Complex Losses": "Custom backward passes for complex loss functions",
            "Gradient Manipulation": "Custom gradient clipping, accumulation, etc.",
            "Research Flexibility": "Implement novel training algorithms"
        }
        
        for benefit, explanation in benefits.items():
            print(f"{benefit}: {explanation}")

ManualOptimizationConcepts.explain_benefits()
```

## 2. GAN Architecture Components

```python
class Generator(nn.Module):
    """Simple Generator for MNIST GAN"""
    
    def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
        super().__init__()
        self.img_shape = img_shape
        self.img_size = int(np.prod(img_shape))
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, self.img_size),
            nn.Tanh()
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

class Discriminator(nn.Module):
    """Simple Discriminator for MNIST GAN"""
    
    def __init__(self, img_shape=(1, 28, 28)):
        super().__init__()
        self.img_size = int(np.prod(img_shape))
        
        self.model = nn.Sequential(
            nn.Linear(self.img_size, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# Test the architectures
latent_dim = 100
generator = Generator(latent_dim)
discriminator = Discriminator()

print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
```

## 3. Manual Optimization GAN Implementation

```python
class ManualGAN(pl.LightningModule):
    """GAN with manual optimization for fine-grained control"""
    
    def __init__(self, latent_dim=100, lr=0.0002, b1=0.5, b2=0.999, img_shape=(1, 28, 28)):
        super().__init__()
        self.save_hyperparameters()
        
        # Networks
        self.generator = Generator(latent_dim, img_shape)
        self.discriminator = Discriminator(img_shape)
        
        # Loss function
        self.adversarial_loss = nn.BCELoss()
        
        # Important: Disable automatic optimization
        self.automatic_optimization = False
        
        # For logging
        self.generated_imgs = None
        
    def forward(self, z):
        return self.generator(z)
    
    def training_step(self, batch, batch_idx):
        imgs, _ = batch
        
        # Get optimizers (manually managed)
        opt_g, opt_d = self.optimizers()
        
        # Sample noise for generator
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim, device=self.device)
        
        # ---------------------
        #  Train Generator
        # ---------------------
        self.toggle_optimizer(opt_g)
        
        # Generate fake images
        fake_imgs = self(z)
        
        # Ground truth result (all fake)
        # We want discriminator to think these are real
        valid = torch.ones(imgs.size(0), 1, device=self.device)
        
        # Generator loss: fool the discriminator
        g_loss = self.adversarial_loss(self.discriminator(fake_imgs), valid)
        
        # Manual backward pass for generator
        self.manual_backward(g_loss)
        opt_g.step()
        opt_g.zero_grad()
        self.untoggle_optimizer(opt_g)
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        self.toggle_optimizer(opt_d)
        
        # Real images
        valid = torch.ones(imgs.size(0), 1, device=self.device)
        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
        
        # Fake images (detach to avoid training generator)
        fake = torch.zeros(imgs.size(0), 1, device=self.device)
        fake_imgs = self(z).detach()  # Detach to stop gradients to generator
        fake_loss = self.adversarial_loss(self.discriminator(fake_imgs), fake)
        
        # Total discriminator loss
        d_loss = (real_loss + fake_loss) / 2
        
        # Manual backward pass for discriminator
        self.manual_backward(d_loss)
        opt_d.step()
        opt_d.zero_grad()
        self.untoggle_optimizer(opt_d)
        
        # Logging
        self.log('g_loss', g_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('d_loss', d_loss, on_step=True, on_epoch=True, prog_bar=True)
        
        # Store generated images for visualization
        if batch_idx == 0:
            self.generated_imgs = fake_imgs[:16]
    
    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        
        return [opt_g, opt_d], []
    
    def on_epoch_end(self):
        # Generate and log images
        if self.generated_imgs is not None:
            grid = make_grid(self.generated_imgs, nrow=4, normalize=True)
            self.logger.experiment.add_image('generated_images', grid, self.current_epoch)

# Initialize the model
model = ManualGAN(latent_dim=100, lr=0.0002)
```

## 4. Advanced Manual Optimization Techniques

```python
class AdvancedManualGAN(ManualGAN):
    """Enhanced GAN with advanced manual optimization features"""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # Gradient penalties and clipping
        self.gradient_penalty_lambda = 10.0
        self.gradient_clip_val = 1.0
        
        # Learning rate scheduling
        self.lr_decay_step = 50
        self.lr_decay_factor = 0.5
        
        # Training balance
        self.d_steps_per_g_step = 1
        self.step_counter = 0
        
    def training_step(self, batch, batch_idx):
        imgs, _ = batch
        opt_g, opt_d = self.optimizers()
        sch_g, sch_d = self.lr_schedulers()
        
        # Sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim, device=self.device)
        
        # ---------------------
        #  Train Discriminator (potentially multiple times)
        # ---------------------
        d_loss_total = 0
        for _ in range(self.d_steps_per_g_step):
            self.toggle_optimizer(opt_d)
            
            # Real and fake losses
            valid = torch.ones(imgs.size(0), 1, device=self.device)
            fake = torch.zeros(imgs.size(0), 1, device=self.device)
            
            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
            fake_imgs = self(z).detach()
            fake_loss = self.adversarial_loss(self.discriminator(fake_imgs), fake)
            
            # Gradient penalty (for WGAN-GP style training)
            gp = self.compute_gradient_penalty(imgs, fake_imgs)
            
            d_loss = (real_loss + fake_loss) / 2 + self.gradient_penalty_lambda * gp
            d_loss_total += d_loss.item()
            
            # Manual backward with gradient clipping
            self.manual_backward(d_loss)
            
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(
                self.discriminator.parameters(), 
                self.gradient_clip_val
            )
            
            opt_d.step()
            opt_d.zero_grad()
            self.untoggle_optimizer(opt_d)
        
        # ---------------------
        #  Train Generator
        # ---------------------
        self.toggle_optimizer(opt_g)
        
        fake_imgs = self(z)
        valid = torch.ones(imgs.size(0), 1, device=self.device)
        g_loss = self.adversarial_loss(self.discriminator(fake_imgs), valid)
        
        # Manual backward with gradient clipping
        self.manual_backward(g_loss)
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(
            self.generator.parameters(), 
            self.gradient_clip_val
        )
        
        opt_g.step()
        opt_g.zero_grad()
        self.untoggle_optimizer(opt_g)
        
        # Update learning rate schedulers
        if self.step_counter % self.lr_decay_step == 0 and self.step_counter > 0:
            sch_g.step()
            sch_d.step()
        
        self.step_counter += 1
        
        # Logging
        self.log('g_loss', g_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('d_loss', d_loss_total / self.d_steps_per_g_step, on_step=True, on_epoch=True, prog_bar=True)
        self.log('gradient_penalty', gp, on_step=True, on_epoch=True)
        
        if batch_idx == 0:
            self.generated_imgs = fake_imgs[:16]
    
    def compute_gradient_penalty(self, real_samples, fake_samples):
        """Compute gradient penalty for improved training stability"""
        alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=self.device)
        
        # Interpolate between real and fake samples
        interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
        
        # Get discriminator output for interpolated samples
        d_interpolates = self.discriminator(interpolates)
        
        fake = torch.ones(d_interpolates.size(), device=self.device, requires_grad=False)
        
        # Compute gradients
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        
        return gradient_penalty
    
    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        
        # Learning rate schedulers
        sch_g = torch.optim.lr_scheduler.StepLR(opt_g, step_size=self.lr_decay_step, gamma=self.lr_decay_factor)
        sch_d = torch.optim.lr_scheduler.StepLR(opt_d, step_size=self.lr_decay_step, gamma=self.lr_decay_factor)
        
        return [opt_g, opt_d], [sch_g, sch_d]

# Initialize advanced model
advanced_model = AdvancedManualGAN(latent_dim=100, lr=0.0002)
```

## 5. Custom Training Loop with Manual Optimization

```python
class CustomTrainingGAN(pl.LightningModule):
    """GAN with completely custom training loop"""
    
    def __init__(self, latent_dim=100, lr=0.0002, img_shape=(1, 28, 28)):
        super().__init__()
        self.save_hyperparameters()
        
        self.generator = Generator(latent_dim, img_shape)
        self.discriminator = Discriminator(img_shape)
        self.adversarial_loss = nn.BCELoss()
        
        # Disable automatic optimization
        self.automatic_optimization = False
        
        # Custom training state
        self.training_state = {
            'g_losses': [],
            'd_losses': [],
            'epoch_g_loss': 0.0,
            'epoch_d_loss': 0.0,
            'batch_count': 0
        }
    
    def training_step(self, batch, batch_idx):
        return self.custom_training_step(batch, batch_idx)
    
    def custom_training_step(self, batch, batch_idx):
        """Completely custom training step with full control"""
        imgs, _ = batch
        batch_size = imgs.size(0)
        
        # Get optimizers
        opt_g, opt_d = self.optimizers()
        
        # ========================
        # Custom Training Logic
        # ========================
        
        # Phase 1: Warm-up discriminator (first 100 batches)
        if self.global_step < 100:
            d_loss = self.train_discriminator_only(imgs, opt_d)
            self.log('d_loss', d_loss, on_step=True, prog_bar=True)
            return
        
        # Phase 2: Alternating training with custom frequency
        if batch_idx % 3 == 0:  # Train generator every 3rd batch
            g_loss = self.train_generator_step(batch_size, opt_g)
            self.training_state['g_losses'].append(g_loss.item())
            self.training_state['epoch_g_loss'] += g_loss.item()
            self.log('g_loss', g_loss, on_step=True, prog_bar=True)
        
        # Always train discriminator
        d_loss = self.train_discriminator_step(imgs, batch_size, opt_d)
        self.training_state['d_losses'].append(d_loss.item())
        self.training_state['epoch_d_loss'] += d_loss.item()
        self.log('d_loss', d_loss, on_step=True, prog_bar=True)
        
        self.training_state['batch_count'] += 1
        
        # Log custom metrics
        if len(self.training_state['g_losses']) > 0:
            avg_g_loss = np.mean(self.training_state['g_losses'][-10:])  # Last 10 batches
            self.log('avg_g_loss_10', avg_g_loss, on_step=True)
        
        avg_d_loss = np.mean(self.training_state['d_losses'][-10:])
        self.log('avg_d_loss_10', avg_d_loss, on_step=True)
    
    def train_discriminator_only(self, real_imgs, opt_d):
        """Warm-up phase: train only discriminator"""
        self.toggle_optimizer(opt_d)
        
        batch_size = real_imgs.size(0)
        z = torch.randn(batch_size, self.hparams.latent_dim, device=self.device)
        
        # Real loss
        valid = torch.ones(batch_size, 1, device=self.device)
        real_loss = self.adversarial_loss(self.discriminator(real_imgs), valid)
        
        # Fake loss
        fake = torch.zeros(batch_size, 1, device=self.device)
        fake_imgs = self.generator(z).detach()
        fake_loss = self.adversarial_loss(self.discriminator(fake_imgs), fake)
        
        d_loss = (real_loss + fake_loss) / 2
        
        self.manual_backward(d_loss)
        opt_d.step()
        opt_d.zero_grad()
        self.untoggle_optimizer(opt_d)
        
        return d_loss
    
    def train_generator_step(self, batch_size, opt_g):
        """Custom generator training step"""
        self.toggle_optimizer(opt_g)
        
        z = torch.randn(batch_size, self.hparams.latent_dim, device=self.device)
        fake_imgs = self.generator(z)
        valid = torch.ones(batch_size, 1, device=self.device)
        
        g_loss = self.adversarial_loss(self.discriminator(fake_imgs), valid)
        
        self.manual_backward(g_loss)
        opt_g.step()
        opt_g.zero_grad()
        self.untoggle_optimizer(opt_g)
        
        return g_loss
    
    def train_discriminator_step(self, real_imgs, batch_size, opt_d):
        """Custom discriminator training step"""
        self.toggle_optimizer(opt_d)
        
        z = torch.randn(batch_size, self.hparams.latent_dim, device=self.device)
        
        # Real loss
        valid = torch.ones(batch_size, 1, device=self.device)
        real_loss = self.adversarial_loss(self.discriminator(real_imgs), valid)
        
        # Fake loss
        fake = torch.zeros(batch_size, 1, device=self.device)
        fake_imgs = self.generator(z).detach()
        fake_loss = self.adversarial_loss(self.discriminator(fake_imgs), fake)
        
        d_loss = (real_loss + fake_loss) / 2
        
        self.manual_backward(d_loss)
        opt_d.step()
        opt_d.zero_grad()
        self.untoggle_optimizer(opt_d)
        
        return d_loss
    
    def on_epoch_end(self):
        """Custom epoch end logic"""
        # Calculate epoch averages
        if self.training_state['batch_count'] > 0:
            avg_g_loss = self.training_state['epoch_g_loss'] / max(1, len(self.training_state['g_losses']))
            avg_d_loss = self.training_state['epoch_d_loss'] / self.training_state['batch_count']
            
            self.log('epoch_avg_g_loss', avg_g_loss)
            self.log('epoch_avg_d_loss', avg_d_loss)
            
            # Reset epoch counters
            self.training_state['epoch_g_loss'] = 0.0
            self.training_state['epoch_d_loss'] = 0.0
            self.training_state['batch_count'] = 0
        
        # Generate sample images
        z = torch.randn(16, self.hparams.latent_dim, device=self.device)
        with torch.no_grad():
            fake_imgs = self.generator(z)
            grid = make_grid(fake_imgs, nrow=4, normalize=True)
            self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
    
    def configure_optimizers(self):
        lr = self.hparams.lr
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
        return [opt_g, opt_d], []

# Initialize custom training model
custom_model = CustomTrainingGAN(latent_dim=100, lr=0.0002)
```

## 6. Data Module for GAN Training

```python
class GANDataModule(pl.LightningDataModule):
    """Data module for GAN training"""
    
    def __init__(self, batch_size=64, num_workers=4):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1] for tanh output
        ])
    
    def prepare_data(self):
        # Download MNIST
        torchvision.datasets.MNIST('./data', train=True, download=True)
        torchvision.datasets.MNIST('./data', train=False, download=True)
    
    def setup(self, stage=None):
        self.mnist_train = torchvision.datasets.MNIST('./data', train=True, transform=self.transform)
        self.mnist_test = torchvision.datasets.MNIST('./data', train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

# Initialize data module
data_module = GANDataModule(batch_size=64)
```

## 7. Training the Manual Optimization GAN

```python
# Configure trainer
trainer = pl.Trainer(
    max_epochs=100,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    log_every_n_steps=50,
    enable_checkpointing=True,
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            dirpath='./checkpoints/gan',
            filename='manual-gan-{epoch:02d}',
            save_top_k=3,
            monitor='g_loss',
            mode='min'
        )
    ]
)

# Train the model (choose one)
print("Training Manual GAN with custom optimization...")

# Option 1: Basic manual optimization
# trainer.fit(model, data_module)

# Option 2: Advanced manual optimization
# trainer.fit(advanced_model, data_module)

# Option 3: Custom training loop
trainer.fit(custom_model, data_module)

print("Training completed!")
```

## 8. Evaluation and Visualization

```python
def evaluate_gan_samples(model, num_samples=64):
    """Generate and visualize GAN samples"""
    model.eval()
    
    with torch.no_grad():
        z = torch.randn(num_samples, model.hparams.latent_dim, device=model.device)
        fake_images = model.generator(z)
        
        # Create grid
        grid = make_grid(fake_images, nrow=8, normalize=True, value_range=(-1, 1))
        
        # Convert to numpy for matplotlib
        grid_np = grid.cpu().numpy().transpose(1, 2, 0)
        
        plt.figure(figsize=(12, 12))
        plt.imshow(grid_np[:, :, 0], cmap='gray')
        plt.title(f'Generated MNIST Samples (Epoch {model.current_epoch})')
        plt.axis('off')
        plt.show()
    
    return fake_images

# Generate samples
if torch.cuda.is_available():
    custom_model = custom_model.cuda()

generated_samples = evaluate_gan_samples(custom_model, num_samples=64)
```

## 9. Manual Optimization Best Practices

```python
class ManualOptimizationBestPractices:
    """Best practices for manual optimization in PyTorch Lightning"""
    
    @staticmethod
    def demonstrate_best_practices():
        practices = {
            "Optimizer Management": [
                "Always use toggle_optimizer() and untoggle_optimizer()",
                "Call optimizer.zero_grad() after each step",
                "Use self.optimizers() to get optimizer list"
            ],
            "Gradient Management": [
                "Use self.manual_backward() instead of loss.backward()",
                "Implement gradient clipping for stability",
                "Be careful with gradient accumulation"
            ],
            "Learning Rate Scheduling": [
                "Handle scheduler stepping manually",
                "Use self.lr_schedulers() to get scheduler list",
                "Log learning rates for monitoring"
            ],
            "Multi-Model Training": [
                "Use separate optimizers for different models",
                "Balance training frequency between models",
                "Monitor loss ratios for training stability"
            ],
            "Logging and Monitoring": [
                "Use sync_dist=True for distributed training",
                "Log all relevant metrics",
                "Implement custom validation logic if needed"
            ]
        }
        
        for category, tips in practices.items():
            print(f"\n{category}:")
            for i, tip in enumerate(tips, 1):
                print(f"  {i}. {tip}")

ManualOptimizationBestPractices.demonstrate_best_practices()
```

## 10. Common Issues and Debugging

```python
class ManualOptimizationDebugging:
    """Common issues and debugging techniques for manual optimization"""
    
    @staticmethod
    def common_issues():
        issues = {
            "Gradient Accumulation": "Gradients not being cleared properly",
            "Optimizer State": "Forgetting to toggle optimizers",
            "Learning Rate": "Schedulers not stepping correctly",
            "Loss Scaling": "Inconsistent loss computation across models",
            "Memory Leaks": "Not detaching tensors when needed"
        }
        
        solutions = {
            "Gradient Accumulation": "Always call optimizer.zero_grad() after step",
            "Optimizer State": "Use toggle_optimizer() and untoggle_optimizer()",
            "Learning Rate": "Manually step schedulers when needed",
            "Loss Scaling": "Normalize losses properly for fair comparison",
            "Memory Leaks": "Use .detach() on tensors that shouldn't flow gradients"
        }
        
        print("Common Issues and Solutions:")
        for issue, description in issues.items():
            print(f"\nIssue: {issue}")
            print(f"  Description: {description}")
            print(f"  Solution: {solutions[issue]}")
    
    @staticmethod
    def debugging_checklist():
        checklist = [
            "Set automatic_optimization = False",
            "Use manual_backward() instead of loss.backward()",
            "Toggle optimizers before and after use",
            "Clear gradients after each optimizer step",
            "Handle schedulers manually",
            "Log all relevant metrics",
            "Use proper tensor detaching",
            "Validate gradient flow"
        ]
        
        print("\nDebugging Checklist:")
        for i, item in enumerate(checklist, 1):
            print(f"{i}. {item}")

debugger = ManualOptimizationDebugging()
debugger.common_issues()
debugger.debugging_checklist()
```

# Summary

This notebook demonstrated manual optimization techniques in PyTorch Lightning using GAN implementation as a comprehensive example. Key concepts covered:

## Manual Optimization Fundamentals
- **Disabled Automatic Optimization**: Setting `automatic_optimization = False`
- **Multi-Optimizer Management**: Handling separate optimizers for generator and discriminator
- **Manual Backward Passes**: Using `manual_backward()` for custom gradient computation
- **Optimizer Control**: Fine-grained control over when and how optimizers step

## Advanced Techniques Implemented
- **Gradient Penalties**: Custom gradient penalty computation for training stability
- **Gradient Clipping**: Manual gradient norm clipping
- **Custom Training Schedules**: Alternating training frequencies between models
- **Learning Rate Management**: Manual scheduler stepping and monitoring

## GAN-Specific Optimizations
- **Balanced Training**: Custom logic for generator vs discriminator training frequency
- **Gradient Flow Control**: Proper use of detach() to control gradient flow
- **Loss Monitoring**: Comprehensive logging of training dynamics
- **Warm-up Strategies**: Initial discriminator-only training phase

## Best Practices Established
- Always use optimizer toggling mechanisms
- Implement proper gradient management
- Handle learning rate schedulers manually
- Monitor training balance and stability
- Use proper tensor detaching for memory efficiency

## Next Steps
- Implement more sophisticated GAN variants (WGAN, Progressive GAN)
- Experiment with different manual optimization strategies
- Apply manual optimization to other multi-model architectures
- Explore advanced gradient manipulation techniques

Manual optimization provides the flexibility needed for complex training scenarios while maintaining the benefits of PyTorch Lightning's infrastructure.