# Conditional GAN (CGAN) Model Testing Notebook

This notebook tests the Conditional GAN (CGAN) model for urban crop yield prediction. CGANs are generative models that can create synthetic data conditioned on specific inputs, making them suitable for generating crop yield predictions under different conditions.

## 1. Setup and Imports

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import sys

# Add src to path for imports
sys.path.append('.')

# Import custom modules
from src.models.cgan import Generator, Discriminator
from src.data.dataset import UrbanCropDataset
from src.training import CGANTrainer
from src.utils import set_seed, get_device, ensure_dir

# Set random seed for reproducibility
set_seed(42)

# Check for CUDA
device = get_device()
print(f"Using device: {device}")

## 2. Configuration

In [None]:
# Configuration settings
config = {
    'data_dir': 'data/processed',  # Directory with processed data
    'model_dir': 'models',         # Directory to save trained models
    'visualize_dir': 'visualizations',  # Directory for visualizations
    'batch_size': 32,              # Batch size for training
    'epochs': 10,                  # Number of epochs (reduced for testing)
    'learning_rate': 0.0002,       # Learning rate for GANs typically lower
    'test_split': 0.2,             # Fraction of data for testing
    'val_split': 0.1,              # Fraction of training data for validation
    'z_size': 100,                 # Size of the latent space vector
    'class_num': 10,               # Number of condition classes
    'img_size': 32,                # Size of the generated images
    'generator_layer_size': [256, 512, 1024],  # Sizes of generator layers
    'discriminator_layer_size': [1024, 512, 256]  # Sizes of discriminator layers
}

# Create necessary directories
ensure_dir(config['model_dir'])
cgan_samples_dir = os.path.join(config['visualize_dir'], 'cgan_samples')
ensure_dir(cgan_samples_dir)

## 3. Load Data

For this notebook, we assume that the data has already been processed and saved as numpy arrays. If not, you'll need to run the data processing scripts first.

In [None]:
# Load processed data
def load_data(config):
    """Load preprocessed data"""
    x_path = os.path.join(config['data_dir'], 'years_array_32_segmented_prevUrb.npy')
    y_path = os.path.join(config['data_dir'], 'crops_array_32_segmented_prevUrb.npy')
    
    if os.path.exists(x_path) and os.path.exists(y_path):
        print("Loading preprocessed data...")
        X_data = np.load(x_path)
        y_data = np.load(y_path)
        print(f"X_data shape: {X_data.shape}")
        print(f"y_data shape: {y_data.shape}")
        return X_data, y_data
    else:
        raise FileNotFoundError(f"Preprocessed data not found at {x_path} and {y_path}. Please run data preprocessing first.")

try:
    X_data, y_data = load_data(config)
    
    # Sample visualization of the data
    plt.figure(figsize=(15, 5))
    
    # Input features (first sample, first 3 channels)
    plt.subplot(1, 2, 1)
    plt.imshow(np.transpose(X_data[0][:3], (1, 2, 0)))
    plt.title('Input Features (First 3 Channels)')
    plt.axis('off')
    
    # Target output
    plt.subplot(1, 2, 2)
    plt.imshow(np.transpose(y_data[0], (1, 2, 0)))
    plt.title('Target Output')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Using dummy data for demonstration purposes...")
    # Create dummy data for demonstration
    X_data = np.random.rand(100, 15, 32, 32)  # [samples, channels, height, width]
    y_data = np.random.rand(100, 3, 32, 32)   # [samples, channels, height, width]
    print(f"Dummy X_data shape: {X_data.shape}")
    print(f"Dummy y_data shape: {y_data.shape}")

## 4. Preprocess Data for CGAN

For CGAN, we need to prepare the data differently. We'll use the first channel of the output as the target image, and create condition labels from the input data features.

In [None]:
def preprocess_for_cgan(X_data, y_data, config):
    """Preprocess data for CGAN training"""
    # For simplicity, we'll use the first channel of y_data as the target image
    target_images = y_data[:, 0, :, :].reshape(-1, 1, config['img_size'], config['img_size'])
    
    # Create condition labels from the input features
    # We'll use a simple approach - cluster the mean of input features into class_num clusters
    feature_means = np.mean(X_data, axis=(1, 2, 3))
    
    # Create class labels by binning the feature means
    bins = np.linspace(feature_means.min(), feature_means.max(), config['class_num'] + 1)
    condition_labels = np.digitize(feature_means, bins) - 1
    condition_labels = np.clip(condition_labels, 0, config['class_num'] - 1)
    
    print(f"Target images shape: {target_images.shape}")
    print(f"Condition labels shape: {condition_labels.shape}")
    print(f"Condition label distribution: {np.bincount(condition_labels)}")
    
    return target_images, condition_labels

# Preprocess data for CGAN
target_images, condition_labels = preprocess_for_cgan(X_data, y_data, config)

## 5. Create Data Loaders

In [None]:
def create_data_loaders(target_images, condition_labels, config):
    """Create train, validation, and test data loaders"""
    # Convert to tensor dataset
    images_tensor = torch.tensor(target_images, dtype=torch.float32)
    labels_tensor = torch.tensor(condition_labels, dtype=torch.long)
    
    # Create custom dataset
    class CGANDataset(torch.utils.data.Dataset):
        def __init__(self, images, labels):
            self.images = images
            self.labels = labels
        
        def __len__(self):
            return len(self.images)
        
        def __getitem__(self, idx):
            return self.images[idx], self.labels[idx]
    
    dataset = CGANDataset(images_tensor, labels_tensor)
    
    # Determine the number of samples for each split
    total_samples = len(dataset)
    test_size = int(total_samples * config['test_split'])
    train_size = total_samples - test_size
    val_size = int(train_size * config['val_split'])
    train_size = train_size - val_size
    
    print(f"Total samples: {total_samples}")
    print(f"Training samples: {train_size}")
    print(f"Validation samples: {val_size}")
    print(f"Test samples: {test_size}")
    
    # Split into train, validation, and test sets
    train_dataset, test_dataset = random_split(
        dataset, [train_size + val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_dataset, val_dataset = random_split(
        train_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2
    )
    
    return train_loader, val_loader, test_loader

# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(target_images, condition_labels, config)

# Verify the data
for images, labels in train_loader:
    print(f"Batch shapes:")
    print(f"Images shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")
    
    # Display a few samples
    plt.figure(figsize=(15, 4))
    for i in range(5):
        plt.subplot(1, 5, i+1)
        plt.imshow(images[i, 0].numpy(), cmap='gray')
        plt.title(f"Class {labels[i].item()}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()
    break

## 6. Initialize CGAN Model

In [None]:
def initialize_cgan(config):
    """Initialize CGAN models and trainer"""
    # Create generator
    generator = Generator(
        z_size=config['z_size'],
        class_num=config['class_num'],
        generator_layer_size=config['generator_layer_size'],
        img_size=config['img_size']
    )
    
    # Create discriminator
    discriminator = Discriminator(
        img_size=config['img_size'],
        class_num=config['class_num'],
        discriminator_layer_size=config['discriminator_layer_size']
    )
    
    print(f"Generator created with {sum(p.numel() for p in generator.parameters() if p.requires_grad)} trainable parameters")
    print(f"Discriminator created with {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)} trainable parameters")
    
    # Create trainer
    trainer = CGANTrainer(
        generator, 
        discriminator, 
        z_size=config['z_size'], 
        class_num=config['class_num'], 
        device=device
    )
    
    trainer.compile(optimizer='adam', learning_rate=config['learning_rate'])
    
    return generator, discriminator, trainer

# Initialize CGAN model
generator, discriminator, trainer = initialize_cgan(config)

## 7. Train the CGAN Model

Train the CGAN model. Note that GANs are known to be difficult to train and may require careful tuning.

In [None]:
def train_cgan(trainer, train_loader, config):
    """Train the CGAN model"""
    # Create output directory for samples
    samples_dir = os.path.join(config['visualize_dir'], 'cgan_samples')
    ensure_dir(samples_dir)
    
    # Train the model
    print(f"Starting training for {config['epochs']} epochs...")
    
    try:
        history = trainer.fit(
            train_loader,
            epochs=config['epochs'],
            sample_interval=1,  # Generate samples every epoch
            save_dir=samples_dir
        )
    except KeyboardInterrupt:
        print("Training interrupted by user")
    
    # Save the models
    model_dir = os.path.join(config['model_dir'], 'cgan')
    trainer.save_model(model_dir)
    
    return trainer

# Train the CGAN model (uncomment to run training)
# trainer = train_cgan(trainer, train_loader, config)

## 8. Plot Training History

After training, plot the generator and discriminator losses over epochs.

In [None]:
def plot_history(trainer, config):
    """Plot training history"""
    history_path = os.path.join(config['visualize_dir'], 'cgan_history.png')
    trainer.plot_history(save_path=history_path)
    
# If you've trained the model, uncomment to plot the history
# plot_history(trainer, config)

## 9. Load Trained Model

In [None]:
def load_trained_cgan(config):
    """Load a trained CGAN model"""
    # Initialize models
    generator = Generator(
        z_size=config['z_size'],
        class_num=config['class_num'],
        generator_layer_size=config['generator_layer_size'],
        img_size=config['img_size']
    )
    
    discriminator = Discriminator(
        img_size=config['img_size'],
        class_num=config['class_num'],
        discriminator_layer_size=config['discriminator_layer_size']
    )
    
    # Create trainer
    trainer = CGANTrainer(
        generator, 
        discriminator, 
        z_size=config['z_size'], 
        class_num=config['class_num'], 
        device=device
    )
    
    trainer.compile(optimizer='adam', learning_rate=config['learning_rate'])
    
    # Load model
    model_dir = os.path.join(config['model_dir'], 'cgan')
    if os.path.exists(model_dir):
        print(f"Loading CGAN model from {model_dir}")
        trainer.load_model(model_dir)
    else:
        print(f"Trained model not found at {model_dir}. Using untrained model.")
    
    return generator, discriminator, trainer

# Load trained model
# Uncomment if you have a trained model
# generator, discriminator, trainer = load_trained_cgan(config)

## 10. Generate Samples with CGAN

Generate samples from the trained CGAN for each class condition.

In [None]:
def generate_samples(generator, config):
    """Generate samples from the CGAN for each class"""
    generator.eval()
    generator.to(device)
    
    # Generate 5 samples for each class
    n_samples_per_class = 5
    n_classes = config['class_num']
    
    with torch.no_grad():
        # Create a grid of samples
        fig, axs = plt.subplots(n_classes, n_samples_per_class, figsize=(15, 3*n_classes))
        
        for class_idx in range(n_classes):
            # Generate random noise
            z = torch.randn(n_samples_per_class, config['z_size']).to(device)
            
            # Create labels for this class
            labels = torch.full((n_samples_per_class,), class_idx, dtype=torch.long).to(device)
            
            # Generate images
            fake_images = generator(z, labels)
            fake_images = fake_images.cpu().detach()
            
            # Plot
            for i in range(n_samples_per_class):
                ax = axs[class_idx, i]
                img = fake_images[i, 0].numpy()
                ax.imshow(img, cmap='gray')
                if i == 0:
                    ax.set_ylabel(f'Class {class_idx}')
                ax.axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(config['visualize_dir'], 'cgan_generated_samples.png'))
        plt.show()

# Generate samples
# Uncomment after loading a trained model
# generate_samples(generator, config)

## 11. Compare CGAN Samples with Real Data

Compare the generated samples with real data for each class.

In [None]:
def compare_real_vs_generated(generator, target_images, condition_labels, config):
    """Compare real data with generated samples for each class"""
    generator.eval()
    generator.to(device)
    
    n_classes = config['class_num']
    
    # Convert to numpy for easier handling
    target_images_np = target_images
    condition_labels_np = condition_labels
    
    with torch.no_grad():
        plt.figure(figsize=(15, 3*n_classes))
        
        for class_idx in range(n_classes):
            # Get real images for this class
            class_indices = np.where(condition_labels_np == class_idx)[0]
            
            if len(class_indices) == 0:
                print(f"No real samples for class {class_idx}")
                continue
                
            # Take up to 3 real samples
            n_real_samples = min(3, len(class_indices))
            real_indices = class_indices[:n_real_samples]
            real_samples = target_images_np[real_indices]
            
            # Generate 3 fake samples
            z = torch.randn(3, config['z_size']).to(device)
            labels = torch.full((3,), class_idx, dtype=torch.long).to(device)
            fake_samples = generator(z, labels).cpu().detach().numpy()
            
            # Plot real samples
            for i in range(n_real_samples):
                plt.subplot(n_classes, 6, class_idx*6 + i + 1)
                plt.imshow(real_samples[i, 0], cmap='gray')
                plt.title(f"Real {i+1}" if i > 0 else f"Real Class {class_idx}")
                plt.axis('off')
            
            # Fill empty spaces if fewer than 3 real samples
            for i in range(n_real_samples, 3):
                plt.subplot(n_classes, 6, class_idx*6 + i + 1)
                plt.axis('off')
            
            # Plot generated samples
            for i in range(3):
                plt.subplot(n_classes, 6, class_idx*6 + i + 4)
                plt.imshow(fake_samples[i, 0], cmap='gray')
                plt.title(f"Generated {i+1}" if i > 0 else f"Generated Class {class_idx}")
                plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(config['visualize_dir'], 'cgan_real_vs_generated.png'))
        plt.show()

# Compare real vs generated samples
# Uncomment after loading a trained model
# compare_real_vs_generated(generator, target_images, condition_labels, config)

## 12. Generate Interpolated Samples

Generate samples by interpolating in the latent space to see smooth transitions.

In [None]:
def generate_interpolations(generator, config):
    """Generate samples by interpolating in the latent space"""
    generator.eval()
    generator.to(device)
    
    with torch.no_grad():
        # Interpolate between two random points in the latent space
        z_start = torch.randn(1, config['z_size']).to(device)
        z_end = torch.randn(1, config['z_size']).to(device)
        
        # Create 10 interpolation steps
        n_steps = 10
        interpolations = []
        
        # Choose a class
        class_idx = np.random.randint(0, config['class_num'])
        label = torch.full((1,), class_idx, dtype=torch.long).to(device)
        
        plt.figure(figsize=(15, 3))
        plt.suptitle(f"Latent Space Interpolation for Class {class_idx}")
        
        for i in range(n_steps):
            # Linear interpolation
            alpha = i / (n_steps - 1)
            z_interp = z_start * (1 - alpha) + z_end * alpha
            
            # Generate image
            fake_image = generator(z_interp, label)
            fake_image = fake_image.cpu().detach().numpy()
            
            # Plot
            plt.subplot(1, n_steps, i+1)
            plt.imshow(fake_image[0, 0], cmap='gray')
            plt.title(f"Step {i+1}")
            plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(config['visualize_dir'], 'cgan_interpolation.png'))
        plt.show()
        
        # Interpolate between classes with fixed latent vector
        z_fixed = torch.randn(1, config['z_size']).to(device)
        
        plt.figure(figsize=(15, 3))
        plt.suptitle("Class Conditioning Interpolation with Fixed Latent Vector")
        
        for i in range(n_steps):
            # Linear interpolation between classes (in embedding space)
            class_idx = i % config['class_num']
            label = torch.full((1,), class_idx, dtype=torch.long).to(device)
            
            # Generate image
            fake_image = generator(z_fixed, label)
            fake_image = fake_image.cpu().detach().numpy()
            
            # Plot
            plt.subplot(1, n_steps, i+1)
            plt.imshow(fake_image[0, 0], cmap='gray')
            plt.title(f"Class {class_idx}")
            plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(config['visualize_dir'], 'cgan_class_interpolation.png'))
        plt.show()

# Generate interpolated samples
# Uncomment after loading a trained model
# generate_interpolations(generator, config)

## 13. Evaluate the Discriminator on Real and Generated Samples

Test how well the discriminator can tell apart real and generated samples.

In [None]:
def evaluate_discriminator(generator, discriminator, test_loader, config):
    """Evaluate the discriminator on real and generated samples"""
    generator.eval()
    discriminator.eval()
    generator.to(device)
    discriminator.to(device)
    
    real_scores = []
    fake_scores = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            # Evaluate on real samples
            images = images.to(device)
            labels = labels.to(device)
            real_validity = discriminator(images, labels)
            real_scores.extend(real_validity.cpu().numpy())
            
            # Generate fake samples with the same labels
            z = torch.randn(labels.size(0), config['z_size']).to(device)
            fake_images = generator(z, labels)
            fake_validity = discriminator(fake_images, labels)
            fake_scores.extend(fake_validity.cpu().numpy())
            
            # Limit the evaluation to a few batches
            if len(real_scores) > 200:
                break
    
    # Convert to numpy arrays
    real_scores = np.array(real_scores)
    fake_scores = np.array(fake_scores)
    
    # Plot histograms of the scores
    plt.figure(figsize=(12, 6))
    plt.hist(real_scores, bins=20, alpha=0.5, label='Real Samples')
    plt.hist(fake_scores, bins=20, alpha=0.5, label='Generated Samples')
    plt.xlabel('Discriminator Score')
    plt.ylabel('Count')
    plt.title('Discriminator Evaluation')
    plt.legend()
    plt.savefig(os.path.join(config['visualize_dir'], 'cgan_discriminator_evaluation.png'))
    plt.show()
    
    # Calculate statistics
    print(f"Real samples - Mean score: {real_scores.mean():.4f}, Std: {real_scores.std():.4f}")
    print(f"Fake samples - Mean score: {fake_scores.mean():.4f}, Std: {fake_scores.std():.4f}")
    
    # Calculate accuracy
    real_correct = (real_scores > 0.5).sum()
    fake_correct = (fake_scores < 0.5).sum()
    total = len(real_scores) + len(fake_scores)
    accuracy = (real_correct + fake_correct) / total
    
    print(f"Discriminator accuracy: {accuracy:.4f}")
    print(f"Real samples correctly identified: {real_correct / len(real_scores):.4f}")
    print(f"Fake samples correctly identified: {fake_correct / len(fake_scores):.4f}")

# Evaluate discriminator
# Uncomment after loading a trained model
# evaluate_discriminator(generator, discriminator, test_loader, config)

## 14. Conclusion

This notebook demonstrated how to train and evaluate a Conditional GAN (CGAN) for urban crop yield prediction. CGANs can generate synthetic data conditioned on specific inputs, which can be useful for data augmentation, exploring different scenarios, or generating predictions under various conditions.

Training GANs is known to be challenging and may require careful tuning of hyperparameters, so expect some experimentation to get good results. For better performance, you might want to try:

1. Using more sophisticated architectures
2. Implementing spectral normalization or other stabilization techniques
3. Experimenting with different loss functions (e.g., Wasserstein GAN)
4. Training for more epochs with learning rate scheduling