In [1]:
# "generation-engine_XYZ[dot]ipynb"

In [None]:
# "generation-engine_XYZ[dot]ipynb"

# -*- coding: utf-8 -*-
"""
Generation Engine for Synthetic Pollen Image Generation
Script: generation-engine_00.py

This enhanced script generates synthetic pollen images by:
1. Loading the trained WGAN-SN generator and critic models
2. Applying quality filtering using the critic to keep only high-quality samples
3. Generating synthetic pollen images (base 128x128) with controlled augmentations
4. Optionally placing these images on background patches (640x640) with advanced blending
5. Creating YOLO annotation files for object detection
6. Analyzing and visualizing the generation results
7. Creating comprehensive reports with statistics

Improvements over the original composition-engine:
- GAN quality filtering using the critic to remove low-quality samples
- Advanced blending methods (Poisson, Pyramid)
- Fixed geometric transformations (proper scaling without dimension reset)
- Dataset statistics matching for realistic distributions
- Conditional pipeline execution via boolean flags
- Memory management for model loading
- Configuration saving and loading
- Enhanced reporting and visualization
- Parallelized processing with memory management
"""

print("!!! setting the 'TF_ENABLE_ONEDNN_OPTS' value to '0' for avoiding the 'oneDNN custom operations' message in powershell console !!!")
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import math
import sys
import json
import random
import logging
import time
import gc
import traceback
import psutil
import threading
import warnings
import pickle
import subprocess
from datetime import datetime
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image, ImageOps
import torch
import torch.nn as nn
from multiprocessing import Pool, Manager, Queue, cpu_count as mp_cpu_count

# --- Add these imports ---
import torch.nn.functional as F
#from math import ceil # made redundant

# --- Import components for feature extraction and evaluation ---
try:
    from pytorch_fid.inception import InceptionV3
    from pytorch_fid.fid_score import calculate_frechet_distance # Keep if FID/KID needed
    FID_AVAILABLE = True
    print("Successfully imported pytorch_fid components.")
except ImportError:
    print("WARNING: pytorch-fid not found. Feature extraction and FID/KID/PRDC analysis will be disabled.")
    FID_AVAILABLE = False

try:
    from sklearn.neighbors import NearestNeighbors
    NEIGHBORS_AVAILABLE = True
    print("Successfully imported NearestNeighbors for PRDC.")
except ImportError:
    print("WARNING: scikit-learn NearestNeighbors not found. Custom P/R/D/C will be disabled.")
    NEIGHBORS_AVAILABLE = False

# Optional torch-fidelity import
try:
    import torch_fidelity
    TORCH_FIDELITY_AVAILABLE = True
    print("Successfully imported torch-fidelity.")
except ImportError:
    print("INFO: torch-fidelity not found. Will use custom P/R/D/C implementation if available.")
    TORCH_FIDELITY_AVAILABLE = False

# Ensure sklearn manifold and preprocessing are imported (likely already there)
try:
    from sklearn.manifold import TSNE
    from sklearn.preprocessing import StandardScaler # We might remove its usage later for consistency
    TSNE_AVAILABLE = True
except ImportError:
    print("WARNING: scikit-learn TSNE/StandardScaler not found.")
    TSNE_AVAILABLE = False

# Ensure UMAP is checked (already done in the original script)
# HAS_UMAP flag should be set correctly

# --- Make sure these are also present ---

import torchvision.transforms as transforms # Needed for image loading/processing

# --- Add these imports ---
import torch.nn.functional as F
#from math import ceil # made redundant
# Add KID calculation dependency
try:
    from sklearn.metrics.pairwise import polynomial_kernel
    SKLEARN_KERNELS_AVAILABLE = True
    print("Successfully imported sklearn kernels for KID.")
except ImportError:
    print("WARNING: scikit-learn polynomial_kernel not found. Custom KID calculation will be disabled.")
    SKLEARN_KERNELS_AVAILABLE = False
# Add scipy if needed for matrix sqrt in FID
try:
    from scipy import linalg
    SCIPY_LINALG_AVAILABLE = True
except ImportError:
    print("WARNING: SciPy linalg not found. FID calculation might fail for non-positive definite covariance matrices.")
    SCIPY_LINALG_AVAILABLE = False

def create_elliptical_mask(img_pil, blur_amount=10):
    """
    Creates a soft, centered elliptical mask matching the input image dimensions.

    Args:
        img_pil (PIL.Image.Image): The PIL image defining the mask dimensions.
        blur_amount (int):          Controls the softness of the mask edge (Gaussian blur sigma).

    Returns:
        PIL.Image.Image: A grayscale PIL Image mask (mode 'L'), or None on error.
    """
    logger = logging.getLogger("GenerationEngine") # Assumes logger is configured
    try:
        width, height = img_pil.size
        if width <= 0 or height <= 0:
             logger.warning("Cannot create elliptical mask for zero-sized image.")
             return None

        # Create a black background using numpy
        mask_np = np.zeros((height, width), dtype=np.uint8)

        # Define ellipse parameters (centered, slightly smaller than image bounds)
        center_x = width // 2
        center_y = height // 2
        # Make axes slightly smaller than half-dimensions to avoid sharp cutoff at edges
        axis_x = max(1, int(center_x * 0.95))
        axis_y = max(1, int(center_y * 0.95))

        # Draw the filled white ellipse using OpenCV
        cv2.ellipse(mask_np, center=(center_x, center_y), axes=(axis_x, axis_y),
                    angle=0, startAngle=0, endAngle=360, color=255, thickness=-1)

        # Apply Gaussian blur for soft edges
        # Kernel size must be odd and positive
        k_size = blur_amount * 2 + 1
        if k_size < 1: k_size = 1 # Ensure positive kernel size
        
        # Apply blur only if blur_amount > 0
        if blur_amount > 0:
            blurred_mask = cv2.GaussianBlur(mask_np, (k_size, k_size), 0)
        else:
            blurred_mask = mask_np # No blur if amount is 0 or less

        # Convert the final numpy array back to PIL Image
        mask_pil = Image.fromarray(blurred_mask, mode='L')
        # logger.debug(f"Created elliptical mask ({width}x{height}) with blur amount {blur_amount}") # Optional debug log
        return mask_pil

    except Exception as e:
        logger.error(f"Error creating elliptical mask: {e}", exc_info=True)
        return None # Return None on error
    
# --- NEW FID/KID Calculation Function ---

def calculate_fid_kid(real_features, fake_features, config, logger, eps=1e-6):
    """
    Calculate FID and KID between two sets of features.
    KID calculation is now chunked to save memory.

    Args:
        real_features (np.ndarray): Features from real images (N_real, D_features).
        fake_features (np.ndarray): Features from fake images (N_fake, D_features).
        config (Config): Configuration object.
        logger (logging.Logger): Logger instance.
        eps (float): Small epsilon value for numerical stability.

    Returns:
        dict: Dictionary containing 'fid' and 'kid' scores, or None if calculation fails.
    """
    results = {'fid': None, 'kid': None}
    min_samples = min(len(real_features), len(fake_features))

    if not FID_AVAILABLE:
        logger.warning("pytorch-fid not available, skipping FID calculation.")
    
    if not SKLEARN_KERNELS_AVAILABLE:
        logger.warning("sklearn kernels not available, skipping KID calculation.")

    if min_samples < 10: # Need at least a few samples
        logger.warning(f"Too few samples ({min_samples}) to reliably calculate FID/KID.")
        return results
    
    try:
        # --- FID Calculation --- (Keep existing code)
        logger.debug(f"Calculating FID using {min_samples} samples per set.")
        # Use the same number of samples from each set for fair comparison
        real_f = real_features[:min_samples]
        fake_f = fake_features[:min_samples]

        mu_real = np.mean(real_f, axis=0)
        sigma_real = np.cov(real_f, rowvar=False)
        mu_fake = np.mean(fake_f, axis=0)
        sigma_fake = np.cov(fake_f, rowvar=False)

        # Use calculate_frechet_distance from pytorch-fid
        fid_value = calculate_frechet_distance(mu_real, sigma_real, mu_fake, sigma_fake, eps=eps)
        results['fid'] = float(fid_value)
        logger.info(f"Calculated FID: {results['fid']:.3f}")

        # Clean up FID calculation variables to free memory before KID calculation
        del mu_real, sigma_real, mu_fake, sigma_fake
        gc.collect()
        
    except Exception as e:
        logger.error(f"Error calculating FID: {e}", exc_info=True)
        results['fid'] = None

    try:
        # --- NEW: CHUNKED KID Calculation ---
        if SKLEARN_KERNELS_AVAILABLE:
            logger.info(f"Calculating KID using chunked approach to save memory")
            
            # Use same feature subsets from FID calculation
            real_f_kid = real_features[:min_samples]
            fake_f_kid = fake_features[:min_samples]
            
            # Define chunk size - adjust based on available memory
            kid_chunk_size = 2000  # Try 1000-5000 depending on available RAM
            
            # Kernel parameters (standard for KID)
            degree = 3
            gamma = None  # Defaults to 1/n_features in sklearn
            coef0 = 1
            
            # Initialize accumulators for kernel sums
            sum_k_rr = 0.0
            sum_k_ff = 0.0
            sum_k_rf = 0.0
            
            num_chunks = math.ceil(min_samples / kid_chunk_size)
            logger.info(f"Processing KID in {num_chunks} chunks (size: {kid_chunk_size})...")
            
            # --- Calculate K_real_real sum (chunked) ---
            with tqdm(total=min_samples, desc="KID Real-Real") as pbar_rr:
                for i in range(0, min_samples, kid_chunk_size):
                    chunk_end = min(i + kid_chunk_size, min_samples)
                    real_chunk = real_f_kid[i:chunk_end]
                    chunk_size = len(real_chunk)
                    
                    if chunk_size == 0:
                        continue
                        
                    # Calculate kernel for this chunk vs all real features
                    k_rr_chunk = polynomial_kernel(
                        real_chunk, real_f_kid, 
                        degree=degree, gamma=gamma, coef0=coef0
                    )
                    
                    # Add to running sum
                    sum_k_rr += np.sum(k_rr_chunk)
                    
                    # Update progress
                    pbar_rr.update(chunk_size)
                    
                    # Clean up chunk variables
                    del k_rr_chunk, real_chunk
                    if i % (2 * kid_chunk_size) == 0:  # Less frequent cleanup
                        gc.collect()
            
            # --- Calculate K_fake_fake sum (chunked) ---
            with tqdm(total=min_samples, desc="KID Fake-Fake") as pbar_ff:
                for i in range(0, min_samples, kid_chunk_size):
                    chunk_end = min(i + kid_chunk_size, min_samples)
                    fake_chunk = fake_f_kid[i:chunk_end]
                    chunk_size = len(fake_chunk)
                    
                    if chunk_size == 0:
                        continue
                        
                    # Calculate kernel for this chunk vs all fake features
                    k_ff_chunk = polynomial_kernel(
                        fake_chunk, fake_f_kid, 
                        degree=degree, gamma=gamma, coef0=coef0
                    )
                    
                    # Add to running sum
                    sum_k_ff += np.sum(k_ff_chunk)
                    
                    # Update progress
                    pbar_ff.update(chunk_size)
                    
                    # Clean up chunk variables
                    del k_ff_chunk, fake_chunk
                    if i % (2 * kid_chunk_size) == 0:  # Less frequent cleanup
                        gc.collect()
            
            # --- Calculate K_real_fake sum (chunked) ---
            with tqdm(total=min_samples, desc="KID Real-Fake") as pbar_rf:
                for i in range(0, min_samples, kid_chunk_size):
                    chunk_end = min(i + kid_chunk_size, min_samples)
                    real_chunk = real_f_kid[i:chunk_end]
                    chunk_size = len(real_chunk)
                    
                    if chunk_size == 0:
                        continue
                        
                    # Calculate kernel for this chunk vs all fake features
                    k_rf_chunk = polynomial_kernel(
                        real_chunk, fake_f_kid, 
                        degree=degree, gamma=gamma, coef0=coef0
                    )
                    
                    # Add to running sum
                    sum_k_rf += np.sum(k_rf_chunk)
                    
                    # Update progress
                    pbar_rf.update(chunk_size)
                    
                    # Clean up chunk variables
                    del k_rf_chunk, real_chunk
                    if i % (2 * kid_chunk_size) == 0:  # Less frequent cleanup
                        gc.collect()
            
            # Calculate final means and KID value
            n_elements = min_samples * min_samples
            mean_k_rr = sum_k_rr / n_elements
            mean_k_ff = sum_k_ff / n_elements
            mean_k_rf = sum_k_rf / n_elements
            
            # Calculate MMD^2
            mmd2 = mean_k_rr + mean_k_ff - 2 * mean_k_rf
            
            # KID is often reported as MMD * 100 (standard scaling)
            kid_value = np.sqrt(max(0, mmd2)) * 100
            results['kid'] = float(kid_value)
            logger.info(f"Calculated KID: {results['kid']:.3f}")
            
            # Clean up KID calculation variables
            del real_f_kid, fake_f_kid
        else:
             results['kid'] = None
             
    except Exception as e:
        logger.error(f"Error calculating KID: {e}", exc_info=True)
        results['kid'] = None

    # Final cleanup
    force_memory_cleanup(config)
    
    return results

# --- NEW FID/KID Plotting Function ---
def plot_fid_kid_comparison(analysis_results, output_path, logger):
    """
    Creates a bar chart comparing FID and KID scores for filtered vs unfiltered generated images.
    """
    logger.info(f"Generating FID/KID comparison plot: {output_path}")
    try:
        fid_kid_f = analysis_results.get('fid_kid_filtered')
        fid_kid_u = analysis_results.get('fid_kid_unfiltered')

        if not fid_kid_f and not fid_kid_u:
            logger.warning("No FID/KID data available for plotting.")
            return False

        labels = ['Filtered', 'Unfiltered']
        fid_scores = [fid_kid_f.get('fid') if fid_kid_f else None,
                      fid_kid_u.get('fid') if fid_kid_u else None]
        kid_scores = [fid_kid_f.get('kid') if fid_kid_f else None,
                      fid_kid_u.get('kid') if fid_kid_u else None]

        # Filter out None values for plotting
        valid_labels_fid = [labels[i] for i, v in enumerate(fid_scores) if v is not None]
        valid_fid = [v for v in fid_scores if v is not None]
        valid_labels_kid = [labels[i] for i, v in enumerate(kid_scores) if v is not None]
        valid_kid = [v for v in kid_scores if v is not None]

        if not valid_fid and not valid_kid:
             logger.warning("No valid FID or KID scores found to plot.")
             return False

        plt.style.use('seaborn-v0_8-darkgrid')
        fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=False) # Separate y-axes might be needed
        fig.suptitle('FID and KID Comparison (Lower is Better)', fontsize=16)

        # --- FID Plot ---
        if valid_fid:
            ax = axes[0]
            bars = ax.bar(valid_labels_fid, valid_fid, color=['orange', 'lightgreen'])
            ax.set_ylabel('FID Score')
            ax.set_title('Fréchet Inception Distance (FID)')
            ax.bar_label(bars, fmt='%.3f')
            # Dynamically set ylim based on data range
            if valid_fid:
                 min_fid = min(valid_fid)
                 max_fid = max(valid_fid)
                 ax.set_ylim(max(0, min_fid * 0.9), max_fid * 1.1)

        else:
            axes[0].text(0.5, 0.5, 'FID Scores Not Available', horizontalalignment='center', verticalalignment='center', transform=axes[0].transAxes)
            axes[0].set_title('Fréchet Inception Distance (FID)')


        # --- KID Plot ---
        if valid_kid:
            ax = axes[1]
            bars = ax.bar(valid_labels_kid, valid_kid, color=['orange', 'lightgreen'])
            ax.set_ylabel('KID Score (x100)')
            ax.set_title('Kernel Inception Distance (KID)')
            ax.bar_label(bars, fmt='%.3f')
            # Dynamically set ylim based on data range
            if valid_kid:
                 min_kid = min(valid_kid)
                 max_kid = max(valid_kid)
                 ax.set_ylim(max(0, min_kid * 0.9), max_kid * 1.1)
        else:
            axes[1].text(0.5, 0.5, 'KID Scores Not Available', horizontalalignment='center', verticalalignment='center', transform=axes[1].transAxes)
            axes[1].set_title('Kernel Inception Distance (KID)')

        plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to make room for suptitle
        plt.savefig(output_path, dpi=300)
        plt.close(fig)
        logger.info(f"Saved FID/KID comparison plot to {output_path}")
        return True

    except Exception as e:
        logger.error(f"Error creating FID/KID comparison plot: {e}", exc_info=True)
        # Attempt to close plot even if saving failed
        try: plt.close(fig)
        except: pass
        return False

def plot_combined_radar_chart(metrics_set1, metrics_set2, label1, label2, title, output_path):
    """
    Create a radar chart comparing two sets of P/R/D/C metrics.

    Args:
        metrics_set1 (dict): First metrics result (e.g., filtered), expected to contain
                             'individual' or 'mean' keys, or be a flat dict.
        metrics_set2 (dict): Second metrics result (e.g., unfiltered).
        label1 (str): Label for the first dataset (e.g., "Filtered").
        label2 (str): Label for the second dataset (e.g., "Unfiltered").
        title (str): Title for the chart.
        output_path (str): Path to save the chart image.
    """
    logger = logging.getLogger("GenerationEngine")
    plt.style.use('seaborn-v0_8-darkgrid') # Optional: use a nice style

    try:
        # Extract 'individual' metrics by default, fallback to mean or flat dict
        def get_plot_data(metrics_data):
            if isinstance(metrics_data, dict):
                if 'individual' in metrics_data and metrics_data['individual']:
                    return metrics_data['individual']
                elif 'mean' in metrics_data and metrics_data['mean']:
                     logger.warning(f"Using 'mean' metrics for plotting in '{title}' as 'individual' is empty/missing.")
                     return metrics_data['mean']
                else: # Assume flat dict if keys 'individual'/'mean' are missing
                     return metrics_data
            return {} # Return empty if input is not dict

        data1 = get_plot_data(metrics_set1)
        data2 = get_plot_data(metrics_set2)

        if not data1 or not data2:
            logger.error(f"Insufficient metrics data to generate combined radar chart '{title}'.")
            return False

        # Define metrics for the radar
        #categories = ['precision', 'recall', 'density', 'coverage']
        categories = ['precision', 'recall', 'coverage']
        angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
        angles += angles[:1] # Close the loop

        fig, ax = plt.subplots(figsize=(10, 8), subplot_kw=dict(polar=True))

        # --- Plot Set 1 (e.g., Filtered) ---
        values1 = [data1.get(cat, 0.0) for cat in categories]
        values1 += values1[:1] # Close the loop
        ax.plot(angles, values1, 'o-', linewidth=2, color='orange', alpha=0.8, label=label1)
        ax.fill(angles, values1, alpha=0.1, color='orange')

        # --- Plot Set 2 (e.g., Unfiltered) ---
        values2 = [data2.get(cat, 0.0) for cat in categories]
        values2 += values2[:1] # Close the loop
        ax.plot(angles, values2, 'o-', linewidth=2, color='lightgreen', alpha=0.8, label=label2)
        ax.fill(angles, values2, alpha=0.1, color='lightgreen')

        # Customize the chart
        ax.set_thetagrids(np.degrees(angles[:-1]), [c.capitalize() for c in categories])
        ax.set_ylim(0, 1.05) # Set ylim slightly > 1.0 for visibility
        ax.grid(True)
        plt.title(title, fontsize=16, y=1.1) # Increase title size and spacing
        # Position legend outside the plot area
        ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))

        plt.tight_layout() # Adjust layout
        plt.savefig(output_path, dpi=300, bbox_inches='tight') # Use bbox_inches='tight'
        plt.close(fig)

        logger.info(f"Saved combined radar chart to {output_path}")
        return True

    except Exception as e:
        logger.error(f"Error creating combined radar chart '{title}': {e}", exc_info=True)
        return False

# --- Add/Ensure these helper functions ---

def log_memory_usage(logger, step=''): # Pass logger explicitly
    """Log current memory usage."""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 ** 3) # GB
        reserved = torch.cuda.memory_reserved() / (1024 ** 3)  # GB
        logger.info(f"GPU Memory [{step}]: Allocated: {allocated:.3f} GB | Reserved: {reserved:.3f} GB")

    try:
        import psutil
        process = psutil.Process(os.getpid())
        memory_info = process.memory_info()
        memory_gb = memory_info.rss / (1024 ** 3)
        logger.info(f"CPU Memory [{step}]: {memory_gb:.3f} GB")
    except ImportError:
        pass

# Ensure force_memory_cleanup exists and uses CLEAR_CUDA_CACHE
def force_memory_cleanup(config=None): # Optional config for CLEAR_CUDA_CACHE flag
    """Force aggressive memory cleanup."""
    gc.collect()
    clear_cache_flag = True
    if config and hasattr(config, 'CLEAR_CUDA_CACHE'):
        clear_cache_flag = config.CLEAR_CUDA_CACHE

    if torch.cuda.is_available() and clear_cache_flag:
        torch.cuda.empty_cache()
        if hasattr(torch.cuda, 'memory_summary'):
             torch.cuda.synchronize()

def get_inception_model(config):
    """Load the InceptionV3 model for feature extraction."""
    if not FID_AVAILABLE:
        raise ImportError("pytorch-fid InceptionV3 not available")

    print("Loading InceptionV3 model for feature extraction...") # Use print or logger
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[config.FEATURE_DIMS]
    model = InceptionV3([block_idx]).to(config.DEVICE)
    model.eval()
    return model

# Suppress PIL DecompressionBombWarning for large images
Image.MAX_IMAGE_PIXELS = None

# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights_only.*")

# Add these functions to improve error handling and ensure errors display in Jupyter

def configure_exception_handler():
    """Configure Python to display full exception tracebacks in Jupyter."""
    import sys
    
    def custom_excepthook(exc_type, exc_value, exc_traceback):
        """Custom exception handler to ensure errors are displayed in Jupyter."""
        import traceback
        print("".join(traceback.format_exception(exc_type, exc_value, exc_traceback)))
    
    # Set the custom exception handler
    sys.excepthook = custom_excepthook

# Model Inspection
def inspect_checkpoint(checkpoint_path, logger=None):
    """Inspect a checkpoint file and print its structure for debugging."""
    try:
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        
        # Print top-level keys
        print(f"Checkpoint keys: {list(checkpoint.keys())}")
        
        # If there's a generator_state_dict, analyze its structure
        if 'generator_state_dict' in checkpoint:
            gen_dict = checkpoint['generator_state_dict']
            print("\nGenerator state dict:")
            for key, tensor in gen_dict.items():
                print(f"  {key}: shape={tensor.shape}, dtype={tensor.dtype}")
        
        # If there's a critic_state_dict, analyze its structure
        if 'critic_state_dict' in checkpoint:
            critic_dict = checkpoint['critic_state_dict']
            print("\nCritic state dict:")
            for key, tensor in critic_dict.items():
                print(f"  {key}: shape={tensor.shape}, dtype={tensor.dtype}")
        
        # Check for other potentially useful keys
        for key in ['epoch', 'best_fid', 'best_kid']:
            if key in checkpoint:
                print(f"\n{key}: {checkpoint[key]}")
                
        return checkpoint
    except Exception as e:
        print(f"Error inspecting checkpoint: {e}")
        if logger:
            logger.error(f"Error inspecting checkpoint: {e}")
        return None

# --- Configuration Class ---
class Config:
    """Configuration class with all parameters and boolean flags."""
    
    def __init__(self):
        # --- Pipeline Control Flags ---
        self.CLEAN_OUTPUT_BEFORE_RUN = True # New flag, default to False for safety
        self.USE_GPU = True  # Use GPU if available
        self.CALCULATE_REAL_STATS = True  # Calculate statistics from real images
        self.USE_STATS_CACHE = True  # Use cached statistics if available
        self.GENERATE_INDIVIDUAL_POLLEN = True  # Generate individual pollen images
        self.GENERATE_COMPOSED_IMAGES = True  # Generate composed images with backgrounds
        self.SAVE_CHECKPOINTS = True  # Save generation checkpoints for resuming
        self.GENERATE_REPORTS = True  # Generate markdown reports
        self.PERFORM_ANALYSIS = True  # Perform statistical analysis on generated images
        
        # --- Memory Management Flags ---
        self.OPTIMIZE_MEMORY_USAGE = True  # Apply memory optimization techniques
        self.LOAD_GENERATOR_TO_CPU = False  # Load generator to CPU first, then transfer to GPU when needed
        self.LOAD_CRITIC_TO_CPU = False  # Load critic to CPU to save GPU memory ~ default : True ; alt : False ;
        self.MONITOR_MEMORY = True  # Track memory usage during execution
        
        # --- Quality Control Flags ---
        self.USE_QUALITY_FILTERING = True  # Use critic for quality filtering
        self.USE_STATS_MATCHING_SIZE = True  # Match size distributions from real data
        self.USE_STATS_MATCHING_HISTOGRAM = False  # Match histogram distributions from real data ~ default : True (causes issues with poisson blending)
        
        # --- Augmentation Flags ---
        self.USE_CONTINUOUS_ROTATION = False  # Force discrete rotations (0, 90, 180, 270 only)
        self.USE_RANDOM_SCALES = True  # Apply random scaling to pollen images
        
        # --- Blending Flags ---
        self.USE_ADVANCED_BLENDING = True  # Use advanced blending methods
        self.USE_CONTENT_AWARE_MASKS = True  # Generate content-aware masks for blending
        
        # --- Parallelization Flags ---
        self.USE_PARALLEL_PROCESSING = True  # Use parallel processing for composed image generation
        self.USE_SEPARATE_PROCESS_PER_BATCH = False  # Use a separate process for each batch
        
        # --- Paths ---
        self.CHECKPOINT_PATH = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\WGAN-SN_training-output_v2-151_flawed-04\continuation-enhanced_00\checkpoints\cont_best_fid_checkpoint.pth.tar"
        self.PREPROCESSED_REAL_DATA_PATH = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\pre-processing_px-128_step_automated-labels_pc-150_mixed"
        self.RAW_REAL_DATA_PATH = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\vet_images_sliced\TrainingStepSet_automated-labels_T_full-size_150-pc_undivided_categorized\all-classes-mixed\uniform"
        self.BACKGROUND_DIR = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\backgrounds_layouts_step_automated-labels_pc-150\step_size_320\backgrounds"
        self.STATISTICS_DIR = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\backgrounds_layouts_step_automated-labels_pc-150\step_size_320\statistics"
        self.OUTPUT_DIR = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\generation-engine_XYZ"
        
        # --- Output Subdirectories ---
        self.POLLEN_SUBDIR = "synthetic_pollen"
        self.COMPOSED_SUBDIR = "synthetic_composed_XYZ"
        self.LABELS_SUBDIR = "synthetic_labels_XYZ"
        self.LOG_SUBDIR = "logs_XYZ"
        self.REPORT_SUBDIR = "reports_XYZ"
        self.GRAPH_SUBDIR = "graphs_XYZ"
        self.STATS_SUBDIR = "stats_XYZ"
        self.CONFIG_SUBDIR = "configs_XYZ"
        self.CHECKPOINT_SUBDIR = "checkpoints_XYZ"
        
        # --- Model Parameters ---
        self.NOISE_DIM = 100
        self.G_FEATURES = 64
        self.C_FEATURES = 64  # Critic features
        self.CHANNELS_IMG = 1
        
        # --- Generation Targets ---
        self.TARGET_POLLEN_IMAGES = 62492  # Number of 128x128 images to generate ~ default pre-processed : 61671 ; alt : 145 ; new : 1450 ; newer : 14500 ; default raw : 62492 ; needed : 87467 ;
        self.TARGET_COMPOSED_IMAGES = 4363  # Number of 640x640 images to generate ~ default : 4363 ; alt : 10 ; new : 100 ; newer : 1000 ; needed : 6108 ;
        self.AVG_POLLEN_PER_IMAGE = 14.32     # Average pollen per composed image ~ 14 as default ; 14.13 with pre-processed ; 14.32 with raw ;
        
        # --- Quality Filtering Parameters ---
        self.FILTERING_SURPLUS_FACTOR = 2.0  # Generate extra XYZ % for filtering ~ reason: 1.25 * 0.8 = 1.0 ; alt (A) : 2.0 * 0.5 = 1.0 ; alt (B) : 4.0 * 0.25 = 1.0 ;
        self.QUALITY_THRESHOLD_PERCENTILE = 50.0  # Keep top XYZ % by quality score ~ reason: 1.25 * 0.8 = 1.0 ; alt (A) : 2.0 * 0.5 = 1.0 ; alt (B) : 4.0 * 0.25 = 1.0 ;
        
        # --- Image Parameters ---
        self.BG_SIZE = 640                 # Background image size
        self.POLLEN_SIZE_BASE = 128        # Base pollen image size
        self.OBJECT_CLASS = 0              # YOLO class for pollen objects
        
        # --- Augmentation Parameters ---
        self.SCALE_RANGE = (0.75, 1.25)    # Allow 25% scaling in each direction
        self.ROTATION_ANGLES = [0, 90, 180, 270]  # Discrete rotation angles (if not continuous)
        self.MARGIN = 15                   # Margin from image edges (in pixels)
        
        # --- Blending Parameters ---
        self.BLENDING_METHOD = "alpha"   # Options: "poisson"/"pyramid"/"alpha" ~ default : "poisson" ; complex : "pyramid" ; simple: "alpha" ;
        self.PYRAMID_LEVELS = 4            # Levels for Laplacian pyramid blending
        self.FEATHER_AMOUNT = 15          # Used for blur/feathering in content/ellipse masks

        # --- Performance Parameters ---
        self.BATCH_SIZE = 64
        self.SCORING_BATCH_SIZE = 128 # Used for critic scoring
        self.FEATURE_EXTRACTION_BATCH_SIZE = 64 # New: Batch size for InceptionV3
        self.NUM_WORKERS = max(1, mp_cpu_count() - 1)
        self.CHUNK_SIZE = 50
        self.MEMORY_CHECK_INTERVAL = 100
        self.GPU_MEMORY_THRESHOLD = 0.9
        self.RAM_MEMORY_THRESHOLD = 0.9
        self.CLEANUP_INTERVAL = 500
        self.PIN_MEMORY = True # Added from evaluator
        self.CLEAR_CUDA_CACHE = True # Added from evaluator

        # --- Analysis Parameters ---
        self.ANALYSIS_SAMPLE_SIZE = 60000 # Max samples for analysis ~ default pre-processed : 60000 ; alt : 100 ; new : 1000 ; newer : 10000 ; default raw : 60000 ;
        self.VISUALIZATION_SAMPLE_SIZE = 15000 # New: Max samples for t-SNE/UMAP plots (adjust as needed) ~ balance real & fake counts
        self.MEASURE_FID = True # Keep if needed later
        self.MEASURE_KID = True # Keep if needed later
        self.MEASURE_PRDC = True
        self.VISUALIZE_TSNE = True # New flag to control visualization
        self.VISUALIZE_UMAP = True # New flag to control visualization
        self.FEATURE_DIMS = 2048   # Inception feature dimensions
        self.PR_K_VALUE = 6        # k-value for PRDC calculation (for torch-fidelity and maybe custom)
        self.MANIFOLD_K = 6        # k-value for custom PRDC manifold estimation
        self.DISTANCE_MULTIPLIER = 1.2 # Multiplier for custom PRDC distance thresholds
        self.SUBSAMPLE_MANIFOLD = True # Whether to subsample for custom PRDC manifold estimation
        self.RANDOM_SEED = 42      # Added from evaluator for reproducibility

        # --- Device Configuration ---
        self.DEVICE = torch.device("cuda" if torch.cuda.is_available() and self.USE_GPU else "cpu")

        
    def save(self, filename=None):
        """Save configuration to a file."""
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = os.path.join(self.OUTPUT_DIR, self.CONFIG_SUBDIR, f"config_{timestamp}.json")
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        
        # Convert config to dict, excluding device (not serializable)
        config_dict = {k: v for k, v in vars(self).items() if k != 'DEVICE'}
        
        # Save to file
        with open(filename, 'w') as f:
            json.dump(config_dict, f, indent=4)
        
        return filename
    
    @classmethod
    def load(cls, filename):
        """Load configuration from a file."""
        with open(filename, 'r') as f:
            config_dict = json.load(f)
        
        # Create new config object
        config = cls()
        
        # Update with loaded values
        for k, v in config_dict.items():
            if hasattr(config, k):
                setattr(config, k, v)
        
        # Re-compute device
        config.DEVICE = torch.device("cuda" if torch.cuda.is_available() and config.USE_GPU else "cpu")
        
        return config
    
    def get_real_data_path(self):
        """Return the appropriate real data path based on configuration."""
        #return self.PREPROCESSED_REAL_DATA_PATH if os.path.exists(self.PREPROCESSED_REAL_DATA_PATH) else self.RAW_REAL_DATA_PATH
        return self.RAW_REAL_DATA_PATH if os.path.exists(self.RAW_REAL_DATA_PATH) else self.PREPROCESSED_REAL_DATA_PATH
    
    def get_all_output_dirs(self):
        """Return a dictionary with all output directories."""
        return {
            'pollen': os.path.join(self.OUTPUT_DIR, self.POLLEN_SUBDIR),
            'composed': os.path.join(self.OUTPUT_DIR, self.COMPOSED_SUBDIR),
            'labels': os.path.join(self.OUTPUT_DIR, self.LABELS_SUBDIR),
            'logs': os.path.join(self.OUTPUT_DIR, self.LOG_SUBDIR),
            'reports': os.path.join(self.OUTPUT_DIR, self.REPORT_SUBDIR),
            'graphs': os.path.join(self.OUTPUT_DIR, self.GRAPH_SUBDIR),
            'stats': os.path.join(self.OUTPUT_DIR, self.STATS_SUBDIR),
            'configs': os.path.join(self.OUTPUT_DIR, self.CONFIG_SUBDIR),
            'checkpoints': os.path.join(self.OUTPUT_DIR, self.CHECKPOINT_SUBDIR)
        }


# --- Dependency Check and Installation ---
def check_and_install_dependencies():
    """Check for required dependencies and install if missing."""
    dependencies = {
        'opencv-python': 'cv2',
        'scikit-image': 'skimage',
        'scikit-learn': 'sklearn'
    }
    
    missing = []
    for package, module in dependencies.items():
        try:
            __import__(module)
        except ImportError:
            missing.append(package)
    
    if missing:
        print(f"Installing missing dependencies: {', '.join(missing)}")
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install"] + missing)
            print("Dependencies installed successfully.")
        except Exception as e:
            print(f"Error installing dependencies: {e}")
            print("Please install the following packages manually:")
            print(", ".join(missing))
            sys.exit(1)

# Import dependencies after checking
try:
    import cv2
    from skimage.exposure import match_histograms
    from sklearn.manifold import TSNE
    from sklearn.preprocessing import StandardScaler
except ImportError:
    check_and_install_dependencies()
    import cv2
    from skimage.exposure import match_histograms
    from sklearn.manifold import TSNE
    from sklearn.preprocessing import StandardScaler

try:
    from umap import UMAP
    HAS_UMAP = True
except ImportError:
    HAS_UMAP = False
    print("UMAP not available. t-SNE will be used for dimensionality reduction.")

# --- Logging Setup ---
def setup_logging(config):
    """Configure logging for console and file output."""
    log_dir = os.path.join(config.OUTPUT_DIR, config.LOG_SUBDIR)
    os.makedirs(log_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = os.path.join(log_dir, f"generation_engine_{timestamp}.log")
    
    # Configure logger
    logger = logging.getLogger("GenerationEngine")
    logger.setLevel(logging.INFO)
    
    # Clear existing handlers
    if logger.hasHandlers():
        logger.handlers.clear()
    
    # Create file handler with detailed formatting
    file_handler = logging.FileHandler(log_filename)
    file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(file_formatter)
    file_handler.setLevel(logging.DEBUG)  # File gets everything
    logger.addHandler(file_handler)
    
    # Create console handler with minimal formatting
    console_handler = logging.StreamHandler()
    console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    console_handler.setFormatter(console_formatter)
    console_handler.setLevel(logging.INFO)  # Console gets INFO and above
    logger.addHandler(console_handler)
    
    logger.info(f"Logging initialized. Log file: {log_filename}")
    return logger

# --- Memory Management ---
def get_memory_usage():
    """Get current memory usage statistics for RAM and GPU."""
    memory_stats = {
        'ram_used_percent': 0,
        'ram_available_gb': 0,
        'gpu_used_percent': 0,
        'gpu_used_gb': 0,
        'gpu_total_gb': 0
    }
    
    # Get RAM usage
    try:
        ram = psutil.virtual_memory()
        memory_stats['ram_used_percent'] = ram.percent
        memory_stats['ram_available_gb'] = ram.available / (1024**3)  # Convert to GB
    except Exception:
        pass
    
    # Get GPU usage if available
    if torch.cuda.is_available():
        try:
            current_device = torch.cuda.current_device()
            memory_stats['gpu_total_gb'] = torch.cuda.get_device_properties(current_device).total_memory / (1024**3)
            memory_stats['gpu_used_gb'] = (torch.cuda.memory_allocated(current_device) + 
                                          torch.cuda.memory_reserved(current_device)) / (1024**3)
            memory_stats['gpu_used_percent'] = (memory_stats['gpu_used_gb'] / memory_stats['gpu_total_gb']) * 100
        except Exception:
            pass
    
    return memory_stats

def memory_monitor(config, logger, stop_event):
    """Background thread to monitor memory usage."""
    logger.info("Starting memory monitoring thread")
    
    stats_dir = os.path.join(config.OUTPUT_DIR, config.STATS_SUBDIR)
    os.makedirs(stats_dir, exist_ok=True)
    
    memory_log = []
    start_time = time.time()
    
    try:
        while not stop_event.is_set():
            timestamp = time.time() - start_time
            memory_stats = get_memory_usage()
            
            log_entry = {
                'timestamp': timestamp,
                **memory_stats
            }
            memory_log.append(log_entry)
            
            # Log to console periodically
            if len(memory_log) % 10 == 0:
                logger.debug(f"Memory usage - RAM: {memory_stats['ram_used_percent']:.1f}%, "
                           f"GPU: {memory_stats['gpu_used_percent']:.1f}% "
                           f"({memory_stats['gpu_used_gb']:.3f}/{memory_stats['gpu_total_gb']:.3f} GB)")
            
            # Check if memory usage is critical
            if (memory_stats['ram_used_percent'] > config.RAM_MEMORY_THRESHOLD * 100 or
                memory_stats['gpu_used_percent'] > config.GPU_MEMORY_THRESHOLD * 100):
                logger.warning(f"HIGH MEMORY USAGE - RAM: {memory_stats['ram_used_percent']:.1f}%, "
                              f"GPU: {memory_stats['gpu_used_percent']:.1f}%")
            
            # Save to file periodically
            if len(memory_log) % 30 == 0:
                try:
                    with open(os.path.join(stats_dir, "memory_usage.json"), 'w') as f:
                        json.dump(memory_log, f, indent=2)
                except Exception as e:
                    logger.error(f"Error saving memory log: {e}")
            
            # Sleep for monitoring interval
            time.sleep(10)  # Check every 10 seconds
    
    except Exception as e:
        logger.error(f"Error in memory monitor thread: {e}")
    
    finally:
        # Save final memory log
        try:
            with open(os.path.join(stats_dir, "memory_usage.json"), 'w') as f:
                json.dump(memory_log, f, indent=2)
            
            # Plot memory usage
            if memory_log:
                plt.figure(figsize=(12, 6))
                timestamps = [entry['timestamp'] / 60 for entry in memory_log]  # Convert to minutes
                
                plt.plot(timestamps, [entry['ram_used_percent'] for entry in memory_log], 
                         label='RAM Usage %', color='blue')
                
                if memory_log[0]['gpu_total_gb'] > 0:
                    plt.plot(timestamps, [entry['gpu_used_percent'] for entry in memory_log], 
                             label='GPU Usage %', color='red')
                
                plt.xlabel('Time (minutes)')
                plt.ylabel('Usage (%)')
                plt.title('Memory Usage During Generation')
                plt.legend()
                plt.grid(True, alpha=0.3)
                plt.savefig(os.path.join(stats_dir, "memory_usage.png"))
                plt.close()
        except Exception as e:
            logger.error(f"Error finalizing memory log: {e}")
        
        logger.info("Memory monitoring thread stopped")

def check_memory_safe(config, logger):
    """Check if memory usage is below thresholds and safe to continue."""
    memory_stats = get_memory_usage()
    
    ram_safe = memory_stats['ram_used_percent'] < config.RAM_MEMORY_THRESHOLD * 100
    gpu_safe = memory_stats['gpu_used_percent'] < config.GPU_MEMORY_THRESHOLD * 100
    
    if not ram_safe:
        logger.warning(f"RAM usage critical: {memory_stats['ram_used_percent']:.1f}% > "
                      f"{config.RAM_MEMORY_THRESHOLD * 100}% threshold")
    
    if not gpu_safe and torch.cuda.is_available():
        logger.warning(f"GPU memory usage critical: {memory_stats['gpu_used_percent']:.1f}% > "
                      f"{config.GPU_MEMORY_THRESHOLD * 100}% threshold")
    
    return ram_safe and gpu_safe

#def force_memory_cleanup():
#    """Force aggressive memory cleanup."""
#    gc.collect()
#    if torch.cuda.is_available():
#        torch.cuda.empty_cache()
        # Extra measures for more aggressive cleanup
#        if hasattr(torch.cuda, 'memory_summary'):
#            torch.cuda.synchronize()

# --- Create Output Directories ---
def create_output_directories(config):
    """Create all necessary output directories."""
    dirs = config.get_all_output_dirs()
    
    for name, path in dirs.items():
        os.makedirs(path, exist_ok=True)
        
    return dirs

# --- Model Definitions ---
class Generator(nn.Module):
    """Generator Network for WGAN-SN (DCGAN-style architecture)."""
    def __init__(self, noise_dim, channels_img, features_g):
        super(Generator, self).__init__()
        # Input: N x noise_dim x 1 x 1
        self.net = nn.Sequential(
            # Z -> FEATURES_G*16 x 4 x 4
            self._block(noise_dim, features_g * 16, 4, 1, 0), 
            # -> FEATURES_G*8 x 8 x 8
            self._block(features_g * 16, features_g * 8, 4, 2, 1), 
            # -> FEATURES_G*4 x 16 x 16
            self._block(features_g * 8, features_g * 4, 4, 2, 1), 
            # -> FEATURES_G*2 x 32 x 32
            self._block(features_g * 4, features_g * 2, 4, 2, 1), 
            # -> FEATURES_G x 64 x 64
            self._block(features_g * 2, features_g, 4, 2, 1), 
            # -> CHANNELS_IMG x 128 x 128
            nn.ConvTranspose2d(features_g, channels_img, kernel_size=4, stride=2, padding=1), 
            # Output image in [-1, 1] range
            nn.Tanh() 
        )
    
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        """Helper for creating a Generator block (ConvTranspose2d + BatchNorm + ReLU)."""
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels), 
            nn.ReLU(True),
        )
    
    def forward(self, x):
        """Forward pass through the Generator."""
        return self.net(x)

class Critic(nn.Module):
    """WGAN-SN Critic Network with Spectral Normalization."""
    def __init__(self, channels_img, features_c):
        super(Critic, self).__init__()
        
        # The error message shows this model uses spectral normalization (weight_u, weight_v, weight_orig)
        # and has a different structure than our original implementation
        self.net = nn.Sequential(
            # Layer 0: First conv with spectral norm
            nn.utils.spectral_norm(
                nn.Conv2d(channels_img, features_c, kernel_size=4, stride=2, padding=1, bias=True)
            ),  # 64x64
            nn.LeakyReLU(0.2),
            
            # Layer 2: Second conv with spectral norm
            nn.utils.spectral_norm(
                nn.Conv2d(features_c, features_c * 2, kernel_size=4, stride=2, padding=1, bias=True)
            ),  # 32x32
            nn.LeakyReLU(0.2),
            
            # Layer 4: Third conv with spectral norm
            nn.utils.spectral_norm(
                nn.Conv2d(features_c * 2, features_c * 4, kernel_size=4, stride=2, padding=1, bias=True)
            ),  # 16x16
            nn.LeakyReLU(0.2),
            
            # Layer 6: Fourth conv with spectral norm
            nn.utils.spectral_norm(
                nn.Conv2d(features_c * 4, features_c * 8, kernel_size=4, stride=2, padding=1, bias=True)
            ),  # 8x8
            nn.LeakyReLU(0.2),
            
            # Layer 8: Fifth conv with spectral norm
            nn.utils.spectral_norm(
                nn.Conv2d(features_c * 8, features_c * 16, kernel_size=4, stride=2, padding=1, bias=True)
            ),  # 4x4
            nn.LeakyReLU(0.2),
            
            # Layer 10: Output layer with spectral norm
            nn.utils.spectral_norm(
                nn.Conv2d(features_c * 16, 1, kernel_size=4, stride=1, padding=0, bias=True)
            ),  # 1x1
        )
    
    def forward(self, x):
        """Forward pass through the Critic."""
        return self.net(x)

# --- Model Loading ---
def load_models(config, logger):
    """Load generator and critic models from checkpoint with memory optimization."""
    logger.info(f"Loading models from checkpoint: {config.CHECKPOINT_PATH}")
    
    try:
        # Initialize models
        generator = Generator(
            noise_dim=config.NOISE_DIM,
            channels_img=config.CHANNELS_IMG,
            features_g=config.G_FEATURES
        )
        
        critic = None
        if config.USE_QUALITY_FILTERING:
            critic = Critic(
                channels_img=config.CHANNELS_IMG,
                features_c=config.C_FEATURES
            )
        
        # Determine loading devices based on memory management settings
        gen_device = torch.device('cpu') if config.LOAD_GENERATOR_TO_CPU else config.DEVICE
        critic_device = torch.device('cpu') if config.LOAD_CRITIC_TO_CPU else config.DEVICE
        
        # Load checkpoint
        logger.info(f"Loading checkpoint (with weights_only=False)...")
        checkpoint = torch.load(config.CHECKPOINT_PATH, map_location='cpu', weights_only=False)
        
        # Load generator
        if 'generator_state_dict' in checkpoint:
            generator.load_state_dict(checkpoint['generator_state_dict'])
            logger.info("Generator loaded from generator_state_dict")
        elif 'model_state_dict' in checkpoint:
            generator.load_state_dict(checkpoint['model_state_dict'])
            logger.info("Generator loaded from model_state_dict")
        else:
            raise KeyError("No recognized generator state dict found in checkpoint")
        
        # Load critic (if requested and available)
        if config.USE_QUALITY_FILTERING and critic is not None:
            try:
                if 'critic_state_dict' in checkpoint:
                    critic.load_state_dict(checkpoint['critic_state_dict'])
                    logger.info("Critic loaded from critic_state_dict")
                else:
                    logger.warning("No critic_state_dict found in checkpoint, quality filtering will be disabled")
                    config.USE_QUALITY_FILTERING = False
                    critic = None
            except Exception as e:
                logger.warning(f"Failed to load critic: {e}")
                logger.warning("Quality filtering will be disabled")
                config.USE_QUALITY_FILTERING = False
                critic = None
        
        # Move models to appropriate devices
        generator = generator.to(gen_device)
        if critic is not None:
            critic = critic.to(critic_device)
            critic.eval()  # Set to evaluation mode
        
        # Set generator to evaluation mode
        generator.eval()
        
        logger.info(f"Generator on {gen_device}, Critic on {'N/A' if critic is None else critic_device}")
        logger.info(f"Quality filtering is {'enabled' if config.USE_QUALITY_FILTERING else 'disabled'}")
        
        return generator, critic, checkpoint
    
    except Exception as e:
        logger.error(f"Error loading models: {e}")
        logger.error(traceback.format_exc())
        raise

# Add this function to convert NumPy types to native Python types
def convert_numpy_types(obj):
    """Convert NumPy types in a dictionary to standard Python types for JSON serialization."""
    if isinstance(obj, dict):
        return {k: convert_numpy_types(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(item) for item in obj]
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return convert_numpy_types(obj.tolist())
    else:
        return obj

# Then update the calculate_real_data_stats function by replacing the JSON saving part:

# --- Real Data Statistics ---
def calculate_real_data_stats(config, logger):
    """Calculate statistics from real pollen images."""
    logger.info("Calculating statistics from real pollen images...")
    
    stats_cache_path = os.path.join(config.OUTPUT_DIR, config.STATS_SUBDIR, "real_data_stats.json")
    
    # Try to load cached stats if allowed
    if config.USE_STATS_CACHE and os.path.exists(stats_cache_path):
        try:
            logger.info(f"Loading cached statistics from {stats_cache_path}")
            with open(stats_cache_path, 'r') as f:
                stats = json.load(f)
            return stats
        except Exception as e:
            logger.warning(f"Error loading cached statistics: {e}. Recalculating...")
    
    # Get path to real data
    real_data_path = config.get_real_data_path()
    logger.info(f"Using real data from: {real_data_path}")
    
    # Find all image files
    image_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff']
    image_paths = []
    for ext in image_extensions:
        image_paths.extend(list(Path(real_data_path).glob(f"*{ext}")))
    
    if not image_paths:
        logger.error(f"No images found in {real_data_path}")
        return None
    
    # Sample a subset of images for faster processing
    max_images = min(len(image_paths), 20000) # default : 5000 ; alt : 20000 ; new : 60000 ;
    sampled_paths = random.sample(image_paths, max_images)
    
    # Initialize statistics containers
    widths = []
    heights = []
    sizes = []
    brightness_values = []
    contrast_values = []
    histograms = []
    
    # Process each image
    logger.info(f"Processing {len(sampled_paths)} real images for statistics...")
    for path in tqdm(sampled_paths, desc="Analyzing real images"):
        try:
            # Open image
            img = Image.open(path)
            
            # Convert to grayscale if needed
            if img.mode != 'L':
                img = img.convert('L')
            
            # Record size statistics
            width, height = img.size
            widths.append(width)
            heights.append(height)
            sizes.append(width * height)
            
            # Convert to numpy for pixel analysis
            img_np = np.array(img)
            
            # Calculate brightness (mean pixel value)
            brightness = np.mean(img_np)
            brightness_values.append(brightness)
            
            # Calculate contrast (standard deviation of pixel values)
            contrast = np.std(img_np)
            contrast_values.append(contrast)
            
            # Calculate histogram
            hist, _ = np.histogram(img_np.flatten(), bins=256, range=(0, 255), density=True)
            histograms.append(hist)
            
        except Exception as e:
            logger.warning(f"Error processing image {path}: {e}")
    
    # Compute summary statistics
    if not widths:
        logger.error("No valid images processed for statistics")
        return None
    
    stats = {
        "size_statistics": {
            "width": {
                "mean": np.mean(widths),
                "std": np.std(widths),
                "min": np.min(widths),
                "max": np.max(widths),
                "percentiles": {
                    "25": np.percentile(widths, 25),
                    "50": np.percentile(widths, 50),
                    "75": np.percentile(widths, 75)
                }
            },
            "height": {
                "mean": np.mean(heights),
                "std": np.std(heights),
                "min": np.min(heights),
                "max": np.max(heights),
                "percentiles": {
                    "25": np.percentile(heights, 25),
                    "50": np.percentile(heights, 50),
                    "75": np.percentile(heights, 75)
                }
            }
        },
        "pixel_statistics": {
            "brightness": {
                "mean": np.mean(brightness_values),
                "std": np.std(brightness_values)
            },
            "contrast": {
                "mean": np.mean(contrast_values),
                "std": np.std(contrast_values)
            }
        },
        "histogram": {
            "mean": np.mean(histograms, axis=0).tolist(),
            "std": np.std(histograms, axis=0).tolist()
        },
        "metadata": {
            "num_images_analyzed": len(widths),
            "data_path": real_data_path,
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
    }
    
    # Convert NumPy types to Python native types for JSON serialization
    stats = convert_numpy_types(stats)
    
    # Save stats to cache file
    os.makedirs(os.path.dirname(stats_cache_path), exist_ok=True)
    with open(stats_cache_path, 'w') as f:
        json.dump(stats, f, indent=4)
    
    logger.info(f"Statistics calculated and saved to {stats_cache_path}")
    return stats

# --- Load Layout Statistics ---
def load_layout_statistics(config, logger):
    """Load layout statistics from JSON file."""
    try:
        stats_file = os.path.join(config.STATISTICS_DIR, "layout_statistics.json")
        logger.info(f"Loading layout statistics from: {stats_file}")
        
        if not os.path.exists(stats_file):
            logger.warning(f"Statistics file not found: {stats_file}")
            logger.info("Using default layout parameters based on provided averages")
            return None
        
        with open(stats_file, 'r') as f:
            stats = json.load(f)
        
        logger.info("Layout statistics loaded successfully")
        return stats
    except Exception as e:
        logger.error(f"Error loading layout statistics: {e}")
        logger.error(traceback.format_exc())
        return None

# --- Load Background Images ---
def load_background_images(config, logger):
    """Load available background images from directory."""
    try:
        logger.info(f"Loading background images from: {config.BACKGROUND_DIR}")
        bg_paths = []
        
        # Find all image files in background directory
        for ext in ['.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff']:
            bg_paths.extend(list(Path(config.BACKGROUND_DIR).glob(f"*{ext}")))
        
        if not bg_paths:
            raise FileNotFoundError(f"No background images found in {config.BACKGROUND_DIR}")
        
        logger.info(f"Found {len(bg_paths)} background images")
        return bg_paths
    except Exception as e:
        logger.error(f"Error loading background images: {e}")
        logger.error(traceback.format_exc())
        raise

# --- Image Generation and Transformation ---
@torch.no_grad()
def generate_pollen_batch(generator, config, batch_size, device=None, noise_dim=None):
    """Generate a batch of pollen images with robust parameter handling."""
    # Handle device parameter
    if device is None:
        device = getattr(config, 'DEVICE', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    
    # Handle noise_dim parameter
    if noise_dim is None:
        noise_dim = getattr(config, 'NOISE_DIM', 100)  # Default to 100 if not specified
    
    # Check if generator should be moved to device
    gen_device = next(generator.parameters()).device
    temp_device_change = False
    
    if gen_device.type == 'cpu' and device.type == 'cuda' and getattr(config, 'LOAD_GENERATOR_TO_CPU', False):
        generator = generator.to(device)
        temp_device_change = True
    
    # Generate images
    try:
        # Generate random noise
        noise = torch.randn(batch_size, noise_dim, 1, 1, device=device)
        
        # Generate image batch
        fake_batch = generator(noise)
        
        # Convert tensor batch to list of PIL Images
        images = []
        for i in range(batch_size):
            # Convert to numpy array
            img_np = fake_batch[i].cpu().numpy()
            
            # Rescale from [-1, 1] to [0, 255]
            img_np = ((img_np * 0.5 + 0.5) * 255).astype(np.uint8)
            
            # Convert to PIL Image (squeeze channel dimension)
            img = Image.fromarray(img_np.squeeze(0), mode='L')
            images.append(img)
        
        return images
    finally:
        # Move generator back to CPU if it was temporarily moved
        if temp_device_change:
            generator = generator.to(torch.device('cpu'))
            # Force cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

@torch.no_grad()
def score_pollen_batch(critic, images, config, device=None):
    """
    Score a batch of PIL images using the critic model, processing in smaller chunks.
    """
    if critic is None or not getattr(config, 'USE_QUALITY_FILTERING', False):
        logger.warning("Critic scoring skipped (critic unavailable or filtering disabled). Returning NaNs.")
        return np.full(len(images), np.nan)

    logger = logging.getLogger("GenerationEngine") # Get logger
    if device is None:
        # Get device from config, ensure it's a torch.device object
        device_setting = getattr(config, 'DEVICE', "cuda" if torch.cuda.is_available() else "cpu")
        if isinstance(device_setting, str):
             device = torch.device(device_setting)
        elif isinstance(device_setting, torch.device):
             device = device_setting
        else: # Fallback if type is unexpected
             logger.warning(f"Unexpected device type in config: {type(device_setting)}. Falling back.")
             device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif isinstance(device, str): # Handle if device is passed as string argument
        device = torch.device(device)
    # Ensure device is a torch.device object now
    if not isinstance(device, torch.device):
         logger.error(f"Device is not a torch.device object after processing: {device}. Defaulting to CPU.")
         device = torch.device("cpu")


    # Get scoring batch size from config, default if not present
    scoring_batch_size = getattr(config, 'SCORING_BATCH_SIZE', 128)
    # Removed redundant logging of batch size here, happens inside loop now if needed

    # If critic is on CPU but we need GPU, move temporarily
    critic_device = next(critic.parameters()).device
    temp_device_change = False
    original_critic_device = critic_device # Store original device

    # Check if critic needs moving (compare device types)
    if critic_device.type == 'cpu' and device.type == 'cuda' and getattr(config, 'LOAD_CRITIC_TO_CPU', False):
        logger.debug(f"Temporarily moving critic to {device} for scoring.")
        critic.to(device)
        temp_device_change = True
        critic_device = device # Update current device

    all_scores = []
    try:
        num_images = len(images)
        # Add leave=False to the tqdm call here
        #for i in tqdm(range(0, num_images, scoring_batch_size), desc="Scoring images", leave=False): # <--- MODIFIED HERE
        for i in range(0, num_images, scoring_batch_size): # <--- REPLACEMENT LINE (tqdm wrapper removed)
            batch_images = images[i : min(i + scoring_batch_size, num_images)]
            if not batch_images: continue

            batch_tensors = []
            for img in batch_images:
                img_np = np.array(img, dtype=np.float32) / 255.0
                img_np = (img_np - 0.5) / 0.5
                # Add channel dim if it's grayscale
                if len(img_np.shape) == 2:
                    img_np = np.expand_dims(img_np, axis=0) # Add channel dim -> (1, H, W)

                # Handle resizing if needed (shouldn't be needed for raw generated)
                h, w = img_np.shape[1], img_np.shape[2]
                base_h, base_w = config.POLLEN_SIZE_BASE, config.POLLEN_SIZE_BASE
                if h != base_h or w != base_w:
                    # Need to convert back to PIL temporarily for resize, then back to numpy/tensor
                    # This is inefficient, avoid passing transformed images if possible
                    img_pil_tmp = Image.fromarray(((img_np * 0.5 + 0.5) * 255).astype(np.uint8).squeeze()).resize((base_w, base_h), Image.BILINEAR)
                    img_np = (np.array(img_pil_tmp, dtype=np.float32) / 255.0 - 0.5) / 0.5
                    if len(img_np.shape) == 2: img_np = np.expand_dims(img_np, axis=0) # Re-add channel if lost

                img_tensor = torch.from_numpy(img_np)
                batch_tensors.append(img_tensor)


            if not batch_tensors: continue

            try:
                # Stack tensors for the CURRENT BATCH ONLY
                batch_tensor = torch.stack(batch_tensors).to(device) # Move batch to device
            except RuntimeError as stack_err:
                 logger.error(f"Error stacking tensors during scoring: {stack_err}")
                 logger.error(f"Shapes in batch: {[t.shape for t in batch_tensors]}")
                 all_scores.append(np.full(len(batch_images), np.nan)) # Append NaNs for this failed batch
                 continue # Skip this batch

            # Get critic scores for the current batch
            try:
                scores = critic(batch_tensor).squeeze().cpu().numpy()
                if scores.ndim == 0: scores = np.array([scores.item()])
                all_scores.append(scores)
            except torch.OutOfMemoryError as oom_err:
                 logger.error(f"OOM Error during critic forward pass with batch size {len(batch_tensor)}. Try reducing SCORING_BATCH_SIZE.")
                 logger.error(f"OOM Details: {oom_err}")
                 raise oom_err # Re-raise OOM
            except Exception as forward_err:
                 logger.error(f"Error during critic forward pass: {forward_err}", exc_info=True)
                 all_scores.append(np.full(len(batch_tensor), np.nan))

            # Cleanup GPU memory for the batch
            del batch_tensor, batch_tensors, scores
            if device.type == 'cuda': # <-- Now device is guaranteed to be torch.device
                 force_memory_cleanup(config)

    except Exception as e:
        logger.error(f"Error occurred during batched scoring: {e}", exc_info=True)
        return np.full(len(images), np.nan)
    finally:
        # Move critic back to original device if it was temporarily moved
        if temp_device_change:
            logger.debug(f"Moving critic back to {original_critic_device}.")
            critic.to(original_critic_device)
            if device.type == 'cuda': force_memory_cleanup(config) # <-- Check works now

    # Concatenate scores from all batches
    if not all_scores:
        logger.warning("No scores were generated.")
        return np.array([])

    try:
        final_scores = np.concatenate(all_scores)
        if len(final_scores) != len(images):
             logger.warning(f"Number of scores ({len(final_scores)}) does not match number of images ({len(images)}). Padding with NaN.")
             padded_scores = np.full(len(images), np.nan)
             length_to_copy = min(len(final_scores), len(images))
             padded_scores[:length_to_copy] = final_scores[:length_to_copy]
             return padded_scores
        return final_scores
    except ValueError as concat_err:
         logger.error(f"Error concatenating scores: {concat_err}")
         # logger.error(f"Score list content (lengths): {[len(s) if isinstance(s, np.ndarray) else 'NaN_batch' for s in all_scores]}")
         return np.full(len(images), np.nan)

def filter_pollen_batch(scores, images, num_required, config):
    """Filter pollen images based on critic scores using percentile threshold."""
    if not getattr(config, 'USE_QUALITY_FILTERING', False):
        # Return the first num_required images if filtering is disabled
        return images[:num_required]
    
    # Ensure we have enough images
    if len(images) < num_required:
        return images
    
    # Get percentile threshold or default to 80%
    threshold_percentile = getattr(config, 'QUALITY_THRESHOLD_PERCENTILE', 80.0)
    
    # Calculate threshold based on percentile
    threshold = np.percentile(scores, 100 - threshold_percentile)
    
    # Get indices of images above threshold
    quality_indices = np.where(scores >= threshold)[0]
    
    # If we have enough quality images, use them
    if len(quality_indices) >= num_required:
        # Sort the quality indices by score (highest first)
        sorted_quality_indices = quality_indices[np.argsort(scores[quality_indices])[::-1]]
        # Take the top num_required
        selected_indices = sorted_quality_indices[:num_required]
    else:
        # If not enough quality images, sort all and take top num_required
        sorted_indices = np.argsort(scores)[::-1]
        selected_indices = sorted_indices[:num_required]
    
    # Return the selected images
    return [images[i] for i in selected_indices]

def apply_geometric_transforms(img_pil, config, real_stats=None):
    """Apply geometric transformations to a pollen image with independent width/height scaling."""
    img = img_pil.copy()
    
    # --- 1. Apply rotation ---
    # Only use 0, 90, 180, 270 degree rotations
    rotation = random.choice([0, 90, 180, 270])
    if rotation > 0:
        img = img.rotate(rotation, resample=Image.BICUBIC, expand=False)
    
    # --- 2. Apply scaling with different aspect ratios ---
    if config.USE_RANDOM_SCALES:
        if config.USE_STATS_MATCHING_SIZE and real_stats and 'size_statistics' in real_stats:
            # Use real statistics to guide scaling
            size_stats = real_stats['size_statistics']
            
            # Get mean and std from statistics
            mean_width = size_stats['width']['mean']
            std_width = size_stats['width']['std']
            mean_height = size_stats['height']['mean']
            std_height = size_stats['height']['std']
            
            # Sample width and height independently to get rectangular images
            target_width = max(10, np.random.normal(mean_width, std_width/2))
            target_height = max(10, np.random.normal(mean_height, std_height/2))
            
            # Calculate separate scale factors for width and height
            scale_x = target_width / img.width
            scale_y = target_height / img.height
        else:
            # Use uniform random scaling with independent x/y scales
            scale_x = random.uniform(config.SCALE_RANGE[0], config.SCALE_RANGE[1])
            scale_y = random.uniform(config.SCALE_RANGE[0], config.SCALE_RANGE[1])
        
        # Apply scaling (if different from original)
        if scale_x != 1.0 or scale_y != 1.0:
            new_width = int(img.width * scale_x)
            new_height = int(img.height * scale_y)
            img = img.resize((new_width, new_height), Image.LANCZOS)
    
    return img

# Also update the Config class to use proper rotation settings

def apply_histogram_matching(source_img, reference_img, config):
    """Apply histogram matching to make source image match reference histogram."""
    if not config.USE_STATS_MATCHING_HISTOGRAM:
        return source_img
    
    try:
        # Convert PIL images to numpy arrays
        source_np = np.array(source_img)
        reference_np = np.array(reference_img)
        
        # Apply histogram matching
        matched_np = match_histograms(source_np, reference_np)
        
        # Convert back to PIL
        matched_img = Image.fromarray(matched_np.astype(np.uint8))
        return matched_img
    except Exception as e:
        # On error, return original image
        return source_img

# --- Blending Functions ---
def create_binary_mask(img):
    """Create a binary mask from a grayscale image."""
    img_np = np.array(img)
    # Threshold to create binary mask (non-zero pixels become 255)
    mask_np = (img_np > 10).astype(np.uint8) * 255
    return Image.fromarray(mask_np, mode='L')

def create_content_aware_mask(img, feather_amount=10):
    """Create a content-aware mask with feathered edges."""
    if not img:
        return None
    
    # Create initial binary mask
    img_np = np.array(img)
    _, binary_mask = cv2.threshold(img_np, 10, 255, cv2.THRESH_BINARY)
    
    # Apply edge detection
    edges = cv2.Canny(binary_mask, 100, 200)
    
    # Dilate edges
    kernel = np.ones((feather_amount, feather_amount), np.uint8)
    edge_zone = cv2.dilate(edges, kernel, iterations=1)
    
    # Create distance transform inside the mask
    dist_transform = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 3)
    
    # Normalize distance transform to 0-255
    cv2.normalize(dist_transform, dist_transform, 0, 255, cv2.NORM_MINMAX)
    
    # Create soft mask by combining binary mask and feathered edges
    soft_mask = binary_mask.copy()
    soft_mask[edge_zone > 0] = dist_transform[edge_zone > 0].astype(np.uint8)
    
    # Apply Gaussian blur for smoother transition
    soft_mask = cv2.GaussianBlur(soft_mask, (feather_amount*2+1, feather_amount*2+1), 0)
    
    return Image.fromarray(soft_mask, mode='L')

def convert_to_3channel(img):
    """Convert single-channel grayscale image to 3-channel."""
    if img.mode == 'L':
        return Image.merge('RGB', (img, img, img))
    return img

def apply_alpha_blending(background, foreground, position, mask):
    """Apply alpha blending using a mask."""
    # Convert to RGB if grayscale
    bg = convert_to_3channel(background)
    fg = convert_to_3channel(foreground)
    
    # Create a copy of background
    result = bg.copy()
    
    # Paste using the mask
    result.paste(fg, position, mask)
    
    return result

def apply_poisson_blending(background, foreground, position, mask=None):
    """Apply seamless cloning (Poisson blending) using OpenCV."""
    # Convert PIL to CV2 format (RGB to BGR)
    bg_cv = cv2.cvtColor(np.array(background), cv2.COLOR_RGB2BGR)
    fg_cv = cv2.cvtColor(np.array(foreground), cv2.COLOR_RGB2BGR)
    
    # Create mask if not provided
    if mask is None:
        mask_cv = create_binary_mask(foreground)
        mask_cv = np.array(mask_cv)
    else:
        mask_cv = np.array(mask)
    
    # Check if mask is valid
    if mask_cv.max() == 0:
        # No valid mask points, fall back to alpha blending
        return apply_alpha_blending(background, foreground, position, mask)
    
    # Define center point for seamless cloning
    center = (position[0] + foreground.width // 2, position[1] + foreground.height // 2)
    
    try:
        # Apply seamless cloning
        result_cv = cv2.seamlessClone(
            fg_cv, bg_cv, mask_cv, center, cv2.MIXED_CLONE
        ) # default : "cv2.MIXED_CLONE' (causes issues upon use after histogram matching) ; alt : cv2.NORMAL_CLONE (causes issues even without histogram matching)
        
        # Convert back to PIL
        result = Image.fromarray(cv2.cvtColor(result_cv, cv2.COLOR_BGR2RGB))
        return result
    except Exception as e:
        # If seamless cloning fails, fall back to alpha blending
        return apply_alpha_blending(background, foreground, position, mask)

def apply_pyramid_blending(background, foreground, position, mask, levels=4):
    """Apply Laplacian pyramid blending."""
    # Convert PIL to numpy
    bg_np = np.array(convert_to_3channel(background))
    fg_np = np.array(convert_to_3channel(foreground))
    mask_np = np.array(mask)
    
    # Create region of interest in background
    x, y = position
    h, w = fg_np.shape[:2]
    
    # Ensure coordinates are within bounds
    if x < 0 or y < 0 or x + w > bg_np.shape[1] or y + h > bg_np.shape[0]:
        # If out of bounds, fall back to alpha blending
        return apply_alpha_blending(background, foreground, position, mask)
    
    # Extract ROI from background
    roi = bg_np[y:y+h, x:x+w].copy()
    
    # Ensure all images have the same dimensions
    if roi.shape[:2] != fg_np.shape[:2] or roi.shape[:2] != mask_np.shape:
        # Resize to match
        fg_np = cv2.resize(fg_np, (roi.shape[1], roi.shape[0]))
        mask_np = cv2.resize(mask_np, (roi.shape[1], roi.shape[0]))
    
    # Normalize mask to range [0, 1]
    mask_np = mask_np.astype(np.float32) / 255.0
    
    # Add channel dimension to mask if needed
    if len(mask_np.shape) == 2:
        mask_np = np.stack([mask_np] * 3, axis=2)
    
    try:
        # Apply pyramid blending to each channel
        result_np = np.zeros_like(roi)
        
        for ch in range(3):
            # Build Gaussian pyramid for each image
            bg_pyr = [roi[:,:,ch]]
            fg_pyr = [fg_np[:,:,ch]]
            mask_pyr = [mask_np[:,:,ch]]
            
            for i in range(levels):
                bg_pyr.append(cv2.pyrDown(bg_pyr[-1]))
                fg_pyr.append(cv2.pyrDown(fg_pyr[-1]))
                mask_pyr.append(cv2.pyrDown(mask_pyr[-1]))
            
            # Build Laplacian pyramid for bg and fg
            bg_lap = []
            fg_lap = []
            
            for i in range(levels, 0, -1):
                bg_lap.append(bg_pyr[i-1] - cv2.pyrUp(cv2.resize(bg_pyr[i], (bg_pyr[i-1].shape[1], bg_pyr[i-1].shape[0]))))
                fg_lap.append(fg_pyr[i-1] - cv2.pyrUp(cv2.resize(fg_pyr[i], (fg_pyr[i-1].shape[1], fg_pyr[i-1].shape[0]))))
            
            # Add deepest level
            bg_lap.append(bg_pyr[-1])
            fg_lap.append(fg_pyr[-1])
            
            # Blend pyramids
            blended_pyr = []
            for i in range(len(bg_lap)):
                blended_pyr.append(mask_pyr[min(i, len(mask_pyr)-1)] * fg_lap[i] + 
                                  (1 - mask_pyr[min(i, len(mask_pyr)-1)]) * bg_lap[i])
            
            # Reconstruct blended image
            blended = blended_pyr[-1]
            for i in range(len(blended_pyr)-2, -1, -1):
                blended = cv2.pyrUp(cv2.resize(blended, (blended_pyr[i].shape[1], blended_pyr[i].shape[0])))
                blended = cv2.add(blended, blended_pyr[i])
            
            result_np[:,:,ch] = blended
        
        # Insert blended result into background
        result_full = bg_np.copy()
        result_full[y:y+h, x:x+w] = result_np
        
        # Convert back to PIL
        result = Image.fromarray(result_full.astype(np.uint8))
        return result
    
    except Exception as e:
        # If pyramid blending fails, fall back to alpha blending
        return apply_alpha_blending(background, foreground, position, mask)

def blend_pollen_onto_background(background, pollen, position, config):
    """Apply the selected blending method."""
    logger = logging.getLogger("GenerationEngine") # Ensure logger is accessible
    mask = None # Initialize mask

    # --- Optional Pre-adjust contrast if you add it back later ---
    # adjusted_pollen = pollen
    # if getattr(config, 'ADJUST_POLLEN_CONTRAST', False):
    #    # ... (contrast adjustment code would go here) ...
    #    # Pass adjusted_pollen below instead of pollen if implemented
    # else:
    #    adjusted_pollen = pollen # Use original if no adjustment
    # For now, we use the original 'pollen' input directly
    adjusted_pollen = pollen
    # ---

    if config.BLENDING_METHOD == "poisson":
        mask = create_binary_mask(pollen) # Poisson uses binary mask
        if mask is None:
             logger.error("Failed to create binary mask for Poisson blending.")
             return background # Return original background on failure
        return apply_poisson_blending(background, adjusted_pollen, position, mask)

    elif config.BLENDING_METHOD == "pyramid":
        # Use content-aware mask if flag is set, otherwise fallback to binary
        # Passing None to apply_pyramid_blending is likely incorrect, so always create a mask.
        if getattr(config, 'USE_CONTENT_AWARE_MASKS', False): # Check if attr exists and is True
             mask = create_content_aware_mask(pollen, config.FEATHER_AMOUNT)
        else:
             mask = create_binary_mask(pollen) # Fallback for pyramid if content aware is off

        if mask is None:
             logger.error("Failed to create mask for Pyramid blending.")
             return background # Return original background on failure
        return apply_pyramid_blending(background, adjusted_pollen, position, mask, config.PYRAMID_LEVELS)

    else:  # Alpha blending or fallback
        logger.debug("Using Alpha blending with Elliptical mask.")
        # --- MODIFIED HERE: Always use elliptical mask for alpha/fallback ---
        mask = create_elliptical_mask(adjusted_pollen, config.FEATHER_AMOUNT)
        # --- End Modification ---

        # Check if mask creation succeeded
        if mask is None:
             logger.error("Failed to create elliptical mask for Alpha blending. Returning original background.")
             # Attempt fallback to binary mask if elliptical failed? Optional.
             # mask = create_binary_mask(pollen)
             # if mask is None: return background
             return background # Return original background if mask fails

        # Apply alpha blending using the generated elliptical mask
        return apply_alpha_blending(background, adjusted_pollen, position, mask)

# --- Position Generation ---
def generate_non_overlapping_positions(count, img_size, min_obj_size, margin):
    """Generate non-overlapping positions for placing pollen on the background."""
    positions = []
    max_attempts = count * 10  # Limit attempts to avoid infinite loops
    
    # Effective area for placement
    valid_min = margin
    valid_max = img_size - min_obj_size - margin
    
    if valid_max <= valid_min:
        raise ValueError(f"Can't place objects: margin ({margin}) too large for image size ({img_size}) and object size ({min_obj_size})")
    
    attempt = 0
    while len(positions) < count and attempt < max_attempts:
        attempt += 1
        
        # Generate random position
        x = random.randint(valid_min, valid_max)
        y = random.randint(valid_min, valid_max)
        
        # Check for overlap with existing positions
        overlap = False
        for existing_x, existing_y in positions:
            # Calculate center-to-center distance
            distance = np.sqrt((x - existing_x)**2 + (y - existing_y)**2)
            
            # Check if distance is less than the minimum object size with a small buffer
            if distance < min_obj_size * 0.9:
                overlap = True
                break
        
        if not overlap:
            positions.append((x, y))
    
    return positions

def determine_pollen_count(layout_stats=None, avg_count=None):
    """Determine how many pollen to place on each background."""
    if layout_stats and 'summary' in layout_stats:
        # Get mean and std from statistics
        mean = layout_stats['summary']['num_valid_objects_per_image']['mean']
        std = layout_stats['summary']['num_valid_objects_per_image']['std']
        
        # Generate a random count based on normal distribution
        count = max(1, int(np.random.normal(mean, std / 4)))  # Divide std by 4 to avoid extreme values
    elif avg_count:
        # Generate a random count around the average
        count = max(1, int(np.random.normal(avg_count, avg_count / 4)))
    else:
        # Fallback to a reasonable default range
        count = random.randint(5, 25)
    
    return count

def generate_yolo_annotation(positions_and_sizes, img_size, obj_class=0):
    """Generate YOLO annotation format for object detection."""
    annotations = []
    
    for (x, y), (width, height) in positions_and_sizes:
        # Calculate center coordinates and normalized dimensions
        center_x = (x + width / 2) / img_size
        center_y = (y + height / 2) / img_size
        norm_width = width / img_size
        norm_height = height / img_size
        
        # Format: class x_center y_center width height
        annotations.append(f"{obj_class} {center_x:.6f} {center_y:.6f} {norm_width:.6f} {norm_height:.6f}")
    
    return "\n".join(annotations)

# --- Analysis and Visualization ---

#def extract_image_features(images, batch_size=64, device=None):
#    """Extract features from images using a feature extractor network."""

# --- Replace the old extract_image_features function ---

def extract_features(data, inception_model, config, is_real=False):
    """
    Extract features from real or generated images using InceptionV3.
    Handles both list of file paths and list of PIL Images.
    """
    logger = logging.getLogger("GenerationEngine") # Get logger instance
    num_samples = len(data)
    desc = "Processing real images" if is_real else "Processing generated images"
    logger.info(f"Extracting features from {num_samples} {desc.split(' ')[1]} images...")
    log_memory_usage(logger, f"Before {desc} features")

    all_features = []
    batch_size = config.FEATURE_EXTRACTION_BATCH_SIZE
    device = config.DEVICE

    # Define transforms (consistent with model-evaluation)
    transform = transforms.Compose([
        transforms.Resize((config.POLLEN_SIZE_BASE, config.POLLEN_SIZE_BASE)), # Use base pollen size
        transforms.ToTensor(),
        transforms.Normalize([0.5] * config.CHANNELS_IMG, [0.5] * config.CHANNELS_IMG) # Match GAN output range
    ])

    try:
        with torch.no_grad():
            for i in tqdm(range(0, num_samples, batch_size), desc=desc, leave=True):
                batch_data = data[i:min(i + batch_size, num_samples)]
                batch_tensors = []

                for item in batch_data:
                    try:
                        # Fixed code
                        if isinstance(item, (str, Path)):  # It's a path (string or Path object)
                            img = Image.open(str(item)).convert('L' if config.CHANNELS_IMG == 1 else 'RGB')
                        elif isinstance(item, Image.Image):  # It's a PIL image
                            img = item.convert('L' if config.CHANNELS_IMG == 1 else 'RGB')
                        else:
                            logger.warning(f"Skipping unsupported data type: {type(item)}")
                            continue
                        img_tensor = transform(img)
                        batch_tensors.append(img_tensor)
                    except Exception as e:
                         logger.warning(f"Error loading/transforming image {item}: {e}. Skipping.")
                         continue # Skip problematic images

                if not batch_tensors:
                    continue

                batch = torch.stack(batch_tensors).to(device)

                # Handle grayscale (Inception expects RGB)
                if batch.shape[1] == 1:
                    batch = batch.repeat(1, 3, 1, 1)

                # --- IMPORTANT: Normalization for Inception ---
                # The model-evaluation script used (batch + 1) / 2.
                # This assumes the Inception model expects [0, 1] range.
                # We replicate that here for consistency.
                batch = (batch + 1) / 2

                # Get features
                features = inception_model(batch)[0]

                # Pool features
                if features.shape[2] != 1 or features.shape[3] != 1:
                    features = F.adaptive_avg_pool2d(features, output_size=(1, 1))

                # Reshape and move to CPU
                features = features.squeeze(-1).squeeze(-1).cpu().numpy()
                all_features.append(features)

                # Memory cleanup
                del batch, features, batch_tensors
                if i % 10 == 0: # Cleanup less frequently than every batch
                    force_memory_cleanup(config)


    except Exception as e:
         logger.error(f"Error during feature extraction: {e}", exc_info=True)
         return None # Indicate failure
    finally:
         force_memory_cleanup(config) # Final cleanup

    if not all_features:
        logger.error(f"Failed to extract any features for {desc}")
        return None

    all_features = np.concatenate(all_features, axis=0)
    logger.info(f"Extracted features from {all_features.shape[0]} images, shape: {all_features.shape}")
    log_memory_usage(logger, f"After {desc} features")
    return all_features

#def generate_tsne_plot(real_features, gen_features, title, output_path):
#    """Generate t-SNE visualization of real vs. generated image features."""

# --- Modify generate_tsne_plot ---
def generate_tsne_plot(real_features, gen_features, title, output_path, config):
    """Generate t-SNE visualization of real vs. generated image features."""
    logger = logging.getLogger("GenerationEngine")
    if not TSNE_AVAILABLE:
        logger.error("Cannot generate t-SNE plot: scikit-learn TSNE not available.")
        return False

    try:
        num_points = len(real_features) + len(gen_features)
        logger.info(f"Preparing {num_points} points ({len(real_features)} real vs {len(gen_features)} generated) for t-SNE...")
        all_features = np.vstack([real_features, gen_features])

        # Apply t-SNE (using config seed)
        logger.info(f"Computing t-SNE (perplexity=30, init='pca', lr='auto', n_jobs=-1)... (Check console for verbose progress)")
        tsne_start_time = time.time() # Time the calculation
        tsne = TSNE(n_components=2, perplexity=30,
                    learning_rate='auto',
                    init='pca',
                    random_state=config.RANDOM_SEED,
                    n_jobs=-1,
                    verbose=1) # <<<--- SET verbose=1 FOR PROGRESS OUTPUT
        embeddings = tsne.fit_transform(all_features)
        tsne_duration = time.time() - tsne_start_time
        logger.info(f"t-SNE computation finished in {tsne_duration:.3f} seconds.")

        # Split back into real and generated
        real_count = real_features.shape[0]
        real_embeddings = embeddings[:real_count]
        gen_embeddings = embeddings[real_count:]

        # Plot
        logger.info(f"Generating plot file: {output_path}")
        plt.figure(figsize=(10, 8))
        plt.scatter(real_embeddings[:, 0], real_embeddings[:, 1], alpha=0.7,
                    s=5, label='Real Images', c='blue')
        plt.scatter(gen_embeddings[:, 0], gen_embeddings[:, 1], alpha=0.7,
                    s=5, label='Generated Images', c='red')

        plt.title(f't-SNE Visualization: {title}\n({num_points} points)') # Add point count to title
        plt.legend(markerscale=3)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.savefig(output_path, dpi=300)
        plt.close()
        logger.info(f"Saved t-SNE plot to {output_path}")
        return True
    except Exception as e:
        logger.error(f"Error generating t-SNE plot: {e}", exc_info=True)
        # Attempt to clean up figure in case of error during saving
        try:
            plt.close()
        except:
            pass
        return False

#def generate_umap_plot(real_features, gen_features, title, output_path):
#    """Generate UMAP visualization of real vs. generated image features."""

# --- Modify generate_umap_plot ---
def generate_umap_plot(real_features, gen_features, title, output_path, config): # Add config
    """Generate UMAP visualization of real vs. generated image features."""
    logger = logging.getLogger("GenerationEngine")
    if not HAS_UMAP:
         logger.error("Cannot generate UMAP plot: umap-learn not available.")
         return False

    try:
        num_points = len(real_features) + len(gen_features)
        logger.info(f"Starting UMAP calculation for {num_points} points...")
        start_time = time.time()
        # Combine features
        all_features = np.vstack([real_features, gen_features])

        # --- REMOVED StandardScaler ---
        # scaler = StandardScaler()
        # all_features = scaler.fit_transform(all_features)

        # Apply UMAP (using config seed)
        logger.info("Computing UMAP embedding...")
        reducer = UMAP(n_neighbors=15, min_dist=0.1, n_components=2,
                       metric='euclidean', # Match evaluator
                       random_state=config.RANDOM_SEED,
                       verbose=True) # <-- Add verbosity
                       
        embeddings = reducer.fit_transform(all_features)

        duration = time.time() - start_time
        logger.info(f"UMAP calculation finished in {duration:.3f} seconds.")

        # Split back into real and generated
        real_count = real_features.shape[0]
        real_embeddings = embeddings[:real_count]
        gen_embeddings = embeddings[real_count:]

        # Plot
        plt.figure(figsize=(10, 8))
        plt.scatter(real_embeddings[:, 0], real_embeddings[:, 1], alpha=0.7,
                    s=5, label='Real Images', c='blue') # Smaller points
        plt.scatter(gen_embeddings[:, 0], gen_embeddings[:, 1], alpha=0.7,
                    s=5, label='Generated Images', c='red') # Smaller points

        plt.title(f'UMAP Visualization: {title}')
        plt.legend(markerscale=3) # Match legend style
        plt.grid(alpha=0.3)      # Match grid style
        plt.tight_layout()
        plt.savefig(output_path, dpi=300)
        plt.close()
        logger.info(f"Saved UMAP plot to {output_path}")
        return True
    except Exception as e:
        logger.error(f"Error generating UMAP plot: {e}", exc_info=True)
        return False

#def calculate_prdc_metrics(real_features, gen_features, nearest_k=5):
#    """Calculate Precision, Recall, Density, and Coverage metrics."""

# --- Replace the old calculate_prdc_metrics function ---

def calculate_prdc_metrics(real_features, fake_features, config):
    """
    Calculate Precision, Recall, Density and Coverage metrics.
    Adapted from model-evaluation.ipynb.
    Uses torch-fidelity if available, otherwise custom implementation.
    """
    logger = logging.getLogger("GenerationEngine") # Get logger instance
    metrics_results = {'individual': {}, 'mean': {}}

    # Ensure features are numpy arrays
    if not isinstance(real_features, np.ndarray): real_features = np.array(real_features)
    if not isinstance(fake_features, np.ndarray): fake_features = np.array(fake_features)

    # --- Try torch-fidelity first ---
    if TORCH_FIDELITY_AVAILABLE:
        try:
            logger.info("Calculating P/R/D/C metrics using torch-fidelity...")
            eval_dir = os.path.join(config.OUTPUT_DIR, config.STATS_SUBDIR) # Save temp files here
            os.makedirs(eval_dir, exist_ok=True)

            real_features_path = os.path.join(eval_dir, "temp_real_features.npz")
            fake_features_path = os.path.join(eval_dir, "temp_fake_features.npz")

            np.savez(real_features_path, features=real_features)
            np.savez(fake_features_path, features=fake_features)

            metrics_dict = torch_fidelity.calculate_metrics(
                input1=fake_features_path, # Input1 is usually fake/generated
                input2=real_features_path, # Input2 is usually real
                cuda=torch.cuda.is_available() and config.USE_GPU,
                isc=False, fid=False, kid=False, verbose=False,
                prc=True, # Calculate Precision-Recall-Coverage
                prc_k=config.PR_K_VALUE, # Use K from config
                cache=False
            )

            # Clean up temporary files
            for path in [real_features_path, fake_features_path]:
                if os.path.exists(path):
                    try: os.remove(path)
                    except OSError as e: logger.warning(f"Could not remove temp file {path}: {e}")

            # Extract metrics
            metrics = {
                'precision': metrics_dict.get('precision', 0.0),
                'recall': metrics_dict.get('recall', 0.0),
                'density': metrics_dict.get('density', 0.0), # Density/Coverage might not always be calculated
                'coverage': metrics_dict.get('coverage', 0.0)
            }

            if metrics_dict.get('precision') is not None and metrics_dict.get('recall') is not None:
                 logger.info(f"torch-fidelity P/R/D/C: P={metrics['precision']:.4f}, R={metrics['recall']:.4f}, D={metrics['density']:.6f}, C={metrics['coverage']:.4f}")
                 # Store results for both 'individual' and 'mean' keys for compatibility with report format
                 metrics_results['individual'] = metrics.copy()
                 metrics_results['mean'] = metrics.copy()
                 return metrics_results
            else:
                logger.warning("torch-fidelity did not return precision/recall metrics. Falling back to custom.")

        except Exception as e:
            logger.error(f"torch-fidelity calculation failed: {e}")
            logger.info("Falling back to custom P/R/D/C implementation...")
            # Clean up temp files in case of error
            for path in [real_features_path, fake_features_path]:
                if os.path.exists(path):
                     try: os.remove(path)
                     except OSError as e: logger.warning(f"Could not remove temp file {path}: {e}")

    #
    # --- Custom Implementation Fallback ---
    if not NEIGHBORS_AVAILABLE:
         logger.error("Cannot calculate custom P/R/D/C: sklearn.neighbors not available.")
         metrics_results['individual'] = {'precision': 0.0, 'recall': 0.0, 'density': 0.0, 'coverage': 0.0}
         metrics_results['mean'] = {'precision': 0.0, 'recall': 0.0, 'density': 0.0, 'coverage': 0.0}
         return metrics_results

    try:
        k = config.MANIFOLD_K
        multiplier = config.DISTANCE_MULTIPLIER
        subsample = config.SUBSAMPLE_MANIFOLD
        num_real = len(real_features)
        num_fake = len(fake_features)
        subsample_size = 10000  # Size to subsample to if enabled
        query_chunk_size = 5000  # Size for chunked queries, adjust based on memory

        logger.info(f"Calculating custom P/R/D/C (k={k}, multiplier={multiplier}, subsample={subsample})")

        # Subsampling logic - no change
        if subsample and num_real > subsample_size:
            logger.info(f"Subsampling real features from {num_real} to {subsample_size} for manifold estimation")
            indices_real = np.random.choice(num_real, subsample_size, replace=False)
            real_for_manifold = real_features[indices_real]
        else:
            real_for_manifold = real_features

        if subsample and num_fake > subsample_size:
            logger.info(f"Subsampling fake features from {num_fake} to {subsample_size} for manifold estimation")
            indices_fake = np.random.choice(num_fake, subsample_size, replace=False)
            fake_for_manifold = fake_features[indices_fake]
        else:
            fake_for_manifold = fake_features

        # Calculate real manifold radii - no change
        logger.debug("Fitting real manifold NN...")
        real_nn = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree', n_jobs=-1)
        real_nn.fit(real_for_manifold)
        real_distances, _ = real_nn.kneighbors(real_for_manifold)
        real_radii = real_distances[:, k] if k > 0 else np.zeros(len(real_for_manifold))

        # Calculate fake manifold radii - no change
        logger.debug("Fitting fake manifold NN...")
        fake_nn = NearestNeighbors(n_neighbors=k + 1, algorithm='ball_tree', n_jobs=-1)
        fake_nn.fit(fake_for_manifold)
        fake_distances, _ = fake_nn.kneighbors(fake_for_manifold)
        fake_radii = fake_distances[:, k] if k > 0 else np.zeros(len(fake_for_manifold))

        # --- Individual Distances Method (CHUNKED) ---
        logger.info("Calculating metrics using individual distances (chunked)...")

        # --- CHUNKED Precision Calculation ---
        logger.debug("Calculating precision (individual, chunked)...")
        real_nn_query = NearestNeighbors(n_neighbors=1, algorithm='ball_tree', n_jobs=-1)
        real_nn_query.fit(real_for_manifold)  # Fit on manifold
        
        precision_in_threshold_count = 0
        total_fake_samples = 0
        
        # Calculate number of chunks for progress reporting
        fake_chunks = math.ceil(len(fake_features) / query_chunk_size)
        logger.info(f"Processing precision in {fake_chunks} chunks...")
        
        # Process in chunks
        with tqdm(total=len(fake_features), desc="PRDC Precision") as pbar:
            for i in range(0, len(fake_features), query_chunk_size):
                chunk_end = min(i + query_chunk_size, len(fake_features))
                fake_chunk = fake_features[i:chunk_end]
                chunk_size = len(fake_chunk)
                
                if chunk_size == 0:
                    continue
                    
                # Query nearest neighbors for this chunk
                precision_distances, closest_real_idx = real_nn_query.kneighbors(fake_chunk)
                
                # Get the radii for these nearest neighbors
                if subsample and num_real > subsample_size:
                    # Map indices back to the subsampled manifold
                    closest_real_radii = np.array([real_radii[idx] for idx in closest_real_idx.flatten()])
                else:
                    closest_real_radii = np.array([real_radii[idx] for idx in closest_real_idx.flatten()])
                
                # Calculate thresholds for this chunk
                threshold_precision = closest_real_radii * multiplier
                
                # Count samples within threshold
                in_threshold = np.sum(precision_distances.flatten() <= threshold_precision)
                precision_in_threshold_count += in_threshold
                total_fake_samples += chunk_size
                
                # Update progress bar
                pbar.update(chunk_size)
                
                # Clean up chunk data
                del fake_chunk, precision_distances, closest_real_idx, closest_real_radii, threshold_precision
                if i % (2 * query_chunk_size) == 0:  # Less frequent cleanup
                    gc.collect()
        
        # Calculate final precision
        precision_individual = precision_in_threshold_count / total_fake_samples if total_fake_samples > 0 else 0.0
        
        # --- CHUNKED Recall Calculation ---
        logger.debug("Calculating recall (individual, chunked)...")
        fake_nn_query = NearestNeighbors(n_neighbors=1, algorithm='ball_tree', n_jobs=-1)
        fake_nn_query.fit(fake_for_manifold)  # Fit on manifold
        
        recall_in_threshold_count = 0
        total_real_samples = 0
        
        # Calculate number of chunks for progress reporting
        real_chunks = math.ceil(len(real_features) / query_chunk_size)
        logger.info(f"Processing recall in {real_chunks} chunks...")
        
        # Process in chunks
        with tqdm(total=len(real_features), desc="PRDC Recall") as pbar:
            for i in range(0, len(real_features), query_chunk_size):
                chunk_end = min(i + query_chunk_size, len(real_features))
                real_chunk = real_features[i:chunk_end]
                chunk_size = len(real_chunk)
                
                if chunk_size == 0:
                    continue
                    
                # Query nearest neighbors for this chunk
                recall_distances, closest_fake_idx = fake_nn_query.kneighbors(real_chunk)
                
                # Get the radii for these nearest neighbors
                if subsample and num_fake > subsample_size:
                    # Map indices back to the subsampled manifold
                    closest_fake_radii = np.array([fake_radii[idx] for idx in closest_fake_idx.flatten()])
                else:
                    closest_fake_radii = np.array([fake_radii[idx] for idx in closest_fake_idx.flatten()])
                
                # Calculate thresholds for this chunk
                threshold_recall = closest_fake_radii * multiplier
                
                # Count samples within threshold
                in_threshold = np.sum(recall_distances.flatten() <= threshold_recall)
                recall_in_threshold_count += in_threshold
                total_real_samples += chunk_size
                
                # Update progress bar
                pbar.update(chunk_size)
                
                # Clean up chunk data
                del real_chunk, recall_distances, closest_fake_idx, closest_fake_radii, threshold_recall
                if i % (2 * query_chunk_size) == 0:  # Less frequent cleanup
                    gc.collect()
        
        # Calculate final recall
        recall_individual = recall_in_threshold_count / total_real_samples if total_real_samples > 0 else 0.0
        coverage_individual = recall_individual  # Coverage = Recall in this implementation
        
        # --- CHUNKED Density Calculation with Mean Threshold ---
        # Using mean fake radius to reduce memory needs
        logger.debug("Calculating density (mean threshold, chunked)...")
        fake_mean_radius = np.mean(fake_radii) * multiplier
        
        fake_nn_query_multi = NearestNeighbors(n_neighbors=min(50, len(fake_for_manifold)), algorithm='ball_tree', n_jobs=-1)
        fake_nn_query_multi.fit(fake_for_manifold)
        
        density_sum = 0
        total_real_samples_density = 0
        
        # Process in chunks
        with tqdm(total=len(real_features), desc="PRDC Density") as pbar:
            for i in range(0, len(real_features), query_chunk_size):
                chunk_end = min(i + query_chunk_size, len(real_features))
                real_chunk = real_features[i:chunk_end]
                chunk_size = len(real_chunk)
                
                if chunk_size == 0:
                    continue
                    
                # Query multiple nearest neighbors for this chunk
                multi_distances, _ = fake_nn_query_multi.kneighbors(real_chunk)
                
                # For each real point, count how many fake points are within the mean radius
                samples_in_radius = np.sum(multi_distances <= fake_mean_radius, axis=1)
                density_sum += np.sum(samples_in_radius)
                total_real_samples_density += chunk_size
                
                # Update progress bar
                pbar.update(chunk_size)
                
                # Clean up chunk data
                del real_chunk, multi_distances, samples_in_radius
                if i % (2 * query_chunk_size) == 0:  # Less frequent cleanup
                    gc.collect()
        
        # Calculate final density
        density_individual = (density_sum / total_real_samples_density) / len(fake_for_manifold) if total_real_samples_density > 0 and len(fake_for_manifold) > 0 else 0.0
        
        metrics_results['individual'] = {
            'precision': float(precision_individual), 
            'recall': float(recall_individual),
            'density': float(density_individual), 
            'coverage': float(coverage_individual)
        }
        logger.info(f"Individual distances metrics: P={precision_individual:.4f}, R={recall_individual:.4f}, D={density_individual:.6f}, C={coverage_individual:.4f}")

        # --- Mean Distances Method ---
        # For the mean method, we can reuse the chunked calculation approach but with fixed thresholds
        logger.info("Calculating metrics using mean distances...")

        real_mean_radius = np.mean(real_radii) * multiplier
        fake_mean_radius = np.mean(fake_radii) * multiplier
        logger.info(f"Mean distance thresholds: real={real_mean_radius:.6f}, fake={fake_mean_radius:.6f}")

        # Since we already calculated density using mean threshold, just reuse that value
        density_mean = density_individual

        # --- CHUNKED Precision (Mean) ---
        precision_mean_count = 0
        precision_mean_total = 0
        
        with tqdm(total=len(fake_features), desc="Mean Precision") as pbar:
            for i in range(0, len(fake_features), query_chunk_size):
                chunk_end = min(i + query_chunk_size, len(fake_features))
                fake_chunk = fake_features[i:chunk_end]
                chunk_size = len(fake_chunk)
                
                if chunk_size == 0:
                    continue
                
                # Query nearest real neighbors for this chunk
                distances, _ = real_nn_query.kneighbors(fake_chunk)
                
                # Count those within the mean threshold
                in_threshold = np.sum(distances.flatten() <= real_mean_radius)
                precision_mean_count += in_threshold
                precision_mean_total += chunk_size
                
                # Update progress and cleanup
                pbar.update(chunk_size)
                del fake_chunk, distances
        
        precision_mean = precision_mean_count / precision_mean_total if precision_mean_total > 0 else 0.0

        # --- CHUNKED Recall (Mean) ---
        recall_mean_count = 0
        recall_mean_total = 0
        
        with tqdm(total=len(real_features), desc="Mean Recall") as pbar:
            for i in range(0, len(real_features), query_chunk_size):
                chunk_end = min(i + query_chunk_size, len(real_features))
                real_chunk = real_features[i:chunk_end]
                chunk_size = len(real_chunk)
                
                if chunk_size == 0:
                    continue
                
                # Query nearest fake neighbors for this chunk
                distances, _ = fake_nn_query.kneighbors(real_chunk)
                
                # Count those within the mean threshold
                in_threshold = np.sum(distances.flatten() <= fake_mean_radius)
                recall_mean_count += in_threshold
                recall_mean_total += chunk_size
                
                # Update progress and cleanup
                pbar.update(chunk_size)
                del real_chunk, distances
        
        recall_mean = recall_mean_count / recall_mean_total if recall_mean_total > 0 else 0.0
        coverage_mean = recall_mean

        metrics_results['mean'] = {
            'precision': float(precision_mean), 
            'recall': float(recall_mean),
            'density': float(density_mean), 
            'coverage': float(coverage_mean)
        }
        logger.info(f"Mean distances metrics: P={precision_mean:.4f}, R={recall_mean:.4f}, D={density_mean:.6f}, C={coverage_mean:.4f}")

        # Cleanup NN models
        del real_nn, fake_nn, real_nn_query, fake_nn_query, fake_nn_query_multi
        force_memory_cleanup(config)

    except Exception as e:
        logger.error(f"Custom P/R/D/C calculation failed: {e}", exc_info=True)
        metrics_results['individual'] = {'precision': 0.0, 'recall': 0.0, 'density': 0.0, 'coverage': 0.0}
        metrics_results['mean'] = {'precision': 0.0, 'recall': 0.0, 'density': 0.0, 'coverage': 0.0}

    return metrics_results
    #

#def generate_radar_chart(metrics, title, output_path):
#    """Generate a radar chart visualization of metrics."""

# --- Modify generate_radar_chart (optional: handle individual/mean dict) ---
def generate_radar_chart(metrics_data, title, output_path):
    """Generate a radar chart visualization of metrics.
       Can accept a dict like {'precision': v, ...} or {'individual': {...}, 'mean': {...}}
    """
    logger = logging.getLogger("GenerationEngine")
    try:
        # If passed the combined dict, use 'individual' by default, or try 'mean'
        if 'individual' in metrics_data:
            metrics = metrics_data['individual']
            if not metrics: metrics = metrics_data.get('mean', {}) # Fallback to mean if individual is empty
        else:
            metrics = metrics_data # Assume it's the flat dict

        if not metrics:
            logger.warning(f"No metrics data provided for radar chart '{title}'")
            return False

        # Extract metrics - handle missing keys gracefully
        #categories = ['precision', 'recall', 'density', 'coverage']
        categories = ['precision', 'recall', 'coverage']
        values = [metrics.get(cat, 0.0) for cat in categories] # Default to 0.0 if missing

        # Create radar chart
        num_vars = len(categories)
        angles = np.linspace(0, 2*np.pi, num_vars, endpoint=False).tolist()
        values_plot = values + values[:1] # Close the loop
        angles_plot = angles + angles[:1] # Close the loop

        fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
        ax.plot(angles_plot, values_plot, 'o-', linewidth=2, color='purple') # Use consistent color
        ax.fill(angles_plot, values_plot, alpha=0.25, color='purple')
        ax.set_thetagrids(np.degrees(angles), [c.capitalize() for c in categories])

        ax.set_ylim(0, 1.05) # Set ylim slightly above 1.0
        plt.title(title, size=15, y=1.1)
        ax.grid(True) # Add grid

        # Add value annotations like in evaluator script
        for angle, value, name in zip(angles, values, categories):
            ax.text(angle, value + 0.05, f"{value:.3f}",
                      horizontalalignment='center', verticalalignment='center')


        plt.tight_layout()
        plt.savefig(output_path, dpi=300)
        plt.close()
        logger.info(f"Saved radar chart to {output_path}")
        return True
    except Exception as e:
        logger.error(f"Error generating radar chart '{title}': {e}", exc_info=True)
        return False
        
def save_sample_visualization(pollen_paths, composed_paths, labels_paths, output_path, num_samples=4):
    """Create a visualization of sample images."""
    try:
        # Limit to available images
        pollen_paths = pollen_paths[:num_samples]
        composed_paths = composed_paths[:num_samples]
        
        # Create figure
        fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 4, 8))
        
        # Display individual pollen images on top row
        for i, path in enumerate(pollen_paths):
            if i >= num_samples:
                break
            img = Image.open(path)
            axes[0, i].imshow(img, cmap='gray')
            axes[0, i].set_title(f"Pollen {i+1}")
            axes[0, i].axis('off')
        
        # Display composed images on bottom row
        for i, path in enumerate(composed_paths):
            if i >= num_samples:
                break
            img = Image.open(path)
            axes[1, i].imshow(img, cmap='gray')
            axes[1, i].set_title(f"Composed {i+1}")
            axes[1, i].axis('off')
            
            # Add bounding boxes if available
            label_path = labels_paths[i] if i < len(labels_paths) else None
            if label_path and os.path.exists(label_path):
                with open(label_path, 'r') as f:
                    annotations = f.readlines()
                
                for annotation in annotations:
                    parts = annotation.strip().split()
                    if len(parts) == 5:
                        # Parse YOLO format
                        _, x_center, y_center, width, height = map(float, parts)
                        
                        # Convert to pixel coordinates
                        img_w, img_h = img.size
                        x = (x_center - width/2) * img_w
                        y = (y_center - height/2) * img_h
                        w = width * img_w
                        h = height * img_h
                        
                        # Draw rectangle
                        rect = plt.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')
                        axes[1, i].add_patch(rect)
        
        # Save figure
        plt.tight_layout()
        plt.savefig(output_path, dpi=300)
        plt.close()
        
        return True
    except Exception as e:
        print(f"Error generating sample visualization: {e}")
        return False

# --- Report Generation ---

def create_markdown_report(config, generation_results, analysis_results, stats, output_path):
    """
    Generate a comprehensive Markdown report of the generation process,
    including separate sections for filtered and unfiltered analysis results.
    """
    logger = logging.getLogger("GenerationEngine")
    try:
        report_dir = os.path.dirname(output_path)
        os.makedirs(report_dir, exist_ok=True)

        with open(output_path, 'w') as f:
            # --- Header ---
            f.write("# Synthetic Pollen Image Generation Report\n\n")
            f.write(f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")

            # --- Configuration ---
            f.write("## Configuration\n\n")
            f.write("### Paths\n")
            f.write(f"- **Generator model:** `{config.CHECKPOINT_PATH}`\n")
            f.write(f"- **Real data:** `{config.get_real_data_path()}`\n")
            f.write(f"- **Backgrounds:** `{config.BACKGROUND_DIR}`\n")
            f.write(f"- **Output directory:** `{config.OUTPUT_DIR}`\n\n")

            f.write("### Generation Parameters\n")
            f.write(f"- **Target pollen images:** {config.TARGET_POLLEN_IMAGES}\n")
            f.write(f"- **Target composed images:** {config.TARGET_COMPOSED_IMAGES}\n")
            f.write(f"- **Quality filtering:** {'Enabled' if config.USE_QUALITY_FILTERING else 'Disabled'}\n")
            if config.USE_QUALITY_FILTERING:
                 f.write(f"  - **Threshold Percentile:** {config.QUALITY_THRESHOLD_PERCENTILE}%\n")
                 f.write(f"  - **Surplus Factor:** {config.FILTERING_SURPLUS_FACTOR}x\n")
            f.write(f"- **Blending method:** {config.BLENDING_METHOD}\n")
            f.write(f"- **Statistics matching (Size):** {'Enabled' if config.USE_STATS_MATCHING_SIZE else 'Disabled'}\n")
            f.write(f"- **Statistics matching (Histogram):** {'Enabled' if config.USE_STATS_MATCHING_HISTOGRAM else 'Disabled'}\n")
            f.write(f"- **Parallel processing:** {'Enabled' if config.USE_PARALLEL_PROCESSING else 'Disabled'}\n\n")

            f.write("### Analysis Parameters\n")
            f.write(f"- **Analysis Sample Size:** {config.ANALYSIS_SAMPLE_SIZE}\n")
            f.write(f"- **PRDC Manifold k:** {config.MANIFOLD_K}\n")
            f.write(f"- **PRDC Distance Multiplier:** {config.DISTANCE_MULTIPLIER}\n")
            f.write(f"- **PRDC Fidelity k:** {config.PR_K_VALUE}\n\n")


            # --- Generation Results ---
            f.write("## Generation Run Summary\n\n")
            if 'pollen_generated' in generation_results:
                f.write(f"- **Pollen images generated:** {generation_results['pollen_generated']} / {config.TARGET_POLLEN_IMAGES}\n")
            if 'composed_generated' in generation_results:
                f.write(f"- **Composed images generated:** {generation_results['composed_generated']} / {config.TARGET_COMPOSED_IMAGES}\n")

            if 'total_time' in generation_results:
                total_seconds = generation_results['total_time']
                hours = int(total_seconds // 3600)
                minutes = int((total_seconds % 3600) // 60)
                seconds = total_seconds % 60
                f.write(f"- **Total time:** {hours}h {minutes}m {seconds:.3f}s\n")

                # Calculate generation rates
                if 'pollen_generated' in generation_results and generation_results['pollen_generated'] > 0 and total_seconds > 0:
                    rate = generation_results['pollen_generated'] / total_seconds
                    f.write(f"- **Pollen generation rate:** {rate:.3f} images/second\n")
                if 'composed_generated' in generation_results and generation_results['composed_generated'] > 0 and total_seconds > 0:
                    rate = generation_results['composed_generated'] / total_seconds
                    f.write(f"- **Composed generation rate:** {rate:.3f} images/second\n")

            # --- Peak Memory ---
            if 'peak_memory' in generation_results:
                f.write("\n### Peak Memory Usage During Run\n")
                f.write(f"- **RAM:** {generation_results['peak_memory'].get('ram_used_percent', 'N/A'):.1f}%\n")
                gpu_perc = generation_results['peak_memory'].get('gpu_used_percent', 'N/A')
                gpu_used = generation_results['peak_memory'].get('gpu_used_gb', 'N/A')
                gpu_total = generation_results['peak_memory'].get('gpu_total_gb', 'N/A')
                if gpu_perc != 'N/A':
                    f.write(f"- **GPU:** {gpu_perc:.1f}% ({gpu_used:.3f} / {gpu_total:.3f} GB)\n")
                else:
                     f.write("- **GPU:** N/A (CUDA not available or error)\n")

                memory_chart = os.path.join(config.OUTPUT_DIR, config.STATS_SUBDIR, "memory_usage.png")
                if os.path.exists(memory_chart):
                    rel_path = os.path.relpath(memory_chart, report_dir)
                    f.write(f"\n![Memory Usage Over Time](./{rel_path.replace(os.sep, '/')})\n") # Ensure relative path uses forward slashes

            # --- Sample Images ---
            f.write("\n## Sample Generated Images\n\n")
            samples_path = os.path.join(config.OUTPUT_DIR, config.GRAPH_SUBDIR, "sample_visualization.png")
            if os.path.exists(samples_path):
                rel_path = os.path.relpath(samples_path, report_dir)
                f.write(f"![Sample Images](./{rel_path.replace(os.sep, '/')})\n\n")
            else:
                 f.write("*Sample visualization image not found.*\n\n")


            # --- Statistical Analysis Section ---
            f.write("## Statistical Analysis of Generated Images\n\n")

            analysis_performed = analysis_results is not None and \
                                 (analysis_results.get('prdc_metrics_filtered') is not None or \
                                  analysis_results.get('prdc_metrics_unfiltered') is not None)

            if not analysis_performed:
                 f.write("*Analysis was skipped or failed (check logs).*\n\n")
            else: # if analysis_performed:
                # --- Analysis Subsection: Filtered ---
                f.write("### Analysis on Filtered Generated Images\n\n")
                f.write("*Note: Filtered images are those saved to the output directory if quality filtering was enabled.*\n\n")

                # Link plots - Filtered
                tsne_path_f = os.path.join(config.OUTPUT_DIR, config.GRAPH_SUBDIR, "tsne_visualization_filtered.png")
                if os.path.exists(tsne_path_f):
                    rel_path = os.path.relpath(tsne_path_f, report_dir)
                    f.write(f"**t-SNE Visualization (Filtered):**\n![t-SNE Filtered](./{rel_path.replace(os.sep, '/')})\n\n")
                else:
                     f.write("**t-SNE Visualization (Filtered):** *Plot not generated.*\n\n")

                umap_path_f = os.path.join(config.OUTPUT_DIR, config.GRAPH_SUBDIR, "umap_visualization_filtered.png")
                if os.path.exists(umap_path_f):
                    rel_path = os.path.relpath(umap_path_f, report_dir)
                    f.write(f"**UMAP Visualization (Filtered):**\n![UMAP Filtered](./{rel_path.replace(os.sep, '/')})\n\n")
                elif config.VISUALIZE_UMAP: # Only mention if attempted
                     f.write("**UMAP Visualization (Filtered):** *Plot not generated (UMAP may not be installed).*\n\n")


                # PRDC Metrics - Filtered
                metrics_f = analysis_results.get('prdc_metrics_filtered')
                if metrics_f:
                    radar_path_f = os.path.join(config.OUTPUT_DIR, config.GRAPH_SUBDIR, "prdc_radar_filtered.png")
                    if os.path.exists(radar_path_f):
                         rel_path = os.path.relpath(radar_path_f, report_dir)
                         f.write(f"**PRDC Metrics Radar (Filtered - Individual):**\n![Radar Filtered](./{rel_path.replace(os.sep, '/')})\n\n")
                    else:
                         f.write("**PRDC Metrics Radar (Filtered):** *Plot not generated.*\n\n")

                    # Individual Table
                    if metrics_f.get('individual'):
                        f.write("**Individual Distances Method (Filtered):**\n")
                        f.write("| Metric    | Value     |\n")
                        f.write("| :-------- | :-------- |\n")
                        for metric, value in metrics_f['individual'].items():
                            f.write(f"| {metric.capitalize():<10} | {value:.4f}    |\n")
                        f.write("\n")
                    # Mean Table
                    if metrics_f.get('mean'):
                        f.write("**Mean Distances Method (Filtered):**\n")
                        f.write("| Metric    | Value     |\n")
                        f.write("| :-------- | :-------- |\n")
                        for metric, value in metrics_f['mean'].items():
                            f.write(f"| {metric.capitalize():<10} | {value:.4f}    |\n")
                        f.write("\n")
                else:
                    f.write("**PRDC Metrics (Filtered):** *Not calculated or failed.*\n\n")

                # ... (existing Filtered section plots/metrics) ...

                # --- Add Average Critic Score (Filtered) --- <-- NEW
                avg_score_f = analysis_results.get('avg_critic_score_filtered')
                if avg_score_f is not None:
                    f.write(f"**Average Critic Score (Filtered):** {avg_score_f:.4f}\n\n")
                else:
                    f.write("**Average Critic Score (Filtered):** *Not calculated.*\n\n")

                # --- Analysis Subsection: Unfiltered ---
                f.write("---\n") # Separator
                f.write("### Analysis on Unfiltered Generated Images\n\n")
                f.write("*Note: Unfiltered images were generated on-the-fly specifically for this analysis.*\n\n")

                # Link plots - Unfiltered
                tsne_path_u = os.path.join(config.OUTPUT_DIR, config.GRAPH_SUBDIR, "tsne_visualization_unfiltered.png")
                if os.path.exists(tsne_path_u):
                    rel_path = os.path.relpath(tsne_path_u, report_dir)
                    f.write(f"**t-SNE Visualization (Unfiltered):**\n![t-SNE Unfiltered](./{rel_path.replace(os.sep, '/')})\n\n")
                else:
                     f.write("**t-SNE Visualization (Unfiltered):** *Plot not generated.*\n\n")

                umap_path_u = os.path.join(config.OUTPUT_DIR, config.GRAPH_SUBDIR, "umap_visualization_unfiltered.png")
                if os.path.exists(umap_path_u):
                    rel_path = os.path.relpath(umap_path_u, report_dir)
                    f.write(f"**UMAP Visualization (Unfiltered):**\n![UMAP Unfiltered](./{rel_path.replace(os.sep, '/')})\n\n")
                elif config.VISUALIZE_UMAP: # Only mention if attempted
                     f.write("**UMAP Visualization (Unfiltered):** *Plot not generated (UMAP may not be installed).*\n\n")

                # PRDC Metrics - Unfiltered
                metrics_u = analysis_results.get('prdc_metrics_unfiltered')
                if metrics_u:
                    radar_path_u = os.path.join(config.OUTPUT_DIR, config.GRAPH_SUBDIR, "prdc_radar_unfiltered.png")
                    if os.path.exists(radar_path_u):
                         rel_path = os.path.relpath(radar_path_u, report_dir)
                         f.write(f"**PRDC Metrics Radar (Unfiltered - Individual):**\n![Radar Unfiltered](./{rel_path.replace(os.sep, '/')})\n\n")
                    else:
                         f.write("**PRDC Metrics Radar (Unfiltered):** *Plot not generated.*\n\n")

                    # Individual Table
                    if metrics_u.get('individual'):
                        f.write("**Individual Distances Method (Unfiltered):**\n")
                        f.write("| Metric    | Value     |\n")
                        f.write("| :-------- | :-------- |\n")
                        for metric, value in metrics_u['individual'].items():
                            f.write(f"| {metric.capitalize():<10} | {value:.4f}    |\n")
                        f.write("\n")
                    # Mean Table
                    if metrics_u.get('mean'):
                        f.write("**Mean Distances Method (Unfiltered):**\n")
                        f.write("| Metric    | Value     |\n")
                        f.write("| :-------- | :-------- |\n")
                        for metric, value in metrics_u['mean'].items():
                            f.write(f"| {metric.capitalize():<10} | {value:.4f}    |\n")
                        f.write("\n")
                else:
                    f.write("**PRDC Metrics (Unfiltered):** *Not calculated or failed.*\n\n")

                # ... (existing Unfiltered section plots/metrics) ...

                # --- Add Average Critic Score (Unfiltered) --- <-- NEW
                avg_score_u = analysis_results.get('avg_critic_score_unfiltered')
                if avg_score_u is not None:
                    f.write(f"**Average Critic Score (Unfiltered):** {avg_score_u:.4f}\n\n")
                else:
                    f.write("**Average Critic Score (Unfiltered):** *Not calculated.*\n\n")

                # --- Combined Comparison ---
                f.write("---\n")
                f.write("### Comparison: Filtered vs. Unfiltered\n\n")
                comp_radar_path = os.path.join(config.OUTPUT_DIR, config.GRAPH_SUBDIR, "prdc_radar_comparison.png")
                if os.path.exists(comp_radar_path):
                    rel_path = os.path.relpath(comp_radar_path, report_dir)
                    f.write(f"**PRDC Comparison Radar (Individual Distances):**\n![Radar Comparison](./{rel_path.replace(os.sep, '/')})\n\n")
                elif metrics_f and metrics_u: # Only mention if data was available
                    f.write("**PRDC Comparison Radar:** *Plot not generated.*\n\n")
                else:
                    f.write("*Comparison plot requires successful calculation of both filtered and unfiltered metrics.*\n\n")


            # --- Real Data Statistics ---
            if stats:
                f.write("## Real Data Statistics Used\n\n")
                f.write(f"*Based on {stats.get('metadata', {}).get('num_images_analyzed', 'N/A')} images from `{stats.get('metadata', {}).get('data_path', 'N/A')}`*\n\n")

                if 'size_statistics' in stats:
                    f.write("### Size Distribution\n\n")
                    size_stats = stats['size_statistics']
                    f.write("| Dimension | Mean  | Std Dev | Min | Max |\n")
                    f.write("| :-------- | :---- | :------ | :-: | :-: |\n")
                    for dim in ['width', 'height']:
                        if dim in size_stats:
                            dim_stats = size_stats[dim]
                            f.write(f"| {dim.capitalize():<9} | {dim_stats.get('mean', 0):.1f} | {dim_stats.get('std', 0):.1f}   | {dim_stats.get('min', 0):<3} | {dim_stats.get('max', 0):<3} |\n")
                    f.write("\n")

                if 'pixel_statistics' in stats:
                    f.write("### Pixel Statistics\n\n")
                    pixel_stats = stats['pixel_statistics']
                    f.write("| Property   | Mean  | Std Dev |\n")
                    f.write("| :--------- | :---- | :------ |\n")
                    for prop in ['brightness', 'contrast']:
                        if prop in pixel_stats:
                            prop_stats = pixel_stats[prop]
                            f.write(f"| {prop.capitalize():<10} | {prop_stats.get('mean', 0):.1f} | {prop_stats.get('std', 0):.1f}   |\n")
                    f.write("\n")
            else:
                 f.write("## Real Data Statistics Used\n\n")
                 f.write("*Real data statistics were not calculated or loaded.*\n\n")

            # --- Conclusion ---
            f.write("## Conclusion\n\n")
            f.write("The generation process has completed. Review the analysis results and sample images to assess the quality and diversity of the generated dataset.\n")

            f.write("\n---\n")
            f.write(f"*Report generated by Generation Engine v0.0 on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*\n")

        logger.info(f"Markdown report saved successfully to {output_path}")
        return True

    except Exception as e:
        logger.error(f"Error generating Markdown report: {e}", exc_info=True)
        return False

def create_performance_summary(config, generation_results, analysis_results, logger):
    """
    Create summary charts and statistics for the generation process,
    including analysis metric summaries.
    """
    try:
        stats_dir = os.path.join(config.OUTPUT_DIR, config.STATS_SUBDIR)
        os.makedirs(stats_dir, exist_ok=True)

        report_path = os.path.join(stats_dir, "performance_summary.txt")
        with open(report_path, 'w') as f:
            f.write("=" * 60 + "\n")
            f.write(" GENERATION ENGINE PERFORMANCE SUMMARY\n")
            f.write("=" * 60 + "\n\n")

            f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Device: {config.DEVICE}\n")
            f.write(f"Output Directory: {config.OUTPUT_DIR}\n\n")

            f.write("--- GENERATION RESULTS ---\n")
            total_time = generation_results.get('total_time', 0)
            f.write(f"- Total runtime: {total_time:.3f} seconds ({total_time / 3600:.3f} hours)\n")
            f.write(f"- Individual pollen images generated: {generation_results.get('pollen_generated', 0)} / {config.TARGET_POLLEN_IMAGES}\n")
            f.write(f"- Composed images generated: {generation_results.get('composed_generated', 0)} / {config.TARGET_COMPOSED_IMAGES}\n\n")

            # Performance Metrics
            if total_time > 0:
                f.write("--- PERFORMANCE METRICS ---\n")
                pollen_gen = generation_results.get('pollen_generated', 0)
                composed_gen = generation_results.get('composed_generated', 0)
                if pollen_gen > 0:
                    f.write(f"- Individual pollen generation rate: {pollen_gen / total_time:.3f} images/second\n")
                if composed_gen > 0:
                    f.write(f"- Composed image generation rate: {composed_gen / total_time:.3f} images/second\n")
                f.write("\n")
            #
            # --- ANALYSIS METRICS SUMMARY ---
            f.write("--- ANALYSIS METRICS SUMMARY ---\n")
            # Check if any analysis results exist at all
            if analysis_results and any(analysis_results.values()):
                # PRDC Summary (Filtered)
                metrics_f = analysis_results.get('prdc_metrics_filtered')
                if metrics_f and metrics_f.get('individual'):
                    m = metrics_f['individual']
                    f.write(f"- Filtered (PRDC Indiv): P={m.get('precision',0):.3f}, R={m.get('recall',0):.3f}, D={m.get('density',0):.4f}, C={m.get('coverage',0):.3f}\n")
                elif metrics_f:
                    f.write(f"- Filtered (PRDC torch?): P={metrics_f.get('precision',0):.3f}, R={metrics_f.get('recall',0):.3f}, D={metrics_f.get('density',0):.4f}, C={metrics_f.get('coverage',0):.3f}\n")
                else:
                    f.write("- Filtered (PRDC):        Not calculated or failed.\n")

                # PRDC Summary (Unfiltered)
                metrics_u = analysis_results.get('prdc_metrics_unfiltered')
                if metrics_u and metrics_u.get('individual'):
                    m = metrics_u['individual']
                    f.write(f"- Unfiltered (PRDC Indiv): P={m.get('precision',0):.3f}, R={m.get('recall',0):.3f}, D={m.get('density',0):.4f}, C={m.get('coverage',0):.3f}\n")
                elif metrics_u:
                    f.write(f"- Unfiltered (PRDC torch?):P={metrics_u.get('precision',0):.3f}, R={metrics_u.get('recall',0):.3f}, D={metrics_u.get('density',0):.4f}, C={metrics_u.get('coverage',0):.3f}\n")
                else:
                    f.write("- Unfiltered (PRDC):      Not calculated or failed.\n")

                # FID/KID Summary
                fid_kid_f = analysis_results.get('fid_kid_filtered')
                fid_f = f"{fid_kid_f.get('fid', 'N/A'):.3f}" if fid_kid_f and fid_kid_f.get('fid') is not None else 'N/A'
                kid_f = f"{fid_kid_f.get('kid', 'N/A'):.3f}" if fid_kid_f and fid_kid_f.get('kid') is not None else 'N/A'
                f.write(f"- Filtered (FID/KID):     FID={fid_f}, KID={kid_f}\n")

                fid_kid_u = analysis_results.get('fid_kid_unfiltered')
                fid_u = f"{fid_kid_u.get('fid', 'N/A'):.3f}" if fid_kid_u and fid_kid_u.get('fid') is not None else 'N/A'
                kid_u = f"{fid_kid_u.get('kid', 'N/A'):.3f}" if fid_kid_u and fid_kid_u.get('kid') is not None else 'N/A'
                f.write(f"- Unfiltered (FID/KID):   FID={fid_u}, KID={kid_u}\n")

                # --- Add Critic Score Summary --- <-- NEW
                score_f = analysis_results.get('avg_critic_score_filtered')
                score_f_str = f"{score_f:.4f}" if score_f is not None else "N/A"
                f.write(f"- Filtered Avg Critic:    {score_f_str}\n")

                score_u = analysis_results.get('avg_critic_score_unfiltered')
                score_u_str = f"{score_u:.4f}" if score_u is not None else "N/A"
                f.write(f"- Unfiltered Avg Critic:  {score_u_str}\n")
                # --- End Critic Score Summary ---

            else:
                f.write("- Analysis metrics were not calculated or analysis was skipped.\n")
            f.write("\n")
            #

            # --- CONFIGURATION SUMMARY ---
            f.write("--- CONFIGURATION SUMMARY ---\n")
            f.write(f"- Quality Filtering: {'Enabled' if config.USE_QUALITY_FILTERING else 'Disabled'} (Threshold: {config.QUALITY_THRESHOLD_PERCENTILE}%)\n")
            f.write(f"- Blending Method: {config.BLENDING_METHOD}\n")
            f.write(f"- Content-aware Masks: {'Enabled' if config.USE_CONTENT_AWARE_MASKS else 'Disabled'}\n")
            f.write(f"- Statistics Matching (Size): {'Enabled' if config.USE_STATS_MATCHING_SIZE else 'Disabled'}\n")
            f.write(f"- Statistics Matching (Hist): {'Enabled' if config.USE_STATS_MATCHING_HISTOGRAM else 'Disabled'}\n")
            f.write(f"- Parallel Processing Used: {'Enabled' if config.USE_PARALLEL_PROCESSING else 'Disabled'} (Workers: {config.NUM_WORKERS})\n\n")


            # Peak Memory
            if 'peak_memory' in generation_results:
                f.write("--- PEAK MEMORY USAGE ---\n")
                f.write(f"- RAM: {generation_results['peak_memory'].get('ram_used_percent', 'N/A'):.1f}%\n")
                gpu_perc = generation_results['peak_memory'].get('gpu_used_percent', 'N/A')
                gpu_used = generation_results['peak_memory'].get('gpu_used_gb', 'N/A')
                gpu_total = generation_results['peak_memory'].get('gpu_total_gb', 'N/A')
                if gpu_perc != 'N/A':
                    f.write(f"- GPU: {gpu_perc:.1f}% ({gpu_used:.3f} / {gpu_total:.3f} GB)\n")
                else:
                     f.write("- GPU: N/A\n")
                f.write("\n")


            f.write("=" * 60 + "\n")


        logger.info(f"Performance summary saved to {report_path}")

        # --- Generate Performance Visualization ---
        # (Keep the existing code for generating performance_visualization.png here)
        # ... (code to plot generation progress and memory usage) ...
        # Note: This plot doesn't include the analysis metrics directly, just runtime/memory.

    except Exception as e:
        logger.error(f"Error creating performance summary: {e}", exc_info=True)

# --- Worker Functions for Parallel Processing ---

def process_composed_batch(args):
    """Process a batch of composed images (sequential version for Jupyter compatibility)."""
    batch_idx, config_dict, bg_paths, real_stats, layout_stats = args

    # Configure logging
    logger = logging.getLogger(f"GenerationEngine.Worker-{batch_idx}")
    logger.setLevel(logging.INFO)

    # Use the config_dict directly for configuration values
    OUTPUT_DIR = config_dict["OUTPUT_DIR"]
    COMPOSED_SUBDIR = config_dict["COMPOSED_SUBDIR"]
    LABELS_SUBDIR = config_dict["LABELS_SUBDIR"]
    CHUNK_SIZE = config_dict["CHUNK_SIZE"]
    BG_SIZE = config_dict["BG_SIZE"]
    POLLEN_SIZE_BASE = config_dict["POLLEN_SIZE_BASE"]
    MARGIN = config_dict["MARGIN"]
    FILTERING_SURPLUS_FACTOR = config_dict.get("FILTERING_SURPLUS_FACTOR", 1.25)
    USE_QUALITY_FILTERING = config_dict["USE_QUALITY_FILTERING"]
    OBJECT_CLASS = config_dict["OBJECT_CLASS"]

    # --- MODIFIED HERE: Ensure DEVICE is a torch.device object ---
    device_name = config_dict.get("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
    try:
        DEVICE = torch.device(device_name)
    except Exception as e:
        logger.error(f"Failed to create torch.device from name '{device_name}': {e}. Defaulting to CPU.")
        DEVICE = torch.device("cpu")
    # --- End Modification ---

    # Define output paths directly
    composed_dir = os.path.join(OUTPUT_DIR, COMPOSED_SUBDIR)
    labels_dir = os.path.join(OUTPUT_DIR, LABELS_SUBDIR)

    # Reconstruct model parameters from config_dict
    NOISE_DIM = config_dict["NOISE_DIM"]
    CHANNELS_IMG = config_dict["CHANNELS_IMG"]
    G_FEATURES = config_dict["G_FEATURES"]
    C_FEATURES = config_dict["C_FEATURES"]
    CHECKPOINT_PATH = config_dict["CHECKPOINT_PATH"]

    # Load models from checkpoint (Generator and Critic)
    # Ensure models are loaded onto the correct DEVICE object determined above
    generator = Generator(
        noise_dim=NOISE_DIM,
        channels_img=CHANNELS_IMG,
        features_g=G_FEATURES
    ).to(DEVICE) # Use the DEVICE object

    critic = None
    if USE_QUALITY_FILTERING:
        critic = Critic(
            channels_img=CHANNELS_IMG,
            features_c=C_FEATURES
        ).to(DEVICE) # Use the DEVICE object

    # ... (rest of the model loading try...except block remains the same) ...
    try:
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False) # Load to target device
        if 'generator_state_dict' in checkpoint:
            generator.load_state_dict(checkpoint['generator_state_dict'])
        elif 'model_state_dict' in checkpoint:
             generator.load_state_dict(checkpoint['model_state_dict'])
        else:
             raise KeyError("Generator state dict not found")

        if USE_QUALITY_FILTERING and critic is not None:
             try:
                 if 'critic_state_dict' in checkpoint:
                     critic.load_state_dict(checkpoint['critic_state_dict'])
                 else:
                     logger.warning("No critic_state_dict found, quality filtering disabled in worker")
                     USE_QUALITY_FILTERING = False
                     critic = None
             except Exception as e:
                 logger.warning(f"Failed to load critic in worker: {e}")
                 USE_QUALITY_FILTERING = False
                 critic = None

        generator.eval()
        if critic is not None: critic.eval()
        logger.info(f"Worker {batch_idx}: Models loaded to {DEVICE}")

    except Exception as e:
        logger.error(f"Worker {batch_idx}: Error loading models: {e}")
        return [], [], []


    # Track results
    generated_images = []
    generated_labels = []
    timestamps = []

    # Create a minimal Config-like object for functions that need it
    # Pass the correct DEVICE object type to mini_config
    class SimpleConfig:
        def __init__(self):
            # Fill in all needed attributes
            self.USE_QUALITY_FILTERING = USE_QUALITY_FILTERING
            self.USE_CONTINUOUS_ROTATION = False
            self.USE_RANDOM_SCALES = config_dict.get("USE_RANDOM_SCALES", True)
            self.USE_STATS_MATCHING_SIZE = config_dict.get("USE_STATS_MATCHING_SIZE", True)
            self.SCALE_RANGE = config_dict.get("SCALE_RANGE", (0.75, 1.25))
            self.BLENDING_METHOD = config_dict.get("BLENDING_METHOD", "poisson")
            self.USE_CONTENT_AWARE_MASKS = config_dict.get("USE_CONTENT_AWARE_MASKS", True)
            self.FEATHER_AMOUNT = config_dict.get("FEATHER_AMOUNT", 15)
            self.PYRAMID_LEVELS = config_dict.get("PYRAMID_LEVELS", 4)
            self.DEVICE = DEVICE # <-- Assign the torch.device object
            self.NOISE_DIM = NOISE_DIM
            self.BATCH_SIZE = config_dict.get("BATCH_SIZE", 64)
            # Make sure SCORING_BATCH_SIZE is available if score_pollen_batch uses it via config
            self.SCORING_BATCH_SIZE = config_dict.get("SCORING_BATCH_SIZE", 128)
            self.POLLEN_SIZE_BASE = POLLEN_SIZE_BASE
            self.FILTERING_SURPLUS_FACTOR = FILTERING_SURPLUS_FACTOR
            self.QUALITY_THRESHOLD_PERCENTILE = config_dict.get("QUALITY_THRESHOLD_PERCENTILE", 80.0)
            # Add analysis parameters needed by called functions (if any)
            self.LOAD_CRITIC_TO_CPU = config_dict.get("LOAD_CRITIC_TO_CPU", False) # Needed by score_pollen_batch

    mini_config = SimpleConfig()

    # Process each background
    for i, bg_path in enumerate(bg_paths):
        try:
            global_idx = batch_idx * CHUNK_SIZE + i

            bg_img = Image.open(bg_path).convert('L')
            if bg_img.size != (BG_SIZE, BG_SIZE):
                bg_img = bg_img.resize((BG_SIZE, BG_SIZE), Image.LANCZOS)
            bg_img_rgb = convert_to_3channel(bg_img)

            pollen_count = determine_pollen_count(layout_stats, config_dict.get("AVG_POLLEN_PER_IMAGE", 14))
            positions = generate_non_overlapping_positions(pollen_count, BG_SIZE, POLLEN_SIZE_BASE, MARGIN)

            if not positions:
                logger.warning(f"Worker {batch_idx}: Could not generate valid positions for image {global_idx}")
                continue

            num_needed = len(positions)
            num_to_generate = num_needed
            if USE_QUALITY_FILTERING:
                num_to_generate = int(np.ceil(num_needed * FILTERING_SURPLUS_FACTOR))

            # Generate RAW pollen batch (128x128)
            pollen_batch_raw = generate_pollen_batch(generator, mini_config, num_to_generate)

            pollen_batch_to_place = [] # This will hold the final images to be placed
            if USE_QUALITY_FILTERING and critic is not None:
                scores_raw = score_pollen_batch(critic, pollen_batch_raw, mini_config) # Score raw images
                if scores_raw is not None and len(scores_raw) == len(pollen_batch_raw):
                    threshold = np.percentile(scores_raw, 100.0 - mini_config.QUALITY_THRESHOLD_PERCENTILE)
                    quality_indices = np.where(scores_raw >= threshold)[0]
                    sorted_quality_indices = quality_indices[np.argsort(scores_raw[quality_indices])[::-1]]
                    selected_indices = sorted_quality_indices[:num_needed]
                    # Select the raw images that passed
                    pollen_batch_to_place = [pollen_batch_raw[idx] for idx in selected_indices]
                else:
                     logger.warning(f"Worker {batch_idx}: Scoring failed for bg {global_idx}, using raw images.")
                     pollen_batch_to_place = pollen_batch_raw[:num_needed] # Fallback
            else:
                # If not filtering, just use the first num_needed raw images
                pollen_batch_to_place = pollen_batch_raw[:num_needed]

            # Adjust positions if filtering resulted in fewer images than planned
            if len(pollen_batch_to_place) < len(positions):
                positions = positions[:len(pollen_batch_to_place)]

            # Process each selected pollen image
            result_img = bg_img_rgb.copy()
            final_positions_and_sizes = []

            for j, (pos, pollen_img_raw) in enumerate(zip(positions, pollen_batch_to_place)):
                # Apply geometric transforms to the raw image before placing
                transformed_pollen = apply_geometric_transforms(pollen_img_raw, mini_config, real_stats)
                # Apply blending using the transformed image
                result_img = blend_pollen_onto_background(result_img, transformed_pollen, pos, mini_config)
                # Store final position and size of the *transformed* image
                final_positions_and_sizes.append((pos, transformed_pollen.size))

            # Generate filename, save composed image, save annotation
            bg_name = os.path.splitext(os.path.basename(bg_path))[0]
            filename = f"synthetic_{bg_name}_{global_idx+1:04d}.png"
            img_path = os.path.join(composed_dir, filename)
            label_path = os.path.join(labels_dir, f"{os.path.splitext(filename)[0]}.txt")

            result_img.save(img_path)

            annotation = generate_yolo_annotation(final_positions_and_sizes, BG_SIZE, OBJECT_CLASS)
            with open(label_path, 'w') as f:
                f.write(annotation)

            generated_images.append(img_path)
            generated_labels.append(label_path)
            timestamps.append(time.time())

            # Cleanup for this background image processing
            del pollen_batch_raw, pollen_batch_to_place, result_img, bg_img, bg_img_rgb, positions, final_positions_and_sizes
            if 'scores_raw' in locals(): del scores_raw
            if i % 5 == 0: # Less frequent cleanup within worker
                gc.collect()
                if DEVICE.type == 'cuda':
                    torch.cuda.empty_cache()

        except Exception as e:
            logger.error(f"Worker {batch_idx}: Error processing background {bg_path} (global idx {global_idx}): {e}", exc_info=True)

    # Final cleanup for the worker
    del generator, critic, mini_config, checkpoint
    gc.collect()
    if DEVICE.type == 'cuda':
        torch.cuda.empty_cache()

    return generated_images, generated_labels, timestamps

# --- Main Function ---
def generate_synthetic_dataset(config, logger):
    """Main function to generate the synthetic dataset."""
    try:
        start_time = time.time()
        logger.info("Starting synthetic dataset generation")
        
        # Setup memory monitoring if enabled
        memory_monitor_thread = None
        stop_monitor_event = threading.Event()
        
        if config.MONITOR_MEMORY:
            memory_monitor_thread = threading.Thread(
                target=memory_monitor,
                args=(config, logger, stop_monitor_event),
                daemon=True
            )
            memory_monitor_thread.start()
        
        # Create output directories
        logger.info("Creating output directories")
        output_dirs = create_output_directories(config)

        # --- NEW: Optional Cleanup ---
        if config.CLEAN_OUTPUT_BEFORE_RUN:
            logger.warning("CLEAN_OUTPUT_BEFORE_RUN is True. Deleting existing content in output subdirectories...")
            dirs_to_clean = [
                output_dirs['composed'],
                # Add others if needed (pollen? labels? graphs? stats?) but be careful
            ]
            for dir_path in dirs_to_clean:
                if os.path.exists(dir_path):
                    try:
                        # Remove contents, not the directory itself
                        for filename in os.listdir(dir_path):
                            file_path = os.path.join(dir_path, filename)
                            if os.path.isfile(file_path) or os.path.islink(file_path):
                                os.unlink(file_path)
                            elif os.path.isdir(file_path):
                                shutil.rmtree(file_path)
                        logger.info(f"Cleaned directory: {dir_path}")
                    except Exception as e:
                        logger.error(f"Failed to clean directory {dir_path}: {e}")
                else:
                     logger.warning(f"Directory not found for cleaning: {dir_path}")
        # --- End Cleanup ---
        
        # Load models
        logger.info("Loading generator and critic models")
        generator, critic, _ = load_models(config, logger)
        
        # Calculate or load real data statistics
        real_stats = None
        if config.CALCULATE_REAL_STATS:
            logger.info("Calculating statistics from real data")
            real_stats = calculate_real_data_stats(config, logger)
        
        # Load layout statistics
        layout_stats = None
        if config.GENERATE_COMPOSED_IMAGES:
            logger.info("Loading layout statistics")
            layout_stats = load_layout_statistics(config, logger)
        
        # Track results
        results = {
            'pollen_generated': 0,
            'composed_generated': 0,
            'pollen_timestamps': [],
            'composed_timestamps': [],
            'pollen_sizes': [],
            'peak_memory': get_memory_usage(),
            # Initialize these to None so they're always in the results dictionary
            'real_features_full': None,
            'filtered_gen_features_full': None, 
            'unfiltered_gen_features_full': None,
            'analysis_results': {}
        }

        # ---> INITIALIZATION HERE <---
        real_features_full = None
        filtered_gen_features_full = None
        unfiltered_gen_features_full = None
        analysis_results = {} # Also initialize the analysis results dict
        # ---> END INITIALIZATION <---
        
        all_pollen_paths = []
        all_composed_paths = []
        all_label_paths = []
        
        # --- Generate individual pollen images ---
        all_pollen_data = [] # <-- NEW: Will store {'path': path, 'score': score} dicts
        if config.GENERATE_INDIVIDUAL_POLLEN:
            logger.info(f"Generating {config.TARGET_POLLEN_IMAGES} individual pollen images")
            pollen_dir = output_dirs['pollen']
    
            pollen_generated_count = 0
            with tqdm(total=config.TARGET_POLLEN_IMAGES, desc="Generating pollen") as pbar:
                while pollen_generated_count < config.TARGET_POLLEN_IMAGES:
                    # Check memory periodically
                    if pollen_generated_count > 0 and pollen_generated_count % (config.BATCH_SIZE * config.MEMORY_CHECK_INTERVAL) == 0:
                        memory_stats = get_memory_usage()
                        results['peak_memory'] = max_memory_stats(results['peak_memory'], memory_stats)
                        if not check_memory_safe(config, logger):
                            logger.warning("Memory usage high during pollen generation, performing cleanup")
                            force_memory_cleanup(config)
    
                    # Determine batch size for this iteration
                    num_needed_in_loop = config.TARGET_POLLEN_IMAGES - pollen_generated_count
                    batch_size_actual = min(config.BATCH_SIZE, num_needed_in_loop)
    
                    # Determine number to generate initially (might be more if filtering)
                    num_to_generate_raw = batch_size_actual
                    if config.USE_QUALITY_FILTERING:
                        # Calculate how many raw images we need to generate to likely get 'batch_size_actual' after filtering
                        # Example: If threshold is 80% (keep top 80%), we need 1/0.8 = 1.25x
                        # Use FILTERING_SURPLUS_FACTOR for consistency
                        num_to_generate_raw = int(np.ceil(batch_size_actual * config.FILTERING_SURPLUS_FACTOR))
                        # Ensure we generate at least batch_size_actual + a few extra if factor is small
                        num_to_generate_raw = max(num_to_generate_raw, batch_size_actual + 2)
    
                    # Generate raw 128x128 batch
                    raw_pollen_batch = generate_pollen_batch(generator, config, num_to_generate_raw)
    
                    # Score the RAW batch
                    raw_scores = None
                    if config.USE_QUALITY_FILTERING and critic is not None:
                         raw_scores = score_pollen_batch(critic, raw_pollen_batch, config)
                    elif critic is not None: # Score even if not filtering, to store the score
                         logger.debug("Scoring raw batch even though filtering is off.")
                         raw_scores = score_pollen_batch(critic, raw_pollen_batch, config)
                    else:
                         logger.warning("Critic not available, cannot score generated images.")
                         raw_scores = np.zeros(len(raw_pollen_batch)) # Assign dummy scores if critic missing
    
                    images_to_process = [] # List of tuples: (raw_img, score)
                    if config.USE_QUALITY_FILTERING and critic is not None:
                        # --- Filtering Logic ---
                        # Calculate threshold based on percentile of scores IN THIS BATCH
                        threshold = np.percentile(raw_scores, 100.0 - config.QUALITY_THRESHOLD_PERCENTILE)
                        # Get indices and scores of images above threshold
                        quality_indices = np.where(raw_scores >= threshold)[0]
                        # Sort by score (highest first)
                        sorted_quality_indices = quality_indices[np.argsort(raw_scores[quality_indices])[::-1]]
                        # Take the top 'batch_size_actual' needed for the target count
                        selected_indices = sorted_quality_indices[:batch_size_actual]
    
                        # Add the selected raw images and their scores to the list to process
                        for idx in selected_indices:
                            images_to_process.append((raw_pollen_batch[idx], raw_scores[idx]))
                        logger.debug(f"Batch Filtering: Kept {len(selected_indices)}/{num_to_generate_raw} images.")
    
                    else:
                        # --- No Filtering ---
                        # Process the first 'batch_size_actual' generated raw images
                        for i in range(min(batch_size_actual, len(raw_pollen_batch))):
                             images_to_process.append((raw_pollen_batch[i], raw_scores[i]))
    
                    # Process (transform, save) and store data for the selected images
                    for raw_img, score in images_to_process:
                        if pollen_generated_count >= config.TARGET_POLLEN_IMAGES:
                            break # Stop if target reached mid-batch
    
                        # Apply geometric transforms to the raw image
                        transformed_img = apply_geometric_transforms(raw_img, config, real_stats)
    
                        # Save the TRANSFORMED image
                        filename = f"pollen_synthetic_{pollen_generated_count+1:06d}.png"
                        save_path = os.path.join(pollen_dir, filename)
                        transformed_img.save(save_path)
    
                        # Store path and the score of the corresponding RAW image
                        all_pollen_data.append({
                            'path': save_path,
                            'score': float(score) if score is not None else None # Store score as float
                        })
    
                        # Track results
                        pollen_generated_count += 1
                        results['pollen_timestamps'].append(time.time())
                        # Storing transformed size might still be useful for other stats
                        results['pollen_sizes'].append(transformed_img.size)
                        pbar.update(1)
    
                    # Memory cleanup after each raw batch generation/scoring/filtering
                    del raw_pollen_batch, raw_scores, images_to_process
                    force_memory_cleanup(config)
    
                    if pollen_generated_count >= config.TARGET_POLLEN_IMAGES:
                        break # Exit outer loop if target reached
    
            results['pollen_generated'] = pollen_generated_count
            logger.info(f"Generated {pollen_generated_count} individual pollen images and stored their raw scores.")
    
        # --- Generate composed images ---
        if config.GENERATE_COMPOSED_IMAGES:
            logger.info(f"Generating {config.TARGET_COMPOSED_IMAGES} composed images")
            
            # Load background images
            bg_paths = load_background_images(config, logger)
            
            # If we need more than available, use backgrounds multiple times
            if len(bg_paths) < config.TARGET_COMPOSED_IMAGES:
                multiplier = config.TARGET_COMPOSED_IMAGES // len(bg_paths) + 1
                bg_paths = bg_paths * multiplier
            
            # Shuffle backgrounds for variety
            random.shuffle(bg_paths)
            bg_paths = bg_paths[:config.TARGET_COMPOSED_IMAGES]
            
            # Process sequentially (safer in Jupyter)
            config.USE_PARALLEL_PROCESSING = False  # Force sequential processing
            logger.info("Processing composed images sequentially (safer in Jupyter)")
            
            # Convert config to dict
            config_dict = {k: v for k, v in vars(config).items() if k != 'DEVICE'}
            
            composed_count = 0
            all_composed_paths = []
            all_label_paths = []
            
            # Prepare backgrounds in chunks
            chunks = []
            chunk_size = config.CHUNK_SIZE
            for i in range(0, len(bg_paths), chunk_size):
                chunks.append(bg_paths[i:i+chunk_size])
            
            # Process each chunk sequentially
            for chunk_idx, bg_chunk in enumerate(tqdm(chunks, desc="Processing composed batches")):
                # Use process_composed_batch but call it directly
                img_paths, label_paths, timestamps = process_composed_batch(
                    (chunk_idx, config_dict, bg_chunk, real_stats, layout_stats)
                )
                
                all_composed_paths.extend(img_paths)
                all_label_paths.extend(label_paths)
                results['composed_timestamps'].extend(timestamps)
                composed_count += len(img_paths)
                
                # Update progress
                logger.info(f"Generated {composed_count}/{config.TARGET_COMPOSED_IMAGES} composed images")
                
                # Enforce memory cleanup between chunks
                force_memory_cleanup()
            
            # Update results
            results['composed_generated'] = composed_count
            logger.info(f"Generated {composed_count} composed images")
        
        """
        # --- Analysis and Visualization ---
        """

        # --- Analysis and Visualization ---

        # Near the beginning of your analysis section
        logger.info(f"DEBUG: Analysis section - VISUALIZE_TSNE={config.VISUALIZE_TSNE}, VISUALIZE_UMAP={config.VISUALIZE_UMAP}")
        logger.info(f"DEBUG: VISUALIZATION_SAMPLE_SIZE={config.VISUALIZATION_SAMPLE_SIZE}, ANALYSIS_SAMPLE_SIZE={config.ANALYSIS_SAMPLE_SIZE}")

        #
        # --- Analysis and Visualization ---
        if config.PERFORM_ANALYSIS and (all_pollen_data or all_composed_paths): # <-- Use all_pollen_data
            logger.info("="*40 + " STARTING ANALYSIS " + "="*40)
            analysis_start_time = time.time()
            graph_dir = output_dirs['graphs']
            stats_dir = output_dirs['stats'] # For temp files used by torch-fidelity
            os.makedirs(graph_dir, exist_ok=True)
            os.makedirs(stats_dir, exist_ok=True)
        
            # Track analysis results separately
            analysis_results = {
                "prdc_metrics_filtered": None,
                "prdc_metrics_unfiltered": None,
                "fid_kid_filtered": None,
                "fid_kid_unfiltered": None,
                "avg_critic_score_filtered": None,
                "avg_critic_score_unfiltered": None
            }
        
            # --- Common Setup for Analysis ---
            inception_model = None
            real_features_full = None
            prdc_possible = NEIGHBORS_AVAILABLE or TORCH_FIDELITY_AVAILABLE
            fid_kid_possible = FID_AVAILABLE and SKLEARN_KERNELS_AVAILABLE
            critic_available_for_scoring = critic is not None # Critic needs to exist for unfiltered scoring
        
            # Check if any analysis requiring features is possible and enabled
            feature_analysis_needed = (config.MEASURE_PRDC and prdc_possible) or \
                                      ((config.MEASURE_FID or config.MEASURE_KID) and fid_kid_possible)
        
            analysis_possible = feature_analysis_needed or (config.USE_QUALITY_FILTERING and critic_available_for_scoring)
        
            if feature_analysis_needed:
                if FID_AVAILABLE:
                    logger.info("Loading InceptionV3 model for analysis...")
                    try:
                        inception_model = get_inception_model(config)
                    except Exception as e:
                        logger.error(f"Failed to load InceptionV3 model: {e}. Skipping feature-based analysis (FID/KID/PRDC).")
                        prdc_possible = False
                        fid_kid_possible = False
                        feature_analysis_needed = False # Can't do feature analysis
                else:
                    logger.warning("InceptionV3 (pytorch-fid) not available. Skipping feature-based analysis (FID/KID/PRDC).")
                    prdc_possible = False
                    fid_kid_possible = False
                    feature_analysis_needed = False
        
            # Extract real features only if needed for feature-based metrics
            if feature_analysis_needed: # Only extract if we can and need to calculate feature metrics
                logger.info("Extracting features from real images for comparison...")
                real_data_path = config.get_real_data_path()
                real_image_paths = []
                for ext in ['.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff']:
                    real_image_paths.extend(list(Path(real_data_path).glob(f"*{ext}")))
        
                if real_image_paths:
                    metric_sample_size = min(config.ANALYSIS_SAMPLE_SIZE, len(real_image_paths))
                    if metric_sample_size < 100:
                         logger.warning(f"Not enough real images ({metric_sample_size}) for reliable analysis. Need >= 100. Skipping feature-based analysis.")
                         feature_analysis_needed = False
                         prdc_possible = False
                         fid_kid_possible = False
                    else:
                        random.seed(config.RANDOM_SEED)
                        np.random.seed(config.RANDOM_SEED)
                        real_metric_paths = random.sample(real_image_paths, metric_sample_size)
                        logger.info(f"Using {len(real_metric_paths)} real image samples for metric calculation.")
                        real_features_full = extract_features(real_metric_paths, inception_model, config, is_real=True)
        
                        if real_features_full is None:
                            logger.error("Failed to extract real features. Aborting feature-based analysis.")
                            feature_analysis_needed = False
                            prdc_possible = False
                            fid_kid_possible = False
                else:
                    logger.warning("No real images found for analysis comparison. Aborting feature-based analysis.")
                    feature_analysis_needed = False
                    prdc_possible = False
                    fid_kid_possible = False
        
            # <<< Start of the main analysis block >>>
            if analysis_possible:
        
                # --- Analysis 1: Using Saved (Potentially Filtered) Generated Pollen ---
                logger.info("--- Analysis 1: Saved (Potentially Filtered) Generated Images ---")
                filtered_gen_features_full = None # Initialize
                if not all_pollen_data:
                    logger.warning("No saved pollen data found (paths/scores), skipping filtered analysis.")
                else:
                    metric_gen_samples_f = min(config.ANALYSIS_SAMPLE_SIZE, len(all_pollen_data))
                    if metric_gen_samples_f < 100:
                        logger.warning(f"Not enough saved generated samples ({metric_gen_samples_f}) for filtered analysis (minimum 100 required).")
                    else:
                        random.seed(config.RANDOM_SEED)
                        np.random.seed(config.RANDOM_SEED)
                        # Sample the list of dictionaries
                        filtered_sample_data = random.sample(all_pollen_data, metric_gen_samples_f)
        
                        # --- Calculate Average Critic Score (Filtered) --- # <-- MODIFIED Block
                        # Extract scores from the sampled data
                        filtered_scores_sample = [item['score'] for item in filtered_sample_data if item['score'] is not None]
                        if filtered_scores_sample:
                            avg_score_f = np.mean(filtered_scores_sample)
                            analysis_results['avg_critic_score_filtered'] = float(avg_score_f)
                            logger.info(f"Average Critic Score (Filtered Sample): {avg_score_f:.4f}")
                        elif config.USE_QUALITY_FILTERING: # Only warn if filtering was on
                            logger.warning("No valid critic scores found for the filtered sample.")
                        # --- End Critic Score Block ---
        
                        # Extract features if needed for other metrics
                        if feature_analysis_needed and inception_model is not None:
                            # Extract paths from the sampled data
                            filtered_metric_paths = [item['path'] for item in filtered_sample_data]
                            logger.info(f"Extracting features from {len(filtered_metric_paths)} saved/filtered generated images for metrics...")
                            filtered_gen_features_full = extract_features(filtered_metric_paths, inception_model, config, is_real=False)
        
                            if filtered_gen_features_full is not None and real_features_full is not None:
                                # --- Subsample for VISUALIZATION (t-SNE/UMAP) ---
                                vis_samples_per_set = config.VISUALIZATION_SAMPLE_SIZE // 2
                                n_real_vis = min(vis_samples_per_set, len(real_features_full))
                                n_gen_vis = min(vis_samples_per_set, len(filtered_gen_features_full))
        
                                if n_real_vis > 0 and n_gen_vis > 0:
                                    # (Visualization logic remains the same - uses features)
                                    np.random.seed(config.RANDOM_SEED)
                                    real_vis_indices = np.random.choice(len(real_features_full), n_real_vis, replace=False)
                                    gen_vis_indices = np.random.choice(len(filtered_gen_features_full), n_gen_vis, replace=False)
                                    real_features_vis = real_features_full[real_vis_indices]
                                    filtered_gen_features_vis = filtered_gen_features_full[gen_vis_indices]
                                    logger.info(f"Subsampling {n_real_vis} real and {n_gen_vis} generated features for visualization.")
        
                                    # Generate t-SNE visualization (Filtered)
                                    if config.VISUALIZE_TSNE and TSNE_AVAILABLE:
                                        tsne_path_f = os.path.join(graph_dir, "tsne_visualization_filtered.png")
                                        logger.info(f"Generating t-SNE plot (Filtered): {tsne_path_f}")
                                        generate_tsne_plot(real_features_vis, filtered_gen_features_vis, "Real vs. Generated (Filtered)", tsne_path_f, config)
        
                                    # Generate UMAP visualization (Filtered)
                                    if config.VISUALIZE_UMAP and HAS_UMAP:
                                        umap_path_f = os.path.join(graph_dir, "umap_visualization_filtered.png")
                                        logger.info(f"Generating UMAP plot (Filtered): {umap_path_f}")
                                        generate_umap_plot(real_features_vis, filtered_gen_features_vis, "Real vs. Generated (Filtered)", umap_path_f, config)
        
                                    del real_features_vis, filtered_gen_features_vis, real_vis_indices, gen_vis_indices
                                else:
                                    logger.warning("Not enough features to create filtered visualizations.")
        
                                # --- Calculate METRICS (PRDC, FID, KID) ---
                                # (Metric calculation logic remains the same - uses features)
                                min_f_samples_metric = min(len(real_features_full), len(filtered_gen_features_full))
                                if min_f_samples_metric < 100:
                                     logger.warning(f"Too few samples ({min_f_samples_metric}) for reliable filtered metric calculation.")
                                else:
                                    real_features_f_metric = real_features_full[:min_f_samples_metric]
                                    filtered_gen_features_f_metric = filtered_gen_features_full[:min_f_samples_metric]
        
                                    # Calculate PRDC (Filtered)
                                    if config.MEASURE_PRDC and prdc_possible:
                                        if min_f_samples_metric < config.MANIFOLD_K + 1 and NEIGHBORS_AVAILABLE:
                                            logger.warning(f"Too few samples ({min_f_samples_metric}) for custom PRDC (k={config.MANIFOLD_K}). Skipping filtered PRDC.")
                                        else:
                                            logger.info(f"Calculating PRDC metrics using {min_f_samples_metric} real vs {min_f_samples_metric} filtered generated features.")
                                            prdc_metrics_f = calculate_prdc_metrics(real_features_f_metric, filtered_gen_features_f_metric, config)
                                            analysis_results['prdc_metrics_filtered'] = prdc_metrics_f
                                            radar_path_f = os.path.join(graph_dir, "prdc_radar_filtered.png")
                                            logger.info(f"Generating PRDC radar chart (Filtered): {radar_path_f}")
                                            generate_radar_chart(prdc_metrics_f, "Distribution Metrics (Filtered - Individual)", radar_path_f)
                                    else:
                                         logger.info("Skipping PRDC metric calculation (Filtered) as per config or dependencies.")
        
                                    # Calculate FID/KID (Filtered)
                                    if (config.MEASURE_FID or config.MEASURE_KID) and fid_kid_possible:
                                         logger.info(f"Calculating FID/KID using {min_f_samples_metric} real vs {min_f_samples_metric} filtered generated features.")
                                         fid_kid_f = calculate_fid_kid(real_features_f_metric, filtered_gen_features_f_metric, config, logger)
                                         analysis_results['fid_kid_filtered'] = fid_kid_f
                                    else:
                                         logger.info("Skipping FID/KID calculation (Filtered) as per config or dependencies.")
        
                                    del real_features_f_metric, filtered_gen_features_f_metric
        
                            else: # If feature extraction failed
                                 logger.warning("Feature extraction failed for filtered generated images. Skipping feature-based metrics.")
        
                        analysis_results["filtered_features_computed"] = filtered_gen_features_full is not None
                        # Add features to the main results dict to return them
                        results['analysis_results'] = analysis_results # Keep analysis metrics together
                        results['filtered_gen_features_full'] = filtered_gen_features_full

                        # Cleanup full features for this part if they exist
                        if filtered_gen_features_full is not None:
                            del filtered_gen_features_full
                        del filtered_sample_data # <-- NEW cleanup for sampled data
                        force_memory_cleanup(config) # Cleanup after Analysis 1
                        #
                        # --- CLEANUP POINT 1: After filtered analysis ---
                        #  (redacted this code block to avoid issues)
        
                # --- Analysis 2: Using Unfiltered Generated Pollen ---
                logger.info("--- Analysis 2: Unfiltered Generated Images ---")
                unfiltered_gen_features_full = None # Initialize
                num_unfiltered_samples_metric = config.ANALYSIS_SAMPLE_SIZE
                logger.info(f"Generating {num_unfiltered_samples_metric} new unfiltered images for analysis...")
        
                # (Generator device handling logic remains the same)
                gen_device = next(generator.parameters()).device
                temp_device_change = False
                if gen_device != config.DEVICE:
                    logger.info(f"Temporarily moving generator to {config.DEVICE} for unfiltered generation.")
                    generator.to(config.DEVICE)
                    temp_device_change = True
        
                unfiltered_images = [] # This list will hold raw 128x128 PIL images
                try:
                    batches_needed = math.ceil(num_unfiltered_samples_metric / config.BATCH_SIZE)
                    for _ in tqdm(range(batches_needed), desc="Generating unfiltered images"):
                        batch_size = min(config.BATCH_SIZE, num_unfiltered_samples_metric - len(unfiltered_images))
                        if batch_size <= 0: break
                        new_batch = generate_pollen_batch(generator, config, batch_size)
                        unfiltered_images.extend(new_batch)
                        if len(unfiltered_images) % (config.BATCH_SIZE * 5) == 0: force_memory_cleanup(config)
                except Exception as gen_err:
                     logger.error(f"Error during unfiltered image generation: {gen_err}", exc_info=True)
                finally:
                     if temp_device_change:
                         logger.info(f"Moving generator back to {gen_device}.")
                         generator.to(gen_device)
                         force_memory_cleanup(config)
        
                unfiltered_images = unfiltered_images[:num_unfiltered_samples_metric]
                logger.info(f"Generated {len(unfiltered_images)} unfiltered images.")
        
                if len(unfiltered_images) < 100:
                    logger.warning(f"Generated fewer than 100 unfiltered images ({len(unfiltered_images)}). Skipping unfiltered analysis.")
                else:
                    # --- Calculate Average Critic Score (Unfiltered) --- # <-- Keep this block
                    if critic_available_for_scoring:
                        logger.info(f"Calculating average critic score for {len(unfiltered_images)} unfiltered samples...")
                        try:
                            # Score the raw unfiltered images
                            unfiltered_scores = score_pollen_batch(critic, unfiltered_images, config)
                            if unfiltered_scores is not None and len(unfiltered_scores) > 0:
                                 avg_score_u = np.mean(unfiltered_scores)
                                 analysis_results['avg_critic_score_unfiltered'] = float(avg_score_u)
                                 logger.info(f"Average Critic Score (Unfiltered): {avg_score_u:.4f}")
                            else:
                                 logger.warning("No scores calculated for unfiltered images.")
                            del unfiltered_scores
                        except Exception as score_err:
                             logger.error(f"Error calculating critic scores for unfiltered images: {score_err}", exc_info=True)
                        finally:
                             force_memory_cleanup(config)
                    else:
                         logger.info("Skipping critic score calculation (Unfiltered) as critic is not available.")
                    # --- End Critic Score Block ---
        
                    # Extract features if needed
                    if feature_analysis_needed and inception_model is not None:
                        logger.info(f"Extracting features from {len(unfiltered_images)} unfiltered generated images for metrics...")
                        # Pass the raw 128x128 PIL images; extract_features handles resizing
                        unfiltered_gen_features_full = extract_features(unfiltered_images, inception_model, config, is_real=False)
        
                    # Cleanup PIL images (do this AFTER feature extraction AND scoring)
                    logger.debug("Cleaning up unfiltered PIL images from memory.")
                    del unfiltered_images
                    force_memory_cleanup(config)
        
                    if unfiltered_gen_features_full is not None and real_features_full is not None:
                        # --- Subsample for VISUALIZATION ---
                        # (Visualization logic remains the same)
                        vis_samples_per_set = config.VISUALIZATION_SAMPLE_SIZE // 2
                        n_real_vis = min(vis_samples_per_set, len(real_features_full))
                        n_gen_vis = min(vis_samples_per_set, len(unfiltered_gen_features_full))
                        if n_real_vis > 0 and n_gen_vis > 0:
                            np.random.seed(config.RANDOM_SEED)
                            real_vis_indices = np.random.choice(len(real_features_full), n_real_vis, replace=False)
                            gen_vis_indices = np.random.choice(len(unfiltered_gen_features_full), n_gen_vis, replace=False)
                            real_features_vis = real_features_full[real_vis_indices]
                            unfiltered_gen_features_vis = unfiltered_gen_features_full[gen_vis_indices]
                            logger.info(f"Subsampling {n_real_vis} real and {n_gen_vis} generated features for visualization.")
                            # t-SNE Unfiltered
                            if config.VISUALIZE_TSNE and TSNE_AVAILABLE:
                                tsne_path_u = os.path.join(graph_dir, "tsne_visualization_unfiltered.png")
                                logger.info(f"Generating t-SNE plot (Unfiltered): {tsne_path_u}")
                                generate_tsne_plot(real_features_vis, unfiltered_gen_features_vis, "Real vs. Generated (Unfiltered)", tsne_path_u, config)
                            # UMAP Unfiltered
                            if config.VISUALIZE_UMAP and HAS_UMAP:
                                umap_path_u = os.path.join(graph_dir, "umap_visualization_unfiltered.png")
                                logger.info(f"Generating UMAP plot (Unfiltered): {umap_path_u}")
                                generate_umap_plot(real_features_vis, unfiltered_gen_features_vis, "Real vs. Generated (Unfiltered)", umap_path_u, config)
                            del real_features_vis, unfiltered_gen_features_vis, real_vis_indices, gen_vis_indices
                        else:
                             logger.warning("Not enough features to create unfiltered visualizations.")
        
                        # --- Calculate METRICS (PRDC, FID, KID) ---
                        # (Metric calculation logic remains the same)
                        min_u_samples_metric = min(len(real_features_full), len(unfiltered_gen_features_full))
                        if min_u_samples_metric < 100:
                             logger.warning(f"Too few samples ({min_u_samples_metric}) for reliable unfiltered metric calculation.")
                        else:
                            real_features_u_metric = real_features_full[:min_u_samples_metric]
                            unfiltered_gen_features_u_metric = unfiltered_gen_features_full[:min_u_samples_metric]
                            # PRDC Unfiltered
                            if config.MEASURE_PRDC and prdc_possible:
                                if min_u_samples_metric < config.MANIFOLD_K + 1 and NEIGHBORS_AVAILABLE:
                                    logger.warning(f"Too few samples ({min_u_samples_metric}) for custom PRDC (k={config.MANIFOLD_K}). Skipping unfiltered PRDC.")
                                else:
                                    logger.info(f"Calculating PRDC metrics using {min_u_samples_metric} real vs {min_u_samples_metric} unfiltered generated features.")
                                    prdc_metrics_u = calculate_prdc_metrics(real_features_u_metric, unfiltered_gen_features_u_metric, config)
                                    analysis_results['prdc_metrics_unfiltered'] = prdc_metrics_u
                                    radar_path_u = os.path.join(graph_dir, "prdc_radar_unfiltered.png")
                                    logger.info(f"Generating PRDC radar chart (Unfiltered): {radar_path_u}")
                                    generate_radar_chart(prdc_metrics_u, "Distribution Metrics (Unfiltered - Individual)", radar_path_u)
                            else:
                                logger.info("Skipping PRDC metric calculation (Unfiltered) as per config or dependencies.")
                            # FID/KID Unfiltered
                            if (config.MEASURE_FID or config.MEASURE_KID) and fid_kid_possible:
                                 logger.info(f"Calculating FID/KID using {min_u_samples_metric} real vs {min_u_samples_metric} unfiltered generated features.")
                                 fid_kid_u = calculate_fid_kid(real_features_u_metric, unfiltered_gen_features_u_metric, config, logger)
                                 analysis_results['fid_kid_unfiltered'] = fid_kid_u
                            else:
                                 logger.info("Skipping FID/KID calculation (Unfiltered) as per config or dependencies.")
                            del real_features_u_metric, unfiltered_gen_features_u_metric
        
                        analysis_results["unfiltered_features_computed"] = unfiltered_gen_features_full is not None
                        # Add features to the main results dict to return them
                        results['analysis_results'] = analysis_results # Keep analysis metrics together
                        results['unfiltered_gen_features_full'] = unfiltered_gen_features_full

                        # Cleanup full features for this part if they exist
                        if unfiltered_gen_features_full is not None:
                             del unfiltered_gen_features_full
        
                    elif not feature_analysis_needed: # If only critic scoring was done
                         logger.info("Only critic scoring performed for unfiltered images as feature analysis was skipped or failed.")
                    else: # If feature extraction failed for unfiltered gen images
                        logger.warning("Feature extraction failed for unfiltered generated images. Skipping feature-based metrics.")

                force_memory_cleanup(config) # Cleanup after Analysis 2
                #
                # --- CLEANUP POINT 2: After unfiltered analysis ---
                #  (redacted this code block to avoid issues)
        
                # --- Generate Combined/Comparison Plots AFTER both analyses attempts ---
                # Combined PRDC Radar Chart
                metrics_f_prdc = analysis_results.get('prdc_metrics_filtered')
                metrics_u_prdc = analysis_results.get('prdc_metrics_unfiltered')
                if metrics_f_prdc and metrics_u_prdc:
                    combined_prdc_radar_path = os.path.join(graph_dir, "prdc_radar_comparison.png")
                    logger.info(f"Generating combined PRDC radar chart: {combined_prdc_radar_path}")
                    plot_combined_radar_chart(
                        metrics_f_prdc, metrics_u_prdc,
                        label1='Filtered', label2='Unfiltered',
                        title='PRDC Comparison (Individual Distances)',
                        output_path=combined_prdc_radar_path
                    )
                elif config.MEASURE_PRDC:
                    logger.warning("Skipping combined PRDC radar chart generation as one or both metric sets are missing.")
        
                # Combined FID/KID Bar Chart
                fid_kid_f = analysis_results.get('fid_kid_filtered')
                fid_kid_u = analysis_results.get('fid_kid_unfiltered')
                if fid_kid_f or fid_kid_u: # If at least one result exists
                     combined_fid_kid_path = os.path.join(graph_dir, "fid_kid_comparison.png")
                     logger.info(f"Generating FID/KID comparison chart: {combined_fid_kid_path}")
                     plot_fid_kid_comparison(analysis_results, combined_fid_kid_path, logger)
                elif config.MEASURE_FID or config.MEASURE_KID:
                     logger.warning("Skipping FID/KID comparison plot as no FID/KID results were calculated.")
                #
                # --- Sample Visualization ---
                if all_pollen_data and all_composed_paths: # Use all_pollen_data here
                    logger.info("Creating sample visualization")
                    sample_vis_path = os.path.join(graph_dir, "sample_visualization.png")
                    # Define how many samples the visualization function should display
                    num_samples = 4 # <-- DEFINE the variable here
                    # Extract more paths than needed initially, the function will select the first num_samples
                    # Ensure we don't try to slice more than available
                    num_paths_to_get = min(len(all_pollen_data), num_samples * 2)
                    pollen_paths_for_vis = [item['path'] for item in all_pollen_data[:num_paths_to_get]]
                    save_sample_visualization(
                        pollen_paths_for_vis, all_composed_paths, all_label_paths, # Pass paths list
                        sample_vis_path, num_samples=num_samples # Use the variable for clarity
                    )
                #
                # --- CLEANUP POINT 3: After comparison plots ---
                #  (redacted this code block to avoid issues)

                # Inside generate_synthetic_dataset, near the end of the analysis block
                # ... (existing analysis code) ...
                
                analysis_results["real_features_computed"] = real_features_full is not None
                # Add features to the main results dict to return them
                results['analysis_results'] = analysis_results # Keep analysis metrics together
                results['real_features_full'] = real_features_full
                
                # ... (rest of the function, e.g., report generation) ...
        
                # --- Final Cleanup of Analysis Resources ---
                logger.debug("Cleaning up final analysis resources.")
                if inception_model is not None:
                    del inception_model
                    logger.debug("Deleted Inception model.")
                if real_features_full is not None:
                    del real_features_full
                    logger.debug("Deleted full real features.")
                force_memory_cleanup(config) # One last cleanup
        
                analysis_duration = time.time() - analysis_start_time
                logger.info(f"Analysis and Visualization completed in {analysis_duration:.3f} seconds")
                logger.info("="*40 + " FINISHED ANALYSIS " + "="*40)
        
            # <<< End of the main analysis block gated by initial analysis_possible check >>>
            else: # If analysis was not possible from the start
                logger.warning("Analysis was skipped because dependencies were missing, initial setup failed, or not enough data.")
        
        # <<< This else belongs to the initial if config.PERFORM_ANALYSIS check >>>
        else:
            logger.info("Skipping analysis as per configuration or no images were generated.")
        
        # --- Finalize Results ---
        end_time = time.time()
        results['total_time'] = end_time - start_time
        
        # Get peak memory usage
        memory_stats = get_memory_usage()
        results['peak_memory'] = max_memory_stats(results['peak_memory'], memory_stats)
        
        # Create summary report
        logger.info("Creating performance summary")
        create_performance_summary(config, results, analysis_results, logger)
        
        # Generate markdown report
        if config.GENERATE_REPORTS:
            logger.info("Generating markdown report")
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            report_path = os.path.join(output_dirs['reports'], f"generation_report_{timestamp}.md")
            create_markdown_report(config, results, analysis_results, real_stats, report_path)
            logger.info(f"Report saved to {report_path}")
        
        # Save final configuration
        if config.SAVE_CHECKPOINTS:
            logger.info("Saving final configuration")
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            config_path = os.path.join(output_dirs['configs'], f"config_{timestamp}.json")
            config.save(config_path)
            logger.info(f"Configuration saved to {config_path}")
        
        # Stop memory monitoring
        if memory_monitor_thread is not None:
            stop_monitor_event.set()
            memory_monitor_thread.join(timeout=5)
        
        logger.info("=" * 80)
        logger.info("Synthetic dataset generation completed")
        logger.info(f"Total execution time: {results['total_time']:.3f} seconds")
        logger.info(f"Individual pollen images: {results['pollen_generated']}/{config.TARGET_POLLEN_IMAGES}")
        logger.info(f"Composed images: {results['composed_generated']}/{config.TARGET_COMPOSED_IMAGES}")
        logger.info("=" * 80)
        #
        # --- CLEANUP POINT 4: Before returning results ---
        #  (redacted this code block to avoid issues)
        
        return results
    
    except Exception as e:
        logger.error(f"Critical error in synthetic dataset generation: {e}")
        logger.error(traceback.format_exc())
        
        # Stop memory monitoring
        if 'stop_monitor_event' in locals() and stop_monitor_event is not None:
            stop_monitor_event.set()
        
        raise
 
# --- Utility Functions ---
def max_memory_stats(stats1, stats2):
    """Return the maximum values between two memory statistics dictionaries."""
    result = {}
    for key in stats1.keys():
        if key in stats2:
            result[key] = max(stats1[key], stats2[key])
        else:
            result[key] = stats1[key]
    return result

# --- Main Execution ---
if __name__ == "__main__":
    # Configure exception handling for better visibility in Jupyter
    configure_exception_handler()
    
    # Initialize configuration
    config = Config()

    # Uncomment to inspect the checkpoint structure:
    inspect_checkpoint(config.CHECKPOINT_PATH)

    # quantitatively inspect the generated surplus & the selected top 
    print("\n'FILTERING_SURPLUS_FACTOR' = '", config.FILTERING_SURPLUS_FACTOR, "'; ")
    print("\n'QUALITY_THRESHOLD_PERCENTILE' = '", config.QUALITY_THRESHOLD_PERCENTILE, "'; ")
    
    # Setup logging with more console output
    logger = setup_logging(config)
    logger.handlers[1].setLevel(logging.DEBUG)  # Console handler gets all messages

    # ... (setup config, logger) ...
    
    try:
        logger.info("=" * 80)
        logger.info("Starting Generation Engine v0.0")
        logger.info(f"Device: {config.DEVICE}")
        
        # Print hardware info
        logger.info("Hardware Information:")
        if torch.cuda.is_available():
            logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
            logger.info(f"CUDA Version: {torch.version.cuda}")
            gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
            logger.info(f"GPU Memory: {gpu_mem:.3f} GB")
        else:
            logger.info("No GPU detected, using CPU")
        
        # Log CPU info
        cpu_count = os.cpu_count() or 0
        logger.info(f"CPU Cores: {cpu_count}")
        
        # Log RAM info
        ram = psutil.virtual_memory()
        logger.info(f"Total RAM: {ram.total / (1024**3):.3f} GB")
        logger.info(f"Available RAM: {ram.available / (1024**3):.3f} GB")
        
        # Run main function
        results = generate_synthetic_dataset(config, logger)

        # --> ADD THESE LINES <--
        # Extract features into global scope if they exist in results
        real_features_full = results.get('real_features_full')
        filtered_gen_features_full = results.get('filtered_gen_features_full')
        unfiltered_gen_features_full = results.get('unfiltered_gen_features_full')
        analysis_results = results.get('analysis_results', {}) # Also get analysis results if needed elsewhere
        # ... (rest of the original main block logging) ...

        # --- PART 1: Add this to the __main__ block after generate_synthetic_dataset ---
        # Add this right after results = generate_synthetic_dataset(config, logger)
        # But before any logging about completion
        
        # --- Save important variables for KID comparison ---
        feature_save_dir = os.path.join(config.OUTPUT_DIR, config.STATS_SUBDIR, "temp_features_for_kid_comp")
        os.makedirs(feature_save_dir, exist_ok=True)
        logger.info(f"Saving features for KID comparison to: {feature_save_dir}")
        
        feature_paths = {}  # Store paths for loading later
        
        # Extract features from results
        real_features_full_main = results.get('real_features_full')
        filtered_gen_features_full_main = results.get('filtered_gen_features_full')
        unfiltered_gen_features_full_main = results.get('unfiltered_gen_features_full')
        
        # Save Real Features
        if real_features_full_main is not None:
            try:
                save_path = os.path.join(feature_save_dir, "real_features.npy")
                np.save(save_path, real_features_full_main)
                feature_paths['real'] = save_path
                logger.info(f"Saved real features to {save_path}")
                del real_features_full_main  # Delete immediately after saving
            except Exception as e:
                logger.error(f"Failed to save or delete real features: {e}", exc_info=True)
        else:
            logger.warning("Real features are None, cannot save for KID comparison")
        
        # Save Filtered Features
        if filtered_gen_features_full_main is not None:
            try:
                save_path = os.path.join(feature_save_dir, "filtered_gen_features.npy")
                np.save(save_path, filtered_gen_features_full_main)
                feature_paths['filtered'] = save_path
                logger.info(f"Saved filtered gen features to {save_path}")
                del filtered_gen_features_full_main  # Delete immediately after saving
            except Exception as e:
                logger.error(f"Failed to save or delete filtered gen features: {e}", exc_info=True)
        else:
            logger.warning("Filtered gen features are None, cannot save for KID comparison")
        
        # Save Unfiltered Features
        if unfiltered_gen_features_full_main is not None:
            try:
                save_path = os.path.join(feature_save_dir, "unfiltered_gen_features.npy")
                np.save(save_path, unfiltered_gen_features_full_main)
                feature_paths['unfiltered'] = save_path
                logger.info(f"Saved unfiltered gen features to {save_path}")
                del unfiltered_gen_features_full_main  # Delete immediately after saving
            except Exception as e:
                logger.error(f"Failed to save or delete unfiltered gen features: {e}", exc_info=True)
        else:
            logger.warning("Unfiltered gen features are None, cannot save for KID comparison")
        
        # Save the paths for later use
        paths_file = os.path.join(feature_save_dir, "feature_paths.json")
        with open(paths_file, 'w') as f:
            json.dump(feature_paths, f, indent=2)
        logger.info(f"Saved feature paths to {paths_file}")
        
        # Force cleanup after saving
        logger.info("Forcing memory cleanup after saving features")
        force_memory_cleanup(config)
        
        logger.info("=" * 80)
        logger.info("Generation Engine completed successfully")
        logger.info(f"Generated {results['pollen_generated']} individual pollen images")
        logger.info(f"Generated {results['composed_generated']} composed images with annotations")
        logger.info(f"Total execution time: {results['total_time']:.3f} seconds")
        logger.info("=" * 80)
        
    except Exception as e:
        logger.error(f"Unhandled exception in Generation Engine: {e}")
        logger.error(traceback.format_exc())
        logger.info("=" * 80)
        logger.info("Generation Engine failed")
        logger.info("=" * 80)
        sys.exit(1)

# --- Usage Example ---
"""
"""

In [None]:

# ============================================================================
# Define all required variables if running in a separate cell
if 'graph_dir' not in locals() or graph_dir is None:
    if 'output_dirs' in locals() and output_dirs is not None:
        graph_dir = output_dirs['graphs']
    else:
        # Fallback - recreate graph_dir directly
        graph_dir = os.path.join(config.OUTPUT_DIR, config.GRAPH_SUBDIR)
        os.makedirs(graph_dir, exist_ok=True)
        print(f"Created graph_dir: {graph_dir}")
# Now your KID Calculation Method Implementation Comparison code can access the graph_dir variable
# ============================================================================

# ============================================================================
# --- START: KID Method Comparison Block ---
# ============================================================================

logger.info("=" * 40 + " STARTING KID METHOD COMPARISON " + "=" * 40)

# --- PART 2: Add this to the beginning of the KID Comparison Block ---
# Add this right after the KID comparison block starts
# After: logger.info("=" * 40 + " STARTING KID METHOD COMPARISON " + "=" * 40)

# --- Load features from disk for KID comparison ---
logger.info("Loading features for KID comparison from disk...")

# Initialize variables
real_features_full_comp = None
filtered_gen_features_full_comp = None
unfiltered_gen_features_full_comp = None

# Get the save directory
feature_save_dir = os.path.join(config.OUTPUT_DIR, config.STATS_SUBDIR, "temp_features_for_kid_comp")

# Load paths from JSON if available
try:
    paths_file = os.path.join(feature_save_dir, "feature_paths.json")
    if os.path.exists(paths_file):
        with open(paths_file, 'r') as f:
            feature_paths = json.load(f)
    else:
        logger.warning(f"Paths file not found: {paths_file}")
        feature_paths = {}
except Exception as e:
    logger.error(f"Error loading feature paths: {e}", exc_info=True)
    feature_paths = {}

# Load Real Features
real_path = feature_paths.get('real')
if real_path and os.path.exists(real_path):
    try:
        real_features_full_comp = np.load(real_path)
        logger.info(f"Loaded real features (shape: {real_features_full_comp.shape}) from {real_path}")
    except Exception as e:
        logger.error(f"Failed to load real features from {real_path}: {e}", exc_info=True)
else:
    logger.warning("Saved real features path not found or not saved")

# Load Filtered Features
filtered_path = feature_paths.get('filtered')
if filtered_path and os.path.exists(filtered_path):
    try:
        filtered_gen_features_full_comp = np.load(filtered_path)
        logger.info(f"Loaded filtered gen features (shape: {filtered_gen_features_full_comp.shape}) from {filtered_path}")
    except Exception as e:
        logger.error(f"Failed to load filtered gen features from {filtered_path}: {e}", exc_info=True)
else:
    logger.warning("Saved filtered gen features path not found or not saved")

# Load Unfiltered Features
unfiltered_path = feature_paths.get('unfiltered')
if unfiltered_path and os.path.exists(unfiltered_path):
    try:
        unfiltered_gen_features_full_comp = np.load(unfiltered_path)
        logger.info(f"Loaded unfiltered gen features (shape: {unfiltered_gen_features_full_comp.shape}) from {unfiltered_path}")
    except Exception as e:
        logger.error(f"Failed to load unfiltered gen features from {unfiltered_path}: {e}", exc_info=True)
else:
    logger.warning("Saved unfiltered gen features path not found or not saved")

# --- Modify the rest of the KID comparison block to use the loaded variables ---
# Throughout the remainder of the KID comparison block, use:
# - real_features_full_comp instead of real_features_full
# - filtered_gen_features_full_comp instead of filtered_gen_features_full
# - unfiltered_gen_features_full_comp instead of unfiltered_gen_features_full

# Validate loaded features
kid_comparison_possible = True
if real_features_full_comp is None:
    logger.error("KID Comparison Error: Real features failed to load")
    kid_comparison_possible = False

if filtered_gen_features_full_comp is None and unfiltered_gen_features_full_comp is None:
    logger.warning("KID Comparison Warning: No generated features loaded. Skipping comparison")
    kid_comparison_possible = False

# The rest of the KID comparison block continues, but using the _comp variables
# if kid_comparison_possible:
#    ... use real_features_full_comp, filtered_gen_features_full_comp, etc.

# --- Configuration & Variable Check ---
# This block assumes the following variables exist from the main script execution:
# - real_features_full: Numpy array of real image features
# - filtered_gen_features_full: Numpy array of filtered generated image features (if analysis ran)
# - unfiltered_gen_features_full: Numpy array of unfiltered generated image features (if analysis ran)
# - config: The Config object instance
# - logger: The logger instance
# - graph_dir: Path to the main graphs directory (e.g., output_dirs['graphs'])

kid_comparison_possible = True
required_vars = ['real_features_full', 'config', 'logger', 'graph_dir']
available_vars = locals() # Check variables available in the current scope

for var_name in required_vars:
    if var_name not in available_vars or available_vars[var_name] is None:
        logger.error(f"KID Calculation Method Implementation Comparison Error: Required variable '{var_name}' not found or is None.")
        kid_comparison_possible = False

# Check if at least one set of generated features exists
filtered_gen_features_full = available_vars.get('filtered_gen_features_full')
unfiltered_gen_features_full = available_vars.get('unfiltered_gen_features_full')

if filtered_gen_features_full is None and unfiltered_gen_features_full is None:
    logger.warning("KID Calculation Method Implementation Comparison Warning: No generated features (filtered or unfiltered) available. Skipping comparison.")
    kid_comparison_possible = False

if kid_comparison_possible:
    try:
        # --- Define Helper Functions ---

        # 1. Method from training Continuation Script (custom version)
        def polynomial_kernel_custom(X, Y, degree=3, gamma=None, coef0=1.0):
            """Polynomial kernel for KID (from training script)"""
            # 1. Type conversion
            X = X.astype(np.float64)
            Y = Y.astype(np.float64)
            
            # 4.1a. Feature normalization (not in Generation Method)
            X_norm = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
            Y_norm = Y / (np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8)
            
            # 4.1b. Hard-coded parameters 
            gamma = 0.2  # Fixed value (Generation uses 1/n_features)
            coef0 = 1.0
            degree = 3
            
            # 4.1c. Compute kernel
            dot_product = np.matmul(X_norm, Y_norm.T)
            return np.clip((gamma * dot_product + coef0) ** degree, 1e-8, 1e6)
        
        def calculate_kid_from_features_custom(real_features, fake_features, config, logger):
            """Calculate KID using custom polynomial kernel (from training script)"""
            try:
                # 1. Type conversion
                real_features = real_features.astype(np.float64)
                fake_features = fake_features.astype(np.float64)
                
                # 2. Feature centering (not in Generation Method)
                real_features = real_features - np.mean(real_features, axis=0, keepdims=True)
                fake_features = fake_features - np.mean(fake_features, axis=0, keepdims=True)
                
                # 3a. Setup for multiple subsets

                subset_size = getattr(config, 'KID_SUBSET_SIZE', 1000) # Use config if available
                num_subsets = getattr(config, 'KID_SUBSETS', 100)

                n_r, n_f = real_features.shape[0], fake_features.shape[0]

                if n_r == 0 or n_f == 0: # failsafe
                    logger.warning("KID Custom: Cannot calculate with empty feature sets.") # failsafe
                    return np.nan, np.nan # failsafe

                subset_size = min(subset_size, min(n_r, n_f)) # failsafe

                if subset_size < 2: # failsafe
                     logger.warning(f"KID Custom: Subset size {subset_size} too small. Need at least 2.") # failsafe
                     return np.nan, np.nan # failsafe
                
                kid_values = []
                # 3b. Sample multiple subsets
                for _ in tqdm(range(num_subsets), desc="KID Custom Subsets"):
                    # 3c. Random sampling without replacement
                    r_idx = np.random.choice(n_r, size=subset_size, replace=False)
                    f_idx = np.random.choice(n_f, size=subset_size, replace=False)
                    r_subset = real_features[r_idx]
                    f_subset = fake_features[f_idx]
        
                    # 4. Calculate kernel matrices
                    k_rr = polynomial_kernel_custom(r_subset, r_subset)
                    k_rf = polynomial_kernel_custom(r_subset, f_subset)
                    k_ff = polynomial_kernel_custom(f_subset, f_subset)
        
                    # 5a. Calculate MMD with diagonal correction
                    n = subset_size
                    mmd_numerator = np.sum(k_rr) - np.trace(k_rr) + np.sum(k_ff) - np.trace(k_ff) - 2 * np.sum(k_rf)
                    
                    # 5b. Normalize by n*(n-1)
                    mmd_denominator = n * (n - 1)

                    if mmd_denominator <= 0: # failsafe
                        logger.warning("KID Custom: Invalid denominator in MMD calculation!") # failsafe
                        mmd = np.nan # failsafe
                    else: # failsafe
                        mmd = mmd_numerator / mmd_denominator
        
                    # 6a. Collect valid MMD values
                    if np.isfinite(mmd):
                        kid_values.append(max(1e-8, mmd)) # Ensure non-negative and non-zero

                if not kid_values: # Check if list is empty after loop # failsafe
                    logger.warning("KID Custom: No valid MMD values calculated.") # failsafe
                    return np.nan, np.nan # failsafe

                # 6b. Calculate statistics across subsets
                scaling_factor = 10000  # 100² to account for square root difference
                return np.mean(kid_values) * scaling_factor, np.std(kid_values) * scaling_factor
                
            except Exception as e:
                logger.error(f"Error in calculate_kid_from_features_custom: {e}", exc_info=True)
                return np.nan, np.nan

        # 2. Method from Generation Engine Script (using sklearn)
        def calculate_kid_generation_method(real_features, fake_features, logger, chunk_size=1000): # Added chunk_size
            """Calculate KID using sklearn polynomial kernel, processed in chunks."""
            try:
                if not SKLEARN_KERNELS_AVAILABLE: # Use the global flag check
                    logger.warning("scikit-learn polynomial_kernel not available for Generation Method KID.")
                    return np.nan
        
                min_samples = min(len(real_features), len(fake_features))
                if min_samples < 10: # Increased minimum for stability
                    logger.warning(f"KID Gen Method: Too few samples ({min_samples}) for reliable calculation. Need >= 10.")
                    return np.nan
        
                # Use features up to min_samples
                real_f = real_features[:min_samples]
                fake_f = fake_features[:min_samples]
                logger.info(f"Calculating KID (Generation Method) on {min_samples} samples.")
        
                # Kernel parameters
                degree = 3
                gamma = None # Defaults to 1/n_features in sklearn
                coef0 = 1
        
                # Initialize accumulators for sums
                sum_k_rr = 0.0
                sum_k_ff = 0.0
                sum_k_rf = 0.0
        
                num_chunks = math.ceil(min_samples / chunk_size)
                logger.info(f"Calculating KID kernel sums in {num_chunks} chunks of size approx {chunk_size}...")
        
                # --- Calculate K_rr sum ---
                current_processed = 0
                # Use tqdm for progress bar
                with tqdm(total=min_samples, desc="KID Chunks (Real-Real)") as pbar_rr:
                    for i in range(0, min_samples, chunk_size):
                        real_chunk = real_f[i:min(i + chunk_size, min_samples)]
                        if len(real_chunk) == 0: continue
                        # K_real_real for this chunk vs ALL real
                        k_rr_chunk = polynomial_kernel(real_chunk, real_f, degree=degree, gamma=gamma, coef0=coef0)
                        sum_k_rr += np.sum(k_rr_chunk)
                        current_processed += len(real_chunk)
                        pbar_rr.update(len(real_chunk))
                        # Optional cleanup
                        del k_rr_chunk, real_chunk
                        if i % 10 == 0: # Less frequent cleanup
                             gc.collect()
        
                # --- Calculate K_ff sum ---
                current_processed = 0
                with tqdm(total=min_samples, desc="KID Chunks (Fake-Fake)") as pbar_ff:
                     for i in range(0, min_samples, chunk_size):
                        fake_chunk = fake_f[i:min(i + chunk_size, min_samples)]
                        if len(fake_chunk) == 0: continue
                        # K_fake_fake for this chunk vs ALL fake
                        k_ff_chunk = polynomial_kernel(fake_chunk, fake_f, degree=degree, gamma=gamma, coef0=coef0)
                        sum_k_ff += np.sum(k_ff_chunk)
                        current_processed += len(fake_chunk)
                        pbar_ff.update(len(fake_chunk))
                        # Optional cleanup
                        del k_ff_chunk, fake_chunk
                        if i % 10 == 0: # Less frequent cleanup
                             gc.collect()
        
                # --- Calculate K_rf sum ---
                current_processed = 0
                with tqdm(total=min_samples, desc="KID Chunks (Real-Fake)") as pbar_rf:
                    for i in range(0, min_samples, chunk_size):
                        real_chunk = real_f[i:min(i + chunk_size, min_samples)]
                        if len(real_chunk) == 0: continue
                         # K_real_fake for this chunk of real vs ALL fake
                        k_rf_chunk = polynomial_kernel(real_chunk, fake_f, degree=degree, gamma=gamma, coef0=coef0)
                        sum_k_rf += np.sum(k_rf_chunk)
                        current_processed += len(real_chunk)
                        pbar_rf.update(len(real_chunk))
                        # Optional cleanup
                        del k_rf_chunk, real_chunk
                        if i % 10 == 0: # Less frequent cleanup
                             gc.collect()
        
                # Calculate final means (divide sum by total number of elements: min_samples * min_samples)
                n_elements = float(min_samples * min_samples) # Use float for division
                if n_elements == 0: return np.nan
        
                mean_k_rr = sum_k_rr / n_elements
                mean_k_ff = sum_k_ff / n_elements
                mean_k_rf = sum_k_rf / n_elements
                logger.debug(f"KID Means: K_rr={mean_k_rr:.4f}, K_ff={mean_k_ff:.4f}, K_rf={mean_k_rf:.4f}")
        
                # Calculate MMD^2 using means
                mmd2 = mean_k_rr + mean_k_ff - 2 * mean_k_rf
        
                # Transform value with square root and scaling
                kid_value = np.sqrt(max(0, mmd2)) * 100
                logger.info("KID (Generation Method) chunked calculation complete.")
                return float(kid_value) if np.isfinite(kid_value) else np.nan
        
            except Exception as e:
                logger.error(f"Error in calculate_kid_generation_method (chunked): {e}", exc_info=True)
                return np.nan
            #
            
        # 3. Plotting Function
        def plot_kid_method_comparison(results_dict, output_path, logger):
            """ Generates a bar chart comparing KID scores from two methods """
            logger.info(f"Generating the KID method comparison plot: {output_path}")

            labels = list(results_dict.keys()) # Should be ['Filtered', 'Unfiltered'] or just one
            if not labels:
                logger.warning("No results of the KID available to plot for comparison.")
                return False

            continuation_scores = [results_dict[label].get('Continuation Method', np.nan) for label in labels]
            generation_scores = [results_dict[label].get('Generation Method', np.nan) for label in labels]

            x = np.arange(len(labels)) # Label locations
            width = 0.35 # Width of the bars

            # Check if any valid scores exist
            has_cont_scores = any(np.isfinite(s) for s in continuation_scores)
            has_gen_scores = any(np.isfinite(s) for s in generation_scores)

            if not has_cont_scores and not has_gen_scores:
                 logger.warning("No valid KID scores found for comparison plot.")
                 return False

            plt.style.use('seaborn-v0_8-darkgrid')
            fig, ax = plt.subplots(figsize=(10, 6))

            # Plot bars, handling potential NaN values gracefully
            rects1 = ax.bar(x - width/2, [s if np.isfinite(s) else 0 for s in continuation_scores], width, label='Continuation Method', color='purple')
            rects2 = ax.bar(x + width/2, [s if np.isfinite(s) else 0 for s in generation_scores], width, label='Generation Method', color='pink')

            # Add labels, title, ticks
            ax.set_ylabel('KID Score (Lower is Better)')
            ax.set_title('KID Score Comparison by Calculation Method')
            ax.set_xticks(x)
            ax.set_xticklabels(labels)
            ax.legend()

            # Add text labels above bars, showing NaN if score was invalid
            def autolabel(rects, scores):
                for i, rect in enumerate(rects):
                    height = rect.get_height()
                    score_val = scores[i]
                    label_text = f'{score_val:.3f}' if np.isfinite(score_val) else 'N/A'
                    ax.annotate(label_text,
                                xy=(rect.get_x() + rect.get_width() / 2, height),
                                xytext=(0, 3), # 3 points vertical offset
                                textcoords="offset points",
                                ha='center', va='bottom')

            autolabel(rects1, continuation_scores)
            autolabel(rects2, generation_scores)

            fig.tight_layout()
            try:
                plt.savefig(output_path, dpi=300)
                plt.close(fig)
                logger.info(f"Saved the KID Calculation Method Implementation Comparison plot to {output_path}")
                return True
            except Exception as e:
                logger.error(f"Error saving the KID Calculation Method Implementation Comparison plot: {e}", exc_info=True)
                plt.close(fig) # Attempt to close plot even if saving failed
                return False


        # --- Main Calculation Logic ---
        kid_comparison_results = {}
        analysis_sample_size = config.ANALYSIS_SAMPLE_SIZE

        # Ensure comparison subdirectory exists
        kid_comparison_dir = os.path.join(graph_dir, "kid_methods_comparison")
        os.makedirs(kid_comparison_dir, exist_ok=True)
        logger.info(f"Saving the KID Calculation Method Implementation Comparison results to: {kid_comparison_dir}")

        # 1. Calculate for Filtered data (if available)
        if filtered_gen_features_full is not None:
            logger.info("--- Comparing the KID methods for Filtered Data ---")
            min_samples_f = min(len(real_features_full), len(filtered_gen_features_full), analysis_sample_size)
            if min_samples_f < 10: # Need at least a few samples
                logger.warning(f"Too few filtered samples ({min_samples_f}) for reliable KID Calculation Method Implementation Comparison.")
            else:
                real_f_comp = real_features_full[:min_samples_f]
                filt_f_comp = filtered_gen_features_full[:min_samples_f]

                # Calculate using Continuation method
                logger.debug("Calculating the KID (Filtered) using Continuation Method...")
                kid_cont_f, kid_std_cont_f = calculate_kid_from_features_custom(real_f_comp, filt_f_comp, config, logger)
                logger.info(f"Filtered KID (Continuation Method): {kid_cont_f:.6f} ± {kid_std_cont_f:.6f}")

                # Calculate using Generation method
                logger.debug("Calculating the KID (Filtered) using Generation Method...")
                kid_gen_f = calculate_kid_generation_method(real_f_comp, filt_f_comp, logger)
                logger.info(f"Filtered KID (Generation Method): {kid_gen_f:.3f}")

                kid_comparison_results['Filtered'] = {
                    'Continuation Method': kid_cont_f,
                    'Generation Method': kid_gen_f
                }
                del real_f_comp, filt_f_comp # Cleanup

        # 2. Calculate for Unfiltered data (if available)
        if unfiltered_gen_features_full is not None:
            logger.info("--- Comparing the KID methods for Unfiltered Data ---")
            min_samples_u = min(len(real_features_full), len(unfiltered_gen_features_full), analysis_sample_size)
            if min_samples_u < 10:
                 logger.warning(f"Too few unfiltered samples ({min_samples_u}) for reliable KID comparison.")
            else:
                real_u_comp = real_features_full[:min_samples_u]
                unfilt_u_comp = unfiltered_gen_features_full[:min_samples_u]

                # Calculate using Continuation method
                logger.debug("Calculating the KID (Unfiltered) using Continuation Method...")
                kid_cont_u, kid_std_cont_u = calculate_kid_from_features_custom(real_u_comp, unfilt_u_comp, config, logger)
                logger.info(f"Unfiltered KID (Continuation Method): {kid_cont_u:.6f} ± {kid_std_cont_u:.6f}")

                # Calculate using Generation method
                logger.debug("Calculating the KID (Unfiltered) using Generation Method...")
                kid_gen_u = calculate_kid_generation_method(real_u_comp, unfilt_u_comp, logger)
                logger.info(f"Unfiltered KID (Generation Method): {kid_gen_u:.3f}")

                kid_comparison_results['Unfiltered'] = {
                    'Continuation Method': kid_cont_u,
                    'Generation Method': kid_gen_u
                }
                del real_u_comp, unfilt_u_comp # Cleanup

        # 3. Generate Comparison Plot
        if kid_comparison_results:
            plot_path = os.path.join(kid_comparison_dir, f"kid_methods_comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
            plot_kid_method_comparison(kid_comparison_results, plot_path, logger)
        else:
             logger.warning("No results for the KID Calculation Method Implementation Comparison were generated.")

    except Exception as e:
        logger.error(f"An error occurred during the KID Calculation Method Implementation Comparison: {e}", exc_info=True)

# Clear memory after comparison block
force_memory_cleanup(config) # Assuming force_memory_cleanup is defined globally

# --- PART 3: Add this to the end of the KID Comparison Block ---
# Add this right before the KID comparison block ends

# Clean up loaded features
logger.info("Cleaning up after KID comparison...")
if 'real_features_full_comp' in locals() and real_features_full_comp is not None:
    del real_features_full_comp
    
if 'filtered_gen_features_full_comp' in locals() and filtered_gen_features_full_comp is not None:
    del filtered_gen_features_full_comp
    
if 'unfiltered_gen_features_full_comp' in locals() and unfiltered_gen_features_full_comp is not None:
    del unfiltered_gen_features_full_comp
    
# Also clean up any subset variables created during comparison
# This will vary based on exact variable names in your comparison code
for var_name in list(locals().keys()):
    if var_name.endswith('_comp') or '_comp_' in var_name:
        if var_name in locals():
            del locals()[var_name]

# Final cleanup
force_memory_cleanup(config)
logger.info("KID comparison cleanup complete")

logger.info("=" * 40 + " FINISHED THE KID CALCULATION IMPLEMENTATION METHOD COMPARISON " + "=" * 40)

# ============================================================================
# --- END: KID Method Comparison Block ---
# ============================================================================
