In [None]:
# "script_v2-151_2A1_primary-modified[dot]ipynb"

# Script "v2.151" (2A1): WGAN-SN Model Definition & Training (Enhanced Edition)
# With added metrics (FID/KID) and visualization capabilities

import importlib
import subprocess
import sys
import os
import gc
import random
import time
import json
import logging 
import traceback 
from datetime import datetime
from math import ceil
import numpy as np

# --- Auto-installation Block ---
def install_and_import(package_name, import_name=None, pip_name=None):
    """Tries to import a package, installs it via pip if import fails."""
    if import_name is None:
        import_name = package_name
    if pip_name is None:
        pip_name = package_name
    try:
        module = importlib.import_module(package_name)
        globals()[import_name] = module
        print(f"Successfully imported {package_name} as {import_name}")
        return True
    except ImportError:
        print(f"{package_name} not found. Attempting installation using pip...")
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name])
            module = importlib.import_module(package_name)
            globals()[import_name] = module
            print(f"Successfully installed and imported {package_name} as {import_name}")
            return True
        except (subprocess.CalledProcessError, ImportError, ModuleNotFoundError) as e:
            print(f"ERROR: Failed to install/import {package_name} (pip name: {pip_name}). {e}")
            print("Please install required packages manually and restart kernel.")
            return False

print("--- Checking and Installing Dependencies ---")
numpy_success = install_and_import('numpy', 'np')
torch_success = install_and_import('torch')
torchvision_success = install_and_import('torchvision')
install_and_import('PIL')
install_and_import('tqdm')
install_and_import('matplotlib.pyplot', 'plt') # changed to resolve a function referncing error
install_and_import('scipy')
install_and_import('pytorch_fid', pip_name='pytorch-fid')
install_and_import('pynvml', pip_name='nvidia-ml-py3')
# New dependency for t-SNE visualization
install_and_import('sklearn.manifold', 'manifold', pip_name='scikit-learn')

# Check critical dependencies
critical_imports_successful = all([numpy_success, torch_success, torchvision_success])
if not critical_imports_successful:
    print("ERROR: Critical packages (numpy, torch, torchvision) failed to import.")
    print("Please install these packages manually and restart the script.")
    sys.exit(1)

# --- Core Imports ---
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader 
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.nn.utils import spectral_norm 
from torch.amp import autocast, GradScaler
import torch.nn.functional as F

# Import FID calculation utilities
try:
    from pytorch_fid.inception import InceptionV3
    from pytorch_fid.fid_score import calculate_frechet_distance
    # Import KID calculation functions if available
    try:
        from pytorch_fid.kid_score import polynomial_kernel, calculate_kid_given_features
        KID_AVAILABLE = True
    except ImportError:
        print("WARNING: KID calculation not available in pytorch_fid. Implementing custom KID calculation.")
        KID_AVAILABLE = True  # We'll implement it ourselves if not in library
        
    FID_AVAILABLE = True
except ImportError:
    print("WARNING: Could not import FID utilities from 'pytorch_fid'. FID/KID calculation will be disabled.")
    FID_AVAILABLE = False
    KID_AVAILABLE = False

# --- Import models and dataset class from separate files ---
try:
    from pollen_datasets_v2A1 import PollenDataset 
    from wgan_models_v2A1 import Generator, CriticSN, initialize_weights 
    print("Successfully imported models and dataset classes from local .py files.")
except ImportError as e:
    print(f"ERROR: Could not import classes from .py files: {e}")
    print("Ensure pollen_datasets_v2A1.py and wgan_models_v2A1.py exist in the same directory.")
    raise 

# --- Check NVML Availability ---
NVML_AVAILABLE = False
if 'pynvml' in globals():
    try:
        pynvml.nvmlInit()
        device_count = pynvml.nvmlDeviceGetCount()
        print(f"Successfully initialized NVML. Found {device_count} NVIDIA GPU(s).")
        NVML_AVAILABLE = True
        try: pynvml.nvmlShutdown()
        except pynvml.NVMLError: print("Warning: NVML shutdown failed after check (might be harmless).")
    except pynvml.NVMLError as nvml_err:
        print(f"Warning: pynvml library imported but failed to initialize. GPU monitoring disabled.")
        print(f"NVML Error: {nvml_err}")
        NVML_AVAILABLE = False
else:
     print("Warning: pynvml library not imported/installed. GPU monitoring disabled.")

# ==============================================================================
# --- Configuration Section --- (EDIT THESE VALUES) ---
# ==============================================================================

# --- Paths ---
# FIX: Corrected path to match the actual directory name
PREPROCESSED_DATA_DIR = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\pre-processing_px-128_step_automated-labels_pc-150" 
OUTPUT_DIR = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\WGAN-SN_training-output_v2-151" 

# --- Model Hyperparameters ---
IMAGE_SIZE = 128       
CHANNELS_IMG = 1       
NOISE_DIM = 100        
G_FEATURES = 64        
C_FEATURES = 64        

# --- Training Hyperparameters --- 
LEARNING_RATE = 0.00005  # Updated lower learning rate, again back to 0.00005 from 0.00002
BETA1 = 0.0            
BETA2 = 0.9            
BATCH_SIZE = 64        
NUM_EPOCHS = 250       # Updated to 250 from 100 epochs
CRITIC_ITERATIONS = 5  # Updated to 5 from 3

# --- Checkpointing & Resuming ---
CHECKPOINT_FREQ_EPOCHS = 5 
RESUME_TRAINING = False   
CHECKPOINT_FILE = "latest_checkpoint_sn_v2151.pth.tar" 
#BEST_FID_CHECKPOINT_TPL = "best_fid_checkpoint_e{epoch:04d}_fid{fid:.2f}_v2151.pth.tar"
#BEST_KID_CHECKPOINT_TPL = "best_kid_checkpoint_e{epoch:04d}_kid{kid:.2f}_v2151.pth.tar"
BEST_FID_CHECKPOINT_FILE = "best_fid_checkpoint_v2151.pth.tar"  # Fixed name for overwriting
BEST_KID_CHECKPOINT_FILE = "best_kid_checkpoint_v2151.pth.tar"  # Fixed name for overwriting
BEST_MODEL_FILE = "best_model_v2151.pt"

# --- Logging & Monitoring ---
LOG_FILE = "training_log_sn_v2151.log" 
SAMPLE_FREQ_STEPS = 500    
MONITOR_TEMP = True        
GPU_TEMP_THRESHOLD = 89    # Updated to 89C
GPU_ID = 0                 
TRACK_MEMORY_USAGE = True

# --- FID/KID Calculation ---
CALCULATE_FID = True
CALCULATE_KID = True                   # New: Calculate KID alongside FID
PRIMARY_EVAL_METRIC = "FID"            # New: Choose "FID" or "KID" for early stopping
KID_SUBSET_SIZE = 1000                 # New: Subset size for KID calculation
KID_SUBSETS = 100                      # New: Number of subsets for KID calculation
FID_FREQ_EPOCHS = 1                    # Check every epoch
FID_NUM_IMAGES = 10000                 # Use 10k images
FID_BATCH_SIZE = 64
REAL_STATS_PATH = os.path.join(OUTPUT_DIR, "real_fid_stats_10k.npz")
#REAL_FEATURES_PATH = os.path.join(OUTPUT_DIR, "real_inception_features_10k.npy")  # New: Path to save real features
REAL_FEATURES_PATH = os.path.join(OUTPUT_DIR, "real_inception_features_10k.npy")  # New: Path to save real features
BEST_FEATURES_PATH = os.path.join(OUTPUT_DIR, "best_fake_features_10k.npy")  # New: Path to save best fake features
FORCE_RECALCULATE_REAL_STATS = False
USE_EARLY_STOPPING = True
EARLY_STOPPING_PATIENCE = 200           # Changed to 25 from 10 & to 200 from 25

# --- T-SNE Visualization ---                   # New section
VISUALIZE_TSNE = True                          # Enable t-SNE visualization
TSNE_SAMPLE_SIZE = 2000                        # Max number of samples to use for t-SNE (to keep computation manageable)
TSNE_PERPLEXITY = 30                           # t-SNE perplexity parameter
TSNE_RANDOM_STATE = 42                         # Random seed for reproducibility

# --- Reproducibility ---
MANUAL_SEED = 42

# ==============================================================================
# --- Setup ---
# ==============================================================================

# --- Create output directories ---
CHKPT_DIR = os.path.join(OUTPUT_DIR, "checkpoints")
SAMPLE_DIR = os.path.join(OUTPUT_DIR, "samples")
LOG_DIR = os.path.join(OUTPUT_DIR, "logs")
PLOT_DIR = os.path.join(OUTPUT_DIR, "plots")
ANALYSIS_DIR = os.path.join(OUTPUT_DIR, "analysis_results")

# Create all required directories
for directory in [CHKPT_DIR, SAMPLE_DIR, LOG_DIR, PLOT_DIR, ANALYSIS_DIR]:
    os.makedirs(directory, exist_ok=True)

# --- Setup Logging --- 
log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s - %(message)s')
logger = logging.getLogger("WGAN_SN_Trainer_v2151")
logger.setLevel(logging.INFO) 
if logger.hasHandlers(): logger.handlers.clear() 
# File Handler
file_handler = logging.FileHandler(os.path.join(LOG_DIR, LOG_FILE), mode='a') 
file_handler.setFormatter(log_formatter)
logger.addHandler(file_handler)
# Console Handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_formatter)
logger.addHandler(console_handler)
logger.info("="*60)
logger.info(f"Starting WGAN-SN Training (v2.151) at {datetime.now()}")
logger.info("="*60)
logger.info("Logger configured.")

# --- Set Seed ---
if MANUAL_SEED is not None:
    logger.info(f"Using manual seed: {MANUAL_SEED}")
    random.seed(MANUAL_SEED)
    np.random.seed(MANUAL_SEED)
    torch.manual_seed(MANUAL_SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(MANUAL_SEED)
else:
    logger.info("Using random seed.")

# --- Memory Utility ---
def log_gpu_memory_usage(step=''):
    if not torch.cuda.is_available() or not TRACK_MEMORY_USAGE:
        return
    
    allocated = torch.cuda.memory_allocated() / (1024 ** 3)  # GB
    reserved = torch.cuda.memory_reserved() / (1024 ** 3)    # GB
    
    logger.info(f"GPU Memory [{step}]: Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB")

# --- NVML Utility (Enhanced with device validation) ---
def get_gpu_temp(gpu_id_to_check):
    if not NVML_AVAILABLE: return None
    temp = None
    nvml_init_success = False
    try:
        pynvml.nvmlInit()
        nvml_init_success = True
        
        # Verify device count before accessing a specific index
        device_count = pynvml.nvmlDeviceGetCount()
        if gpu_id_to_check >= device_count:
            logger.warning(f"GPU ID {gpu_id_to_check} out of range (max: {device_count-1}). Using GPU 0.")
            gpu_id_to_check = 0
            
        handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id_to_check)
        temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
    except pynvml.NVMLError as nvml_err:
        logger.warning(f"Could not get GPU temp (ID {gpu_id_to_check}): {nvml_err}", exc_info=False) 
    except Exception as e:
        logger.warning(f"Non-NVML error getting GPU temp: {e}", exc_info=False) 
    finally:
        if nvml_init_success:
            try: pynvml.nvmlShutdown() 
            except pynvml.NVMLError: pass 
    return temp

# --- Directory Validation Function ---
def validate_directory(dir_path):
    if not os.path.isdir(dir_path):
        logger.error(f"Directory not found: {dir_path}")
        return False
    return True

# --- Checkpoint Utilities ---
def save_checkpoint(state, filename):
    save_path = os.path.join(CHKPT_DIR, filename)
    logger.info(f"=> Saving checkpoint to {save_path}")
    try:
        torch.save(state, save_path)
        logger.info(f"   Checkpoint saved successfully.")
    except Exception as e:
        logger.error(f"!! Failed to save checkpoint: {e}", exc_info=True) 

def load_checkpoint(filename, generator, critic, opt_gen, opt_critic, scaler_gen, scaler_critic): 
    load_path = os.path.join(CHKPT_DIR, filename)
    logger.info(f"=> Attempting to load checkpoint from {load_path}")
    if not os.path.exists(load_path):
         logger.warning(f"=> No checkpoint found at '{load_path}'. Starting from scratch.")
         return 0, 0, [], [], [], [], [], [], [], float('inf'), float('inf'), 0
    try:
        checkpoint = torch.load(load_path, map_location=torch.device('cpu')) 
        start_epoch = checkpoint.get('epoch', 0) 
        global_step = checkpoint.get('step', 0)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        critic.load_state_dict(checkpoint['critic_state_dict'])
        opt_gen.load_state_dict(checkpoint['optimizer_gen_state_dict'])
        opt_critic.load_state_dict(checkpoint['optimizer_critic_state_dict'])
        if 'scaler_gen_state_dict' in checkpoint and scaler_gen is not None:
            scaler_gen.load_state_dict(checkpoint['scaler_gen_state_dict'])
            logger.info("   Loaded GradScaler state for Generator.")
        else:
            logger.warning("   Generator GradScaler state not found in checkpoint.")
        if 'scaler_critic_state_dict' in checkpoint and scaler_critic is not None:
            scaler_critic.load_state_dict(checkpoint['scaler_critic_state_dict'])
            logger.info("   Loaded GradScaler state for Critic.")
        else:
            logger.warning("   Critic GradScaler state not found in checkpoint.")
        
        g_losses_hist = checkpoint.get('g_losses_history', [])
        c_losses_hist = checkpoint.get('c_losses_history', [])
        fid_scores_hist = checkpoint.get('fid_scores_history', [])
        fid_epochs_hist = checkpoint.get('fid_epochs_history', [])
        kid_scores_hist = checkpoint.get('kid_scores_history', [])  
        kid_std_hist = checkpoint.get('kid_std_history', [])        
        kid_epochs_hist = checkpoint.get('kid_epochs_history', [])   
        best_fid_val = checkpoint.get('best_fid', float('inf'))
        best_kid_val = checkpoint.get('best_kid', float('inf'))     
        epochs_no_improve_val = checkpoint.get('epochs_no_improve', 0)
        
        logger.info(f"=> Loaded checkpoint successfully. Resuming from Epoch {start_epoch}, Step {global_step}")
        logger.info(f"   History: {len(g_losses_hist)} loss points, {len(fid_scores_hist)} FID points, {len(kid_scores_hist)} KID points.")
        logger.info(f"   Best FID: {best_fid_val:.4f}, Best KID: {best_kid_val:.4f}. Patience: {epochs_no_improve_val}")
        return start_epoch, global_step, g_losses_hist, c_losses_hist, fid_scores_hist, fid_epochs_hist, kid_scores_hist, kid_std_hist, kid_epochs_hist, best_fid_val, best_kid_val, epochs_no_improve_val
    except KeyError as e:
         logger.error(f"=> Error loading checkpoint: Missing key {e}. Checkpoint might be incompatible or corrupt. Starting from scratch.")
         return 0, 0, [], [], [], [], [], [], [], float('inf'), float('inf'), 0
    except Exception as e:
        logger.error(f"=> Error loading checkpoint: {e}. Starting from scratch.", exc_info=True) 
        return 0, 0, [], [], [], [], [], [], [], float('inf'), float('inf'), 0

# --- Save Best Model in Loadable Format ---
def save_best_model(generator, filename=BEST_MODEL_FILE):
    save_path = os.path.join(OUTPUT_DIR, filename)
    logger.info(f"Saving best model to {save_path}")
    try:
        # Save the model in a format suitable for loading in inference mode
        torch.save({
            'model_state_dict': generator.state_dict(),
            'model_config': {
                'noise_dim': NOISE_DIM,
                'channels_img': CHANNELS_IMG,
                'features_g': G_FEATURES
            }
        }, save_path)
        logger.info(f"Best model saved successfully.")
    except Exception as e:
        logger.error(f"Failed to save best model: {e}", exc_info=True)

# --- Save features from best model checkpoint ---
def save_fake_features(fake_features, path=BEST_FEATURES_PATH):
    """Save fake features from the current best checkpoint"""
    if fake_features is None:
        logger.warning("No fake features to save.")
        return
    
    try:
        # Save features for later visualization
        np.save(path, fake_features)
        logger.info(f"Saved best fake features to: {path}")
    except Exception as e:
        logger.error(f"Failed to save fake features: {e}", exc_info=True)

# --- Plotting Utilities ---
# Original plot_metrics function is kept for backward compatibility but not used
def plot_metrics(g_losses, c_losses, fid_scores, fid_epochs, save_dir):
    try:
        plt.figure(figsize=(12, 5))
        
        # Plot losses
        plt.subplot(1, 2, 1)
        epochs = range(1, len(g_losses) + 1)
        if g_losses:
            plt.plot(epochs, g_losses, label="Generator Loss", alpha=0.8)
        if c_losses:
            plt.plot(epochs, c_losses, label="Critic Loss", alpha=0.8)
        plt.title("Losses per Epoch")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        if g_losses or c_losses:
            plt.legend()
        plt.grid(True, linestyle='--', alpha=0.6)
        
        # Plot FID scores
        plt.subplot(1, 2, 2)
        if fid_scores and fid_epochs:
            plt.plot(fid_epochs, fid_scores, marker='o', linestyle='-', label="FID Score")
            if fid_scores:
                best_fid_val = min(fid_scores)
                best_epoch_idx = fid_scores.index(best_fid_val)
                best_epoch = fid_epochs[best_epoch_idx]
                plt.scatter([best_epoch], [best_fid_val], color='red', s=100, zorder=5, 
                            label=f'Best FID: {best_fid_val:.2f} (Epoch {best_epoch})')
            plt.legend()
        else:
            plt.text(0.5, 0.5, 'FID not calculated or no data.', 
                    ha='center', va='center', transform=plt.gca().transAxes)
        
        plt.title("FID Score per Check")
        plt.xlabel("Epoch")
        plt.ylabel("FID Score (Lower is Better)")
        plt.grid(True, linestyle='--', alpha=0.6)
        
        plt.tight_layout()
        plot_filename = os.path.join(save_dir, "training_metrics_plot_v2151.png")
        plt.savefig(plot_filename)
        logger.info(f"Saved metrics plot to {plot_filename}")
        plt.close()
    except Exception as e:
        logger.error(f"Failed to generate or save plots: {e}", exc_info=True)

# New separate plotting functions
def plot_losses(g_losses, c_losses, save_dir):
    """Plot generator and critic losses separately"""
    try:
        if not g_losses or not c_losses:
            logger.warning("No loss data to plot")
            return
        
        fig, ax = plt.subplots(figsize=(10, 6))
        epochs = range(1, len(g_losses) + 1)
        
        ax.plot(epochs, g_losses, 'r-', label="Generator Loss", alpha=0.8)
        ax.plot(epochs, c_losses, 'b-', label="Critic Loss", alpha=0.8)
        
        ax.set_title("Generator and Critic Loss vs. Epoch")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.legend()
        ax.grid(True, linestyle='--', alpha=0.6)
        
        plt.tight_layout()
        plot_filename = os.path.join(save_dir, "loss_plot_v2151.png")
        plt.savefig(plot_filename)
        logger.info(f"Saved loss plot to {plot_filename}")
        plt.close(fig)
    except Exception as e:
        logger.error(f"Failed to generate or save loss plot: {e}", exc_info=True)

def plot_fid(fid_scores, fid_epochs, save_dir):
    """Plot FID scores with best and worst points annotated"""
    try:
        if not fid_scores or not fid_epochs:
            logger.warning("No FID data to plot")
            return
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        ax.plot(fid_epochs, fid_scores, marker='o', linestyle='-', label="FID Score")
        
        # Find and annotate best (min) FID
        best_fid_val = min(fid_scores)
        best_epoch_idx = fid_scores.index(best_fid_val)
        best_epoch = fid_epochs[best_epoch_idx]
        ax.scatter([best_epoch], [best_fid_val], color='red', s=100, zorder=5)
        ax.annotate(f'Best: {best_fid_val:.2f}\nEpoch: {best_epoch}', 
                    xy=(best_epoch, best_fid_val), xytext=(10, -30),
                    textcoords='offset points', arrowprops=dict(arrowstyle="->"))
        
        # Find and annotate worst (max) FID
        worst_fid_val = max(fid_scores)
        worst_epoch_idx = fid_scores.index(worst_fid_val)
        worst_epoch = fid_epochs[worst_epoch_idx]
        ax.scatter([worst_epoch], [worst_fid_val], color='orange', s=100, zorder=5)
        ax.annotate(f'Worst: {worst_fid_val:.2f}\nEpoch: {worst_epoch}', 
                    xy=(worst_epoch, worst_fid_val), xytext=(-30, 10), 
                    textcoords='offset points', arrowprops=dict(arrowstyle="->"))
        
        ax.set_title("FID Score vs. Epoch")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("FID Score (Lower is Better)")
        ax.grid(True, linestyle='--', alpha=0.6)
        
        plt.tight_layout()
        plot_filename = os.path.join(save_dir, "fid_plot_v2151.png")
        plt.savefig(plot_filename)
        logger.info(f"Saved FID plot to {plot_filename}")
        plt.close(fig)
    except Exception as e:
        logger.error(f"Failed to generate or save FID plot: {e}", exc_info=True)

def plot_kid(kid_scores, kid_stds, kid_epochs, save_dir):
    """Plot KID scores with best and worst points annotated"""
    try:
        if not kid_scores or not kid_epochs:
            logger.warning("No KID data to plot")
            return
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Plot with error bars if standard deviations are available
        if kid_stds and len(kid_stds) == len(kid_scores):
            ax.errorbar(kid_epochs, kid_scores, yerr=kid_stds, fmt='o-', 
                      label="KID Score", capsize=4)
        else:
            ax.plot(kid_epochs, kid_scores, marker='o', linestyle='-', label="KID Score")
        
        # Find and annotate best (min) KID
        best_kid_val = min(kid_scores)
        best_epoch_idx = kid_scores.index(best_kid_val)
        best_epoch = kid_epochs[best_epoch_idx]
        ax.scatter([best_epoch], [best_kid_val], color='red', s=100, zorder=5)
        ax.annotate(f'Best: {best_kid_val:.4f}\nEpoch: {best_epoch}', 
                    xy=(best_epoch, best_kid_val), xytext=(10, -30),
                    textcoords='offset points', arrowprops=dict(arrowstyle="->"))
        
        # Find and annotate worst (max) KID
        worst_kid_val = max(kid_scores)
        worst_epoch_idx = kid_scores.index(worst_kid_val)
        worst_epoch = kid_epochs[worst_epoch_idx]
        ax.scatter([worst_epoch], [worst_kid_val], color='orange', s=100, zorder=5)
        ax.annotate(f'Worst: {worst_kid_val:.4f}\nEpoch: {worst_epoch}', 
                    xy=(worst_epoch, worst_kid_val), xytext=(-30, 10), 
                    textcoords='offset points', arrowprops=dict(arrowstyle="->"))
        
        ax.set_title("KID Score vs. Epoch")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("KID Score (Lower is Better)")
        ax.grid(True, linestyle='--', alpha=0.6)
        
        plt.tight_layout()
        plot_filename = os.path.join(save_dir, "kid_plot_v2151.png")
        plt.savefig(plot_filename)
        logger.info(f"Saved KID plot to {plot_filename}")
        plt.close(fig)
    except Exception as e:
        logger.error(f"Failed to generate or save KID plot: {e}", exc_info=True)

def plot_combined_metrics(fid_scores, kid_scores, epochs, save_dir):
    """Plot FID and KID scores on the same graph"""
    try:
        if not fid_scores or not kid_scores or not epochs:
            logger.warning("Missing data for combined metrics plot")
            return
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Check if we need different scales for FID and KID
        fid_max, fid_min = max(fid_scores), min(fid_scores)
        kid_max, kid_min = max(kid_scores), min(kid_scores)
        fid_range = fid_max - fid_min
        kid_range = kid_max - kid_min
        
        # If scales are very different, use two y-axes
        if fid_range / kid_range > 5 or kid_range / fid_range > 5:
            # Plot FID on left axis
            line1, = ax.plot(epochs, fid_scores, 'b-o', label="FID Score")
            ax.set_ylabel("FID Score", color='b')
            ax.tick_params(axis='y', labelcolor='b')
            
            # Create right y-axis for KID
            ax2 = ax.twinx()
            line2, = ax2.plot(epochs, kid_scores, 'r-o', label="KID Score")
            ax2.set_ylabel("KID Score", color='r')
            ax2.tick_params(axis='y', labelcolor='r')
            
            # Put legends together
            lines = [line1, line2]
            labels = [line.get_label() for line in lines]
            ax.legend(lines, labels, loc="upper right")
        else:
            # Plot both on same scale
            ax.plot(epochs, fid_scores, 'b-o', label="FID Score")
            ax.plot(epochs, kid_scores, 'r-o', label="KID Score")
            ax.set_ylabel("Score Value (Lower is Better)")
            ax.legend()
        
        ax.set_title("FID and KID Scores vs. Epoch")
        ax.set_xlabel("Epoch")
        ax.grid(True, linestyle='--', alpha=0.6)
        
        plt.tight_layout()
        plot_filename = os.path.join(save_dir, "combined_metrics_plot_v2151.png")
        plt.savefig(plot_filename)
        logger.info(f"Saved combined metrics plot to {plot_filename}")
        plt.close(fig)
    except Exception as e:
        logger.error(f"Failed to generate or save combined metrics plot: {e}", exc_info=True)

def plot_feature_space(real_features, fake_features, save_path):
    """Create t-SNE visualization of real vs fake feature distributions"""
    try:
        if real_features is None or fake_features is None:
            logger.warning("Missing features for t-SNE visualization")
            return
        
        # Sample if too many points
        max_samples = TSNE_SAMPLE_SIZE // 2  # Half for real, half for fake
        
        if len(real_features) > max_samples:
            indices = np.random.choice(len(real_features), max_samples, replace=False)
            real_sample = real_features[indices]
        else:
            real_sample = real_features
            
        if len(fake_features) > max_samples:
            indices = np.random.choice(len(fake_features), max_samples, replace=False)
            fake_sample = fake_features[indices]
        else:
            fake_sample = fake_features
        
        # Combine features for t-SNE
        combined_features = np.vstack([real_sample, fake_sample])
        
        # Create labels (0 for real, 1 for fake)
        labels = np.zeros(len(combined_features))
        labels[len(real_sample):] = 1
        
        # Perform t-SNE
        logger.info("Computing t-SNE embedding...")
        tsne = manifold.TSNE(n_components=2, perplexity=TSNE_PERPLEXITY, 
                            random_state=TSNE_RANDOM_STATE, n_iter=1000)
        embedding = tsne.fit_transform(combined_features)
        
        # Create plot
        fig, ax = plt.subplots(figsize=(10, 8))
        
        real_points = embedding[:len(real_sample)]
        fake_points = embedding[len(real_sample):]
        
        ax.scatter(real_points[:, 0], real_points[:, 1], c='blue', alpha=0.6, label='Real', s=20)
        ax.scatter(fake_points[:, 0], fake_points[:, 1], c='red', alpha=0.6, label='Generated', s=20)
        
        ax.set_title('t-SNE Visualization of Real vs Generated Feature Distributions')
        ax.legend()
        
        # Add a description of what plot shows
        ax.annotate("Note: Points closer together have similar feature representations\n"
                    "Good generation = red points distributed similarly to blue points", 
                    xy=(0.5, -0.01), xycoords='axes fraction', 
                    ha='center', va='top', fontsize=9)
        
        plt.tight_layout()
        plt.savefig(save_path)
        logger.info(f"Saved t-SNE visualization to {save_path}")
        plt.close(fig)
        
    except Exception as e:
        logger.error(f"Failed to generate or save feature space visualization: {e}", exc_info=True)

# --- Generate Markdown Report ---
def generate_markdown_report(g_losses_hist, c_losses_hist, fid_scores_hist, fid_epochs_hist, kid_scores_hist, kid_epochs_hist, training_info):
    """Generate a simple markdown report of the training results"""
    report_path = os.path.join(ANALYSIS_DIR, "training_report.md")
    
    try:
        with open(report_path, "w") as f:
            f.write("# WGAN-SN Training Report (v2.151-modified)\n\n")
            f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            
            # Training summary
            f.write("## Training Summary\n\n")
            f.write(f"- Total Epochs: {training_info.get('epochs', 'N/A')}\n")
            f.write(f"- Final Status: {training_info.get('stop_reason', 'N/A')}\n")
            f.write(f"- Total Training Time: {training_info.get('training_time', 0):.2f} seconds\n")
            f.write(f"- Final Generator Loss: {g_losses_hist[-1] if g_losses_hist else 'N/A'}\n")
            f.write(f"- Final Critic Loss: {c_losses_hist[-1] if c_losses_hist else 'N/A'}\n")
            
            # FID information
            if fid_scores_hist:
                best_fid = min(fid_scores_hist)
                best_fid_epoch = fid_epochs_hist[fid_scores_hist.index(best_fid)]
                f.write(f"- Best FID Score: {best_fid:.4f} (Epoch {best_fid_epoch})\n")
                f.write(f"- Final FID Score: {fid_scores_hist[-1]:.4f}\n")
                
            # KID information (new)
            if kid_scores_hist:
                best_kid = min(kid_scores_hist)
                best_kid_epoch = kid_epochs_hist[kid_scores_hist.index(best_kid)]
                f.write(f"- Best KID Score: {best_kid:.4f} (Epoch {best_kid_epoch})\n")
                f.write(f"- Final KID Score: {kid_scores_hist[-1]:.4f}\n")
            
            # Loss trends
            f.write("\n## Loss Trends\n\n")
            if g_losses_hist and len(g_losses_hist) > 10:
                recent_g_loss = g_losses_hist[-5:]
                early_g_loss = g_losses_hist[:5]
                avg_recent = sum(recent_g_loss) / len(recent_g_loss)
                avg_early = sum(early_g_loss) / len(early_g_loss)
                
                f.write(f"- Initial Generator Loss Avg (first 5 epochs): {sum(early_g_loss)/len(early_g_loss):.4f}\n")
                f.write(f"- Final Generator Loss Avg (last 5 epochs): {sum(recent_g_loss)/len(recent_g_loss):.4f}\n")
                
                if avg_recent < -0.5 and avg_recent > -2.0:
                    f.write("- Generator loss has stabilized in the appropriate negative range typically seen in successful WGAN training.\n")
                elif avg_recent > -0.5:
                    f.write("- Generator loss may be too close to zero, which could indicate mode collapse or poor convergence.\n")
                elif avg_recent < -2.0:
                    f.write("- Generator loss is very negative, which might indicate training instability.\n")
                    
                if avg_recent > avg_early and g_losses_hist[-1] > g_losses_hist[-10]:
                    f.write("- Warning: Generator loss is increasing, which suggests potential training instability.\n")
            
            # FID analysis
            if fid_scores_hist and len(fid_scores_hist) >= 3:
                f.write("\n## FID Analysis\n\n")
                last_fids = fid_scores_hist[-3:]
                
                if last_fids[0] > last_fids[-1] and last_fids[1] > last_fids[-1]:
                    f.write("- FID scores are continuing to improve, indicating the model is still learning to generate more realistic images.\n")
                elif all(abs(last_fids[0] - fid) < 1.0 for fid in last_fids[1:]):
                    f.write("- FID scores have stabilized, suggesting the model has reached convergence.\n")
                elif last_fids[0] < last_fids[-1] and last_fids[1] < last_fids[-1]:
                    f.write("- Warning: FID scores are worsening, which could indicate overfitting or training instability.\n")
            
            # KID analysis (new)
            if kid_scores_hist and len(kid_scores_hist) >= 3:
                f.write("\n## KID Analysis\n\n")
                last_kids = kid_scores_hist[-3:]
                
                if last_kids[0] > last_kids[-1] and last_kids[1] > last_kids[-1]:
                    f.write("- KID scores are continuing to improve, suggesting better distributional match between real and generated images.\n")
                elif all(abs(last_kids[0] - kid) < 0.001 for kid in last_kids[1:]):
                    f.write("- KID scores have stabilized, indicating convergence in the feature distribution matching.\n")
                elif last_kids[0] < last_kids[-1] and last_kids[1] < last_kids[-1]:
                    f.write("- Warning: KID scores are worsening, which could indicate mode collapse or overfitting.\n")
                
            # FID vs KID comparison (new)
            if fid_scores_hist and kid_scores_hist and len(fid_scores_hist) == len(kid_scores_hist):
                f.write("\n## FID vs KID Comparison\n\n")
                
                # Calculate correlation between FID and KID
                try:
                    correlation = np.corrcoef(fid_scores_hist, kid_scores_hist)[0, 1]
                    f.write(f"- Correlation between FID and KID scores: {correlation:.4f}\n")
                    
                    if correlation > 0.8:
                        f.write("- FID and KID are strongly correlated, suggesting they are measuring similar aspects of generation quality.\n")
                    elif correlation > 0.5:
                        f.write("- FID and KID show moderate correlation, suggesting they capture somewhat different aspects of generation quality.\n")
                    else:
                        f.write("- FID and KID show weak correlation, suggesting they may be measuring different aspects of generation quality.\n")
                except:
                    f.write("- Could not calculate correlation between FID and KID.\n")
                
                # Check if best epochs align
                best_fid_epoch = fid_epochs_hist[fid_scores_hist.index(min(fid_scores_hist))]
                best_kid_epoch = kid_epochs_hist[kid_scores_hist.index(min(kid_scores_hist))]
                
                if best_fid_epoch == best_kid_epoch:
                    f.write(f"- Best FID and KID scores both occurred at the same epoch ({best_fid_epoch}), strongly validating this as the optimal model.\n")
                else:
                    f.write(f"- Best FID occurred at epoch {best_fid_epoch}, while best KID occurred at epoch {best_kid_epoch}.\n")
                    f.write(f"- This divergence suggests different aspects of quality peaked at different times during training.\n")
            
            # Generated Images
            f.write("\n## Generated Images\n\n")
            f.write("Sample images are saved in the `samples` directory.\n")
            
            # Plots
            f.write("\n## Plots\n\n")
            f.write("- Loss plot: `plots/loss_plot_v2151.png`\n")
            f.write("- FID plot: `plots/fid_plot_v2151.png`\n")
            f.write("- KID plot: `plots/kid_plot_v2151.png`\n")
            f.write("- Combined metrics plot: `plots/combined_metrics_plot_v2151.png`\n")
            f.write("- t-SNE feature space visualization: `plots/feature_space_tsne_v2151.png`\n")
            
            # Recommendations
            f.write("\n## Recommendations\n\n")
            
            if not g_losses_hist:
                f.write("- No training data available to make recommendations.\n")
            else:
                if fid_scores_hist and min(fid_scores_hist) > 100:
                    f.write("- FID scores are high. Consider training for more epochs or adjusting hyperparameters.\n")
                elif fid_scores_hist and min(fid_scores_hist) < 50:
                    f.write("- FID scores are good. The model is generating realistic images.\n")
                
                if kid_scores_hist and min(kid_scores_hist) > 0.1:
                    f.write("- KID scores are high. Consider training for more epochs or adjusting training dynamics.\n")
                elif kid_scores_hist and min(kid_scores_hist) < 0.05:
                    f.write("- KID scores are good. The model is generating realistic feature distributions.\n")
                
                if training_info.get('epochs', 0) < NUM_EPOCHS and training_info.get('stop_reason', '') not in ["Early stopping (FID)", "Early stopping (KID)"]:
                    f.write("- Training was interrupted before completion. Consider resuming training.\n")
                
        logger.info(f"Generated markdown report at {report_path}")
        return report_path
    except Exception as e:
        logger.error(f"Failed to generate markdown report: {e}", exc_info=True)
        return None

# --- Helper functions for KID calculation ---
def polynomial_kernel(X, Y):
    """
    Polynomial kernel for KID: k(x,y) = (gamma <x,y> + coef0)^degree
    Using carefully balanced parameters to prevent underflow/overflow.
    """
    # Convert to higher precision
    X = X.astype(np.float64)
    Y = Y.astype(np.float64)
    
    # Normalize features with slightly relaxed epsilon
    X_norm = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
    Y_norm = Y / (np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8)
    
    # More balanced parameters that won't underflow
    gamma = 0.2
    coef0 = 1.0
    degree = 3
    
    dot_product = np.matmul(X_norm, Y_norm.T)
    
    # Prevent exact zeros with lower bound
    return np.clip((gamma * dot_product + coef0) ** degree, 1e-8, 1e6)

def calculate_kid_from_features(real_features, fake_features, subset_size=1000, num_subsets=100):
    """
    Calculate KID given features extracted from Inception.
    Uses polynomial kernel and subsampling with safeguards against numerical issues.
    """
    # Use high precision
    real_features = real_features.astype(np.float64)
    fake_features = fake_features.astype(np.float64)
    
    # Center the features (remove mean) - this is still good practice
    real_features = real_features - np.mean(real_features, axis=0, keepdims=True)
    fake_features = fake_features - np.mean(fake_features, axis=0, keepdims=True)
    
    n_r, n_f = real_features.shape[0], fake_features.shape[0]
    
    subset_size = min(subset_size, min(n_r, n_f))
    kid_values = []
    
    # Verify inputs aren't identical
    if np.array_equal(real_features, fake_features):
        logger.warning("WARNING: real_features and fake_features are identical arrays! KID calculation will be biased.")
    
    for _ in range(num_subsets):
        # Sample subset_size features from both distributions
        r_idx = np.random.choice(n_r, size=subset_size, replace=False)
        f_idx = np.random.choice(n_f, size=subset_size, replace=False)
        
        r_subset = real_features[r_idx]
        f_subset = fake_features[f_idx]
        
        # Calculate polynomial kernel MMD (Maximum Mean Discrepancy)
        k_rr = polynomial_kernel(r_subset, r_subset)
        k_rf = polynomial_kernel(r_subset, f_subset)
        k_ff = polynomial_kernel(f_subset, f_subset)
        
        # Calculate unbiased MMD estimate with safeguards
        n = subset_size
        mmd_numerator = np.sum(k_rr) - np.trace(k_rr) + np.sum(k_ff) - np.trace(k_ff) - 2 * np.sum(k_rf)
        mmd_denominator = n * (n-1)
        
        # Prevent division by zero (should never happen with our subset size checks)
        if mmd_denominator <= 0:
            logger.warning("WARNING: Invalid denominator in KID calculation!")
            mmd = 0.01  # Fallback value
        else:
            mmd = mmd_numerator / mmd_denominator
        
        # Ensure non-negative MMD and prevent exact zeros
        mmd = max(1e-8, mmd)
        kid_values.append(mmd)
    
    return np.mean(kid_values), np.std(kid_values)

# --- FID Calculation Utilities ---
def get_inception_model(device):
    if not FID_AVAILABLE:
        raise RuntimeError("pytorch-fid library not available.")
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
    model = InceptionV3([block_idx]).to(device)
    model.eval()
    return model

def get_activations_from_data(dataloader, model, device, num_images, batch_size, desc=""):
    if not FID_AVAILABLE:
        logger.warning("pytorch-fid not available.")
        return None
    
    n_batches = ceil(num_images / batch_size)
    n_used_imgs = 0
    pred_list = []
    
    iterator = iter(dataloader)
    logger.info(f"Calculating activations for {num_images} {desc} images ({n_batches} batches)...")
    
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc=f"Activations {desc}", leave=False):
            try:
                batch = next(iterator).to(device)
                if isinstance(batch, (list, tuple)):
                    batch = batch[0]
                if batch.shape[0] == 0:
                    continue
                
                current_batch_size = batch.shape[0]
                
                if batch.dtype != torch.float32:
                    batch = batch.float()
                if batch.shape[1] == 1:
                    batch = batch.repeat(1, 3, 1, 1)
                if batch.shape[1] != 3:
                    raise ValueError(f"Batch needs 3 channels, got {batch.shape[1]}")
                
                batch = (batch * 0.5) + 0.5
                batch = torch.clamp(batch, 0.0, 1.0)  # Rescale [-1,1] to [0,1]
                
                pred = model(batch)[0]
                if pred.size(2) != 1 or pred.size(3) != 1:
                    pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
                
                pred_list.append(pred.squeeze(3).squeeze(2).cpu().numpy())
                n_used_imgs += current_batch_size
                
                if n_used_imgs >= num_images:
                    break
                    
                # Clean up to reduce memory usage
                del batch, pred
                
            except StopIteration:
                logger.warning(f"Dataloader exhausted early @ batch {i}. Using {n_used_imgs} images.")
                break
            except Exception as e:
                logger.error(f"Error during activation batch {i}: {e}", exc_info=True)
                return None
    
    if not pred_list:
        return None
    
    pred_arr = np.concatenate(pred_list, axis=0)
    pred_arr = pred_arr[:num_images]
    
    return pred_arr

def get_generated_activations(generator, inception_model, device, noise_dim, num_images, batch_size, desc=""):
    if not FID_AVAILABLE:
        logger.warning("pytorch-fid not available.")
        return None
    
    n_batches = ceil(num_images / batch_size)
    n_generated_imgs = 0
    pred_list = []
    
    generator.eval()
    logger.info(f"Generating {num_images} fake images & activations ({n_batches} batches)...")
    
    with torch.no_grad():
        for i in tqdm(range(n_batches), desc=f"Generating & Activating {desc}", leave=False):
            try:
                current_batch_size = min(batch_size, num_images - n_generated_imgs)
                if current_batch_size <= 0:
                    break
                
                noise = torch.randn(current_batch_size, noise_dim, 1, 1, device=device)
                generated_batch = generator(noise)
                
                if generated_batch.dtype != torch.float32:
                    generated_batch = generated_batch.float()
                if generated_batch.shape[1] == 1:
                    generated_batch = generated_batch.repeat(1, 3, 1, 1)
                if generated_batch.shape[1] != 3:
                    raise ValueError(f"Generated batch needs 3 channels, got {generated_batch.shape[1]}")
                
                generated_batch = (generated_batch * 0.5) + 0.5
                generated_batch = torch.clamp(generated_batch, 0.0, 1.0)  # Rescale
                
                pred = inception_model(generated_batch)[0]
                if pred.size(2) != 1 or pred.size(3) != 1:
                    pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
                
                pred_list.append(pred.squeeze(3).squeeze(2).cpu().numpy())
                n_generated_imgs += current_batch_size
                
                # Clean up to reduce memory usage
                del noise, generated_batch, pred
                
            except Exception as e:
                logger.error(f"Error during generated activation batch {i}: {e}", exc_info=True)
                generator.train()
                return None
    
    generator.train()
    
    if not pred_list:
        return None
    
    pred_arr = np.concatenate(pred_list, axis=0)
    pred_arr = pred_arr[:num_images]
    
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return pred_arr

def get_real_stats(real_dataloader, real_stats_path, real_features_path, inception_model, device, num_images, batch_size, force_recalculate=False):
    """Updated to also save raw feature array for KID calculation and t-SNE visualization"""
    if os.path.exists(real_stats_path) and os.path.exists(real_features_path) and not force_recalculate:
        logger.info(f"Loading pre-calculated real stats and features from: {real_stats_path} and {real_features_path}")
        try:
            stats = np.load(real_stats_path)
            mu_real, sigma_real = stats['mu'], stats['sigma']
            
            # Load real features (for KID calculation and visualization)
            real_features = np.load(real_features_path)
            
            if (mu_real is not None and sigma_real is not None and 
                mu_real.shape == (2048,) and sigma_real.shape == (2048, 2048) and
                real_features.shape[1] == 2048):
                 logger.info("Loaded real stats and features successfully.")
                 return mu_real, sigma_real, real_features
            else:
                logger.warning("Loaded real stats or features invalid. Recalculating...")
        except Exception as e:
            logger.warning(f"Could not load real stats or features file ({e}). Recalculating...")
    
    logger.info(f"Calculating FID stats and features for {num_images} real images...")
    real_activations = get_activations_from_data(real_dataloader, inception_model, device, num_images, batch_size, desc="Real")
    
    if real_activations is None or len(real_activations) < num_images:
        logger.error(f"Failed to get enough real activations. Cannot calculate stats.")
        return None, None, None
    
    mu_real = np.mean(real_activations, axis=0)
    sigma_real = np.cov(real_activations, rowvar=False)
    
    logger.info(f"Calculated real stats (mu: {mu_real.shape}, sigma: {sigma_real.shape}).")
    
    try:
        # Save stats for FID
        os.makedirs(os.path.dirname(real_stats_path), exist_ok=True)
        np.savez(real_stats_path, mu=mu_real, sigma=sigma_real)
        logger.info(f"Saved real FID stats to: {real_stats_path}")
        
        # Save raw features for KID and visualization
        np.save(real_features_path, real_activations)
        logger.info(f"Saved real features to: {real_features_path}")
    except Exception as e:
        logger.error(f"Failed to save real stats or features: {e}", exc_info=True)
    
    return mu_real, sigma_real, real_activations

def calculate_fid_and_kid(generator, inception_model, real_mu, real_sigma, real_features, device, noise_dim, num_images, batch_size):
    """Calculate both FID and KID metrics using the same generated features"""
    if not FID_AVAILABLE:
        logger.warning("pytorch-fid not available.")
        return float('inf'), (float('inf'), float('inf')), None
    
    if real_mu is None or real_sigma is None or real_features is None:
        logger.error("Real stats or features not available.")
        return float('inf'), (float('inf'), float('inf')), None
    
    logger.info(f"Calculating FID and KID using {num_images} generated images...")
    fake_features = get_generated_activations(generator, inception_model, device, noise_dim, num_images, batch_size, desc="Fake (FID/KID)")
    
    if fake_features is None or len(fake_features) < num_images:
        logger.error(f"Failed to get enough fake activations.")
        return float('inf'), (float('inf'), float('inf')), None
    
    # Calculate FID
    mu_fake = np.mean(fake_features, axis=0)
    sigma_fake = np.cov(fake_features, rowvar=False)
    
    logger.info("Calculating Frechet distance...")
    
    try:
        fid_value = calculate_frechet_distance(mu_fake, sigma_fake, real_mu, real_sigma)
        logger.info(f"Calculated FID: {fid_value:.4f}")
    except Exception as e:
        logger.error(f"Error calculating Frechet distance: {e}", exc_info=True)
        fid_value = float('inf')
    
    # Calculate KID
    logger.info(f"Calculating KID with {KID_SUBSET_SIZE} samples, {KID_SUBSETS} subsets...")
    try:
        # Check if we're using library function or our own implementation
        if 'calculate_kid_given_features' in globals():
            kid_mean, kid_std = calculate_kid_given_features(real_features, fake_features, 
                                                             subset_size=KID_SUBSET_SIZE,
                                                             num_subsets=KID_SUBSETS)
        else:
            kid_mean, kid_std = calculate_kid_from_features(real_features, fake_features, 
                                                           subset_size=KID_SUBSET_SIZE,
                                                           num_subsets=KID_SUBSETS)
        logger.info(f"Calculated KID: {kid_mean:.6f} ± {kid_std:.6f}")
    except Exception as e:
        logger.error(f"Error calculating KID: {e}", exc_info=True)
        kid_mean, kid_std = float('inf'), float('inf')
    
    # Clean up
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return fid_value, (kid_mean, kid_std), fake_features

# Function to generate all visualizations after training
def generate_all_visualizations(g_losses_hist, c_losses_hist, fid_scores_hist, fid_epochs_hist, kid_scores_hist, kid_std_hist, kid_epochs_hist, real_features=None, fake_features=None, save_dir=PLOT_DIR):
    """Generate all visualization plots in separate files"""
    logger.info("Generating all visualization plots...")
    
    # Generate individual plots
    plot_losses(g_losses_hist, c_losses_hist, save_dir)
    plot_fid(fid_scores_hist, fid_epochs_hist, save_dir)
    
    if kid_scores_hist and kid_epochs_hist:
        plot_kid(kid_scores_hist, kid_std_hist, kid_epochs_hist, save_dir)
        
        # Generate combined FID/KID plot if both are available
        if fid_scores_hist and fid_epochs_hist and len(fid_scores_hist) == len(kid_scores_hist):
            plot_combined_metrics(fid_scores_hist, kid_scores_hist, fid_epochs_hist, save_dir)
    
    # Generate feature space visualization if features are available
    if VISUALIZE_TSNE and real_features is not None and fake_features is not None:
        tsne_path = os.path.join(save_dir, "feature_space_tsne_v2151.png")
        plot_feature_space(real_features, fake_features, tsne_path)
    
    # Original plot_metrics call is commented out
    # plot_metrics(g_losses_hist, c_losses_hist, fid_scores_hist, fid_epochs_hist, save_dir)
    
    logger.info("All visualization plots generated.")

# ==============================================================================
# --- Main Training Execution ---
# ==============================================================================
if __name__ == "__main__": 
    
    # Wrap execution in try/finally for cleanup
    training_successful = False
    stop_reason = "Unknown"
    g_losses_hist = []
    c_losses_hist = []
    fid_scores_hist = []
    fid_epochs_hist = []
    kid_scores_hist = []  
    kid_std_hist = []     
    kid_epochs_hist = []  
    best_fid = float('inf')
    best_kid = float('inf')  
    epochs_no_improve = 0
    final_real_features = None  # For feature space visualization
    best_fake_features = None  # For feature space visualization from best model
    
    try: 
        # --- Log parameters ---
        logger.info("--- Starting WGAN-SN Training with enhancements (v2.151-modified) ---") 
        config = {k: v for k, v in globals().items() if k.isupper() and not k.startswith('_')}
        logger.info("Configuration:\n" + json.dumps(config, indent=4, default=str))
        
        # Save configuration
        config_save_path = os.path.join(OUTPUT_DIR, "training_config_v2151.json")
        try:
            with open(config_save_path, 'w') as f:
                json.dump(config, f, indent=4, default=str)
            logger.info(f"Saved configuration to {config_save_path}")
        except Exception as e:
            logger.error(f"Failed to save configuration: {e}")

        # --- Setup Device ---
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {device}")
        
        if device.type == "cuda":
            logger.info(f"CUDA Version: {torch.version.cuda}")
            gpu_name = torch.cuda.get_device_name(0)
            logger.info(f"GPU Name: {gpu_name}")
        
        # Disable AMP as requested
        amp_enabled = False
        logger.info(f"AMP (Automatic Mixed Precision) enabled: {amp_enabled}")
        
        # Track initial memory
        log_gpu_memory_usage("Initialization")

        # --- Validate directories before continuing ---
        if not validate_directory(PREPROCESSED_DATA_DIR):
            logger.critical(f"Input data directory not found: {PREPROCESSED_DATA_DIR}")
            raise FileNotFoundError(f"Input data directory not found: {PREPROCESSED_DATA_DIR}")

        # --- Setup Dataset and DataLoader ---
        logger.info("Setting up Dataset and DataLoader...")
        # Define transformations, including normalization to [-1, 1] and augmentations
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        
        try:
            # Instantiate imported dataset class
            dataset = PollenDataset(
                root_dir=PREPROCESSED_DATA_DIR, 
                transform=transform,
                image_size=IMAGE_SIZE,      
                channels_img=CHANNELS_IMG 
            )
            if len(dataset) == 0:
                raise ValueError("Dataset is empty.")
            
            # Optimize worker count based on CPU cores
            try:
                dataloader_num_workers = min(max(os.cpu_count() // 2, 1), 4)  # Cap at 4
            except:
                dataloader_num_workers = 2  # Default if can't determine
                
            dataloader_pin_memory = (device.type == 'cuda')
            dataloader_persistent_workers = False  # Disabled as requested for memory efficiency
            
            logger.info(f"DataLoader using num_workers={dataloader_num_workers}, pin_memory={dataloader_pin_memory}, persistent_workers={dataloader_persistent_workers}.")
            
            dataloader = DataLoader(
                dataset, 
                batch_size=BATCH_SIZE, 
                shuffle=True, 
                num_workers=dataloader_num_workers, 
                pin_memory=dataloader_pin_memory, 
                persistent_workers=dataloader_persistent_workers,
                drop_last=True
            ) 
            
            logger.info(f"DataLoader created with {len(dataloader)} batches per epoch.")
            
        except Exception as e:
            logger.error(f"Failed to create Dataset/DataLoader: {e}", exc_info=True)
            raise  # Let the outer try/except handle the error

        # --- Initialize Models and Optimizers ---
        logger.info("Initializing models and optimizers...")
        generator = Generator(NOISE_DIM, CHANNELS_IMG, G_FEATURES).to(device)
        critic = CriticSN(CHANNELS_IMG, C_FEATURES).to(device) 
        initialize_weights(generator)
        initialize_weights(critic)
        logger.info("Models initialized with specified weights.")

        opt_gen = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))
        opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))

        # Initialize GradScalers using updated API
        scaler_critic = GradScaler(enabled=amp_enabled)
        scaler_gen = GradScaler(enabled=amp_enabled)

        # --- Prepare for FID/KID Calculation if enabled ---
        inception_model = None
        real_mu = None
        real_sigma = None
        real_features = None
        
        if (CALCULATE_FID or CALCULATE_KID) and FID_AVAILABLE:
            logger.info("Preparing for FID/KID calculation...")
            try:
                inception_model = get_inception_model(device)
                logger.info("InceptionV3 model loaded for FID/KID.")
                
                fid_transform = transforms.Compose([
                    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5] * CHANNELS_IMG, [0.5] * CHANNELS_IMG)
                ])
                
                fid_dataset = PollenDataset(
                    PREPROCESSED_DATA_DIR, 
                    transform=fid_transform, 
                    image_size=IMAGE_SIZE, 
                    channels_img=CHANNELS_IMG
                )
                
                actual_fid_num_images = min(FID_NUM_IMAGES, len(fid_dataset))
                if actual_fid_num_images < FID_NUM_IMAGES:
                    logger.warning(f"Using {actual_fid_num_images} images for FID/KID based on dataset size.")
                
                fid_dataloader = DataLoader(
                    fid_dataset, 
                    batch_size=FID_BATCH_SIZE, 
                    shuffle=False, 
                    num_workers=dataloader_num_workers, 
                    pin_memory=dataloader_pin_memory
                )
                
                logger.info(f"Created DataLoader for real image FID/KID stats ({len(fid_dataloader)} batches).")
                
                real_mu, real_sigma, real_features = get_real_stats(
                    fid_dataloader, 
                    REAL_STATS_PATH,
                    REAL_FEATURES_PATH,
                    inception_model, 
                    device, 
                    actual_fid_num_images, 
                    FID_BATCH_SIZE, 
                    FORCE_RECALCULATE_REAL_STATS
                )
                
                if real_mu is None or real_sigma is None or real_features is None:
                    logger.error("Failed to get real image FID/KID statistics. Disabling FID/KID calculation.")
                    CALCULATE_FID = False
                    CALCULATE_KID = False
                else:
                    logger.info("Successfully obtained real image statistics and features.")
                    # Save features for visualization
                    final_real_features = real_features
                
                # Cleanup for memory efficiency
                del fid_dataloader, fid_dataset, fid_transform
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
            except Exception as fid_setup_e:
                logger.error(f"Error during FID/KID setup: {fid_setup_e}", exc_info=True)
                logger.error("Disabling FID/KID calculation.")
                CALCULATE_FID = False
                CALCULATE_KID = False
                inception_model = None

        # --- Load Checkpoint if Resuming ---
        start_epoch = 0
        global_step = 0
        
        if RESUME_TRAINING:
            start_epoch, global_step, g_losses_hist, c_losses_hist, fid_scores_hist, fid_epochs_hist, kid_scores_hist, kid_std_hist, kid_epochs_hist, best_fid, best_kid, epochs_no_improve = load_checkpoint(
                CHECKPOINT_FILE, generator, critic, opt_gen, opt_critic, scaler_gen, scaler_critic
            ) 
            
            # Ensure models and optimizer states are on the correct device after loading
            generator.to(device) 
            critic.to(device)
            
            for state in opt_gen.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)
                        
            for state in opt_critic.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)

        # --- Prepare for Training ---
        fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device) # Fixed noise for consistent samples

        # Set models to training mode
        generator.train()
        critic.train()
        logger.info(f"--- Starting Training Loop from Epoch {start_epoch+1}, Step {global_step} ---")

        # Check memory usage before training
        log_gpu_memory_usage("Before Training Loop")

        # ==================== TRAINING LOOP ====================
        early_stop_triggered = False
        training_start_time = time.time()
        
        for epoch in range(start_epoch, NUM_EPOCHS):
            epoch_start_time = time.time()
            # Progress bar for batches within the epoch
            loop_pbar = tqdm(enumerate(dataloader), total=len(dataloader), leave=True, desc=f"Epoch [{epoch+1}/{NUM_EPOCHS}]") 
            
            # Track average losses for the epoch
            avg_loss_c_epoch = 0.0
            avg_loss_g_epoch = 0.0
            batches_in_epoch = 0

            for batch_idx, real_images in loop_pbar:
                
                # --- Optional: GPU Temperature Check ---
                if MONITOR_TEMP and NVML_AVAILABLE and (global_step % 200 == 0): 
                     gpu_temp = get_gpu_temp(GPU_ID)
                     if gpu_temp is not None:
                          logger.info(f"Step {global_step} | GPU Temp: {gpu_temp}°C")
                          if gpu_temp > GPU_TEMP_THRESHOLD:
                               logger.warning(f"GPU Temperature {gpu_temp}°C exceeded threshold {GPU_TEMP_THRESHOLD}°C!")
                               logger.warning("Saving checkpoint and stopping training gracefully.")
                               save_checkpoint({
                                   'epoch': epoch, 
                                   'step': global_step, 
                                   'generator_state_dict': generator.state_dict(),
                                   'critic_state_dict': critic.state_dict(),
                                   'optimizer_gen_state_dict': opt_gen.state_dict(),
                                   'optimizer_critic_state_dict': opt_critic.state_dict(),
                                   'scaler_gen_state_dict': scaler_gen.state_dict(),
                                   'scaler_critic_state_dict': scaler_critic.state_dict(),
                                   'g_losses_history': g_losses_hist,
                                   'c_losses_history': c_losses_hist,
                                   'fid_scores_history': fid_scores_hist,
                                   'fid_epochs_history': fid_epochs_hist,
                                   'kid_scores_history': kid_scores_hist,
                                   'kid_std_history': kid_std_hist,
                                   'kid_epochs_history': kid_epochs_hist,
                                   'best_fid': best_fid,
                                   'best_kid': best_kid,
                                   'epochs_no_improve': epochs_no_improve
                               }, CHECKPOINT_FILE)
                               stop_reason = f"GPU Temp {gpu_temp}C > Threshold {GPU_TEMP_THRESHOLD}C"
                               early_stop_triggered = True
                               break

                # --- Main Training Step ---
                try:
                    if real_images is None: 
                         logger.warning(f"Skipping batch {batch_idx} due to None data.")
                         continue
                         
                    real_images = real_images.to(device)
                    cur_batch_size = real_images.shape[0]
                    if cur_batch_size == 0:
                        continue 

                    # --- Train Critic ---
                    critic_loss_accum_iter = 0.0
                    opt_critic.zero_grad(set_to_none=True) 
                    
                    for _ in range(CRITIC_ITERATIONS): 
                        noise = torch.randn(cur_batch_size, NOISE_DIM, 1, 1).to(device)
                        
                        with autocast(device_type='cuda', enabled=amp_enabled):
                             with torch.no_grad():
                                 fake_images = generator(noise) 
                             critic_real = critic(real_images).reshape(-1)
                             critic_fake = critic(fake_images).reshape(-1) 
                             loss_critic = torch.mean(critic_fake) - torch.mean(critic_real)
                             
                        critic_loss_accum_iter += loss_critic.item()
                        
                        # Scale the loss for backprop
                        scaler_critic.scale(loss_critic).backward()
                    
                    # Update critic with gradient clipping
                    scaler_critic.unscale_(opt_critic)
                    torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=1.0)
                    scaler_critic.step(opt_critic)
                    scaler_critic.update()
                    
                    avg_loss_c_iter = critic_loss_accum_iter / CRITIC_ITERATIONS 

                    # --- Train Generator --- 
                    opt_gen.zero_grad(set_to_none=True)
                    
                    with autocast(device_type='cuda', enabled=amp_enabled):
                         noise_for_g = torch.randn(cur_batch_size, NOISE_DIM, 1, 1).to(device)
                         fake_images_for_g = generator(noise_for_g)
                         critic_fake_for_gen = critic(fake_images_for_g).reshape(-1) 
                         loss_gen = -torch.mean(critic_fake_for_gen)
                    
                    # Scale, backward, clip, and step
                    scaler_gen.scale(loss_gen).backward()
                    scaler_gen.unscale_(opt_gen)
                    torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
                    scaler_gen.step(opt_gen)
                    scaler_gen.update()

                    # --- Logging and Visualization ---
                    loss_g_item = loss_gen.item() 
                    avg_loss_c_epoch += avg_loss_c_iter
                    avg_loss_g_epoch += loss_g_item
                    batches_in_epoch += 1

                    # Periodically log losses
                    if global_step % 100 == 0: 
                         logger.debug(f"Step {global_step} | Loss C: {avg_loss_c_iter:.4f}, Loss G: {loss_g_item:.4f}")
                         # Track memory usage
                         log_gpu_memory_usage(f"Step {global_step}")

                    # Save sample images periodically based on global step
                    if global_step > 0 and global_step % SAMPLE_FREQ_STEPS == 0:
                        logger.info(f"Saving samples at step {global_step}")
                        generator.eval()
                        critic.eval() 
                        
                        with torch.no_grad():
                            with autocast(device_type='cuda', enabled=amp_enabled): 
                                fake_samples = generator(fixed_noise) 
                            img_grid = vutils.make_grid(fake_samples * 0.5 + 0.5, normalize=False) 
                            vutils.save_image(img_grid, os.path.join(SAMPLE_DIR, f"sample_{epoch+1:04d}_{global_step:07d}.png"))
                        
                        generator.train()
                        critic.train() 
                        
                        # Clean up sample generation tensors
                        del fake_samples, img_grid
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()

                    # Update progress bar description
                    loop_pbar.set_description(f"Epoch [{epoch+1}/{NUM_EPOCHS}]") 
                    loop_pbar.set_postfix(loss_C=avg_loss_c_iter, loss_G=loss_g_item, step=global_step)
                    
                    # Clean up tensors to reduce memory usage
                    del real_images, noise, fake_images, critic_real, critic_fake, loss_critic
                    del noise_for_g, fake_images_for_g, critic_fake_for_gen, loss_gen
                    
                    global_step += 1

                # --- Error Handling for Batch ---
                except RuntimeError as e:
                    logger.error(f"Runtime error processing batch {batch_idx} at Step {global_step}: {e}", exc_info=True) 
                    if "out of memory" in str(e).lower():
                        logger.error(f"CUDA Out Of Memory! Batch Size: {BATCH_SIZE}. Consider reducing batch size.")
                        logger.warning("Attempting to save checkpoint before stopping...")
                        save_checkpoint({
                            'epoch': epoch, 
                            'step': global_step, 
                            'generator_state_dict': generator.state_dict(),
                            'critic_state_dict': critic.state_dict(),
                            'optimizer_gen_state_dict': opt_gen.state_dict(),
                            'optimizer_critic_state_dict': opt_critic.state_dict(),
                            'scaler_gen_state_dict': scaler_gen.state_dict(),
                            'scaler_critic_state_dict': scaler_critic.state_dict(),
                            'g_losses_history': g_losses_hist,
                            'c_losses_history': c_losses_hist,
                            'fid_scores_history': fid_scores_hist,
                            'fid_epochs_history': fid_epochs_hist,
                            'kid_scores_history': kid_scores_hist,
                            'kid_std_history': kid_std_hist,
                            'kid_epochs_history': kid_epochs_hist,
                            'best_fid': best_fid,
                            'best_kid': best_kid,
                            'epochs_no_improve': epochs_no_improve
                        }, CHECKPOINT_FILE)
                        stop_reason = "CUDA Out of Memory"
                        raise e 
                    raise e 
                    
                except Exception as e: 
                    logger.error(f"Generic error processing batch {batch_idx} at step {global_step}: {e}", exc_info=True)
                    raise e
            
            # Check for early stopping
            if early_stop_triggered:
                break
                
            # --- End of Batch Loop ---

            # --- End of Epoch ---
            epoch_duration = time.time() - epoch_start_time
            if batches_in_epoch > 0:
                 avg_loss_c_epoch /= batches_in_epoch
                 avg_loss_g_epoch /= batches_in_epoch
                 
            # Append losses to history
            g_losses_hist.append(avg_loss_g_epoch)
            c_losses_hist.append(avg_loss_c_epoch)
                 
            logger.info(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Completed in {epoch_duration:.2f}s | Avg Loss C: {avg_loss_c_epoch:.4f} | Avg Loss G: {avg_loss_g_epoch:.4f}")

            # --- Memory cleanup ---
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            log_gpu_memory_usage(f"After Epoch {epoch+1}")
            
            # --- FID/KID Calculation ---
            if (CALCULATE_FID or CALCULATE_KID) and FID_AVAILABLE and inception_model is not None:
                current_fid, (current_kid, current_kid_std), fake_features_epoch = calculate_fid_and_kid(
                    generator, 
                    inception_model, 
                    real_mu, 
                    real_sigma,
                    real_features,
                    device, 
                    NOISE_DIM, 
                    FID_NUM_IMAGES, 
                    FID_BATCH_SIZE
                )
                
                # Track FID if calculated and valid
                if CALCULATE_FID and current_fid != float('inf'):
                    fid_scores_hist.append(current_fid)
                    fid_epochs_hist.append(epoch + 1)
                    logger.info(f"--- FID Score @ Epoch {epoch+1}: {current_fid:.4f} ---")
                
                # Track KID if calculated and valid
                if CALCULATE_KID and current_kid != float('inf'):
                    kid_scores_hist.append(current_kid)
                    kid_std_hist.append(current_kid_std)
                    kid_epochs_hist.append(epoch + 1)
                    logger.info(f"--- KID Score @ Epoch {epoch+1}: {current_kid:.6f} ± {current_kid_std:.6f} ---")
                
                # Early stopping check
                if USE_EARLY_STOPPING:
                    metric_improved = False
                    
                    # Check if the primary chosen metric improved
                    if PRIMARY_EVAL_METRIC == "FID" and CALCULATE_FID and current_fid < best_fid:
                        logger.info(f"FID improved: {best_fid:.4f} -> {current_fid:.4f}. Saving best FID checkpoint.")
                        best_fid = current_fid
                        metric_improved = True
                        
                        # Save best FID checkpoint (using fixed name to overwrite previous best)
                        save_checkpoint({
                            'epoch': epoch + 1, 
                            'step': global_step, 
                            'generator_state_dict': generator.state_dict(),
                            'critic_state_dict': critic.state_dict(),
                            'optimizer_gen_state_dict': opt_gen.state_dict(),
                            'optimizer_critic_state_dict': opt_critic.state_dict(),
                            'scaler_gen_state_dict': scaler_gen.state_dict(),
                            'scaler_critic_state_dict': scaler_critic.state_dict(),
                            'g_losses_history': g_losses_hist,
                            'c_losses_history': c_losses_hist,
                            'fid_scores_history': fid_scores_hist,
                            'fid_epochs_history': fid_epochs_hist,
                            'kid_scores_history': kid_scores_hist,
                            'kid_std_history': kid_std_hist,
                            'kid_epochs_history': kid_epochs_hist,
                            'best_fid': best_fid,
                            'best_kid': best_kid,
                            'epochs_no_improve': epochs_no_improve
                        }, BEST_FID_CHECKPOINT_FILE)
                        
                        # Save the best features for visualization
                        if fake_features_epoch is not None:
                            best_fake_features = fake_features_epoch
                            save_fake_features(fake_features_epoch)
                        
                    elif PRIMARY_EVAL_METRIC == "KID" and CALCULATE_KID and current_kid < best_kid:
                        logger.info(f"KID improved: {best_kid:.6f} -> {current_kid:.6f}. Saving best KID checkpoint.")
                        best_kid = current_kid
                        metric_improved = True
                        
                        # Save best KID checkpoint (using fixed name to overwrite previous best)
                        save_checkpoint({
                            'epoch': epoch + 1, 
                            'step': global_step, 
                            'generator_state_dict': generator.state_dict(),
                            'critic_state_dict': critic.state_dict(),
                            'optimizer_gen_state_dict': opt_gen.state_dict(),
                            'optimizer_critic_state_dict': opt_critic.state_dict(),
                            'scaler_gen_state_dict': scaler_gen.state_dict(),
                            'scaler_critic_state_dict': scaler_critic.state_dict(),
                            'g_losses_history': g_losses_hist,
                            'c_losses_history': c_losses_hist,
                            'fid_scores_history': fid_scores_hist,
                            'fid_epochs_history': fid_epochs_hist,
                            'kid_scores_history': kid_scores_hist,
                            'kid_std_history': kid_std_hist,
                            'kid_epochs_history': kid_epochs_hist,
                            'best_fid': best_fid,
                            'best_kid': best_kid,
                            'epochs_no_improve': epochs_no_improve
                        }, BEST_KID_CHECKPOINT_FILE)
                        
                        # Save the best features for visualization
                        if fake_features_epoch is not None:
                            best_fake_features = fake_features_epoch
                            save_fake_features(fake_features_epoch)
                    
                    # Track non-primary metric improvements too, but don't consider for early stopping
                    if PRIMARY_EVAL_METRIC == "KID" and CALCULATE_FID and current_fid < best_fid:
                        logger.info(f"FID improved: {best_fid:.4f} -> {current_fid:.4f}. (Not primary metric)")
                        best_fid = current_fid
                        
                        # Save best FID checkpoint (using fixed name to overwrite previous best)
                        save_checkpoint({
                            'epoch': epoch + 1, 
                            'step': global_step, 
                            'generator_state_dict': generator.state_dict(),
                            'critic_state_dict': critic.state_dict(),
                            'optimizer_gen_state_dict': opt_gen.state_dict(),
                            'optimizer_critic_state_dict': opt_critic.state_dict(),
                            'scaler_gen_state_dict': scaler_gen.state_dict(),
                            'scaler_critic_state_dict': scaler_critic.state_dict(),
                            'g_losses_history': g_losses_hist,
                            'c_losses_history': c_losses_hist,
                            'fid_scores_history': fid_scores_hist,
                            'fid_epochs_history': fid_epochs_hist,
                            'kid_scores_history': kid_scores_hist,
                            'kid_std_history': kid_std_hist,
                            'kid_epochs_history': kid_epochs_hist,
                            'best_fid': best_fid,
                            'best_kid': best_kid,
                            'epochs_no_improve': epochs_no_improve
                        }, BEST_FID_CHECKPOINT_FILE)
                        
                    elif PRIMARY_EVAL_METRIC == "FID" and CALCULATE_KID and current_kid < best_kid:
                        logger.info(f"KID improved: {best_kid:.6f} -> {current_kid:.6f}. (Not primary metric)")
                        best_kid = current_kid
                        
                        # Save best KID checkpoint (using fixed name to overwrite previous best)
                        save_checkpoint({
                            'epoch': epoch + 1, 
                            'step': global_step, 
                            'generator_state_dict': generator.state_dict(),
                            'critic_state_dict': critic.state_dict(),
                            'optimizer_gen_state_dict': opt_gen.state_dict(),
                            'optimizer_critic_state_dict': opt_critic.state_dict(),
                            'scaler_gen_state_dict': scaler_gen.state_dict(),
                            'scaler_critic_state_dict': scaler_critic.state_dict(),
                            'g_losses_history': g_losses_hist,
                            'c_losses_history': c_losses_hist,
                            'fid_scores_history': fid_scores_hist,
                            'fid_epochs_history': fid_epochs_hist,
                            'kid_scores_history': kid_scores_hist,
                            'kid_std_history': kid_std_hist,
                            'kid_epochs_history': kid_epochs_hist,
                            'best_fid': best_fid,
                            'best_kid': best_kid,
                            'epochs_no_improve': epochs_no_improve
                        }, BEST_KID_CHECKPOINT_FILE)
                            
                    # Handle early stopping based on primary metric
                    if metric_improved:
                        epochs_no_improve = 0
                        # Always save the best model based on the primary metric
                        save_best_model(generator)
                    else:
                        epochs_no_improve += 1
                        logger.info(f"{PRIMARY_EVAL_METRIC} did not improve. Patience: {epochs_no_improve}/{EARLY_STOPPING_PATIENCE}.")
                        
                        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
                            logger.warning(f"--- Early stopping triggered after {epochs_no_improve} epochs without {PRIMARY_EVAL_METRIC} improvement. ---")
                            stop_reason = f"Early stopping ({PRIMARY_EVAL_METRIC})"
                            early_stop_triggered = True
            
            # --- Memory cleanup after FID/KID --- 
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # --- Checkpointing ---
            checkpoint_state = {
                'epoch': epoch + 1, 
                'step': global_step,
                'generator_state_dict': generator.state_dict(),
                'critic_state_dict': critic.state_dict(),
                'optimizer_gen_state_dict': opt_gen.state_dict(),
                'optimizer_critic_state_dict': opt_critic.state_dict(),
                'scaler_gen_state_dict': scaler_gen.state_dict(),
                'scaler_critic_state_dict': scaler_critic.state_dict(),
                'g_losses_history': g_losses_hist,
                'c_losses_history': c_losses_hist,
                'fid_scores_history': fid_scores_hist,
                'fid_epochs_history': fid_epochs_hist,
                'kid_scores_history': kid_scores_hist,
                'kid_std_history': kid_std_hist,
                'kid_epochs_history': kid_epochs_hist,
                'best_fid': best_fid,
                'best_kid': best_kid,
                'epochs_no_improve': epochs_no_improve
            }
            
            # Save checkpoint every N epochs or if it's the last epoch
            if (epoch + 1) % CHECKPOINT_FREQ_EPOCHS == 0 or (epoch + 1) == NUM_EPOCHS:
                save_checkpoint(checkpoint_state, f"checkpoint_epoch_{epoch+1:04d}_sn_v2151.pth.tar")
                
            # Always save the latest checkpoint
            save_checkpoint(checkpoint_state, CHECKPOINT_FILE) 
            
            # Check for early stopping
            if early_stop_triggered:
                break

        # --- End of Epoch Loop ---

        # --- Training Finished ---
        total_training_time = time.time() - training_start_time
        
        logger.info("="*60)
        if not early_stop_triggered:
            stop_reason = f"Completed {NUM_EPOCHS} epochs"
            
        training_successful = True
        logger.info("--- Training Finished Successfully ---")
        logger.info(f"Reason: {stop_reason}")
        logger.info(f"Total Training Time: {total_training_time:.2f} seconds")
        logger.info(f"Final Global Step: {global_step}")
        
        if fid_scores_hist:
            logger.info(f"Best FID Score Achieved: {best_fid:.4f}")
        if kid_scores_hist:
            logger.info(f"Best KID Score Achieved: {best_kid:.6f}")
            
        # Generate final plots and visualizations
        logger.info("Generating final visualization plots...")
        
        # Use best features for visualization if available, otherwise use last epoch features
        if best_fake_features is None and os.path.exists(BEST_FEATURES_PATH):
            try:
                best_fake_features = np.load(BEST_FEATURES_PATH)
                logger.info(f"Loaded best fake features for visualization from: {BEST_FEATURES_PATH}")
            except Exception as e:
                logger.warning(f"Failed to load best fake features: {e}")
        
        generate_all_visualizations(
            g_losses_hist, c_losses_hist, 
            fid_scores_hist, fid_epochs_hist, 
            kid_scores_hist, kid_std_hist, kid_epochs_hist,
            real_features=final_real_features,
            fake_features=best_fake_features,  # Use best features for visualization
            save_dir=PLOT_DIR
        )
        
        # If we never saved the best model (no FID/KID improvement), save the final model
        if not os.path.exists(os.path.join(OUTPUT_DIR, BEST_MODEL_FILE)):
            logger.info("No best model saved during training. Saving final model...")
            save_best_model(generator)
            
        # Collect training info for report
        training_info = {
            'epochs': epoch + 1 if 'epoch' in locals() else 0,
            'stop_reason': stop_reason,
            'training_time': total_training_time,
            'best_fid': best_fid if best_fid != float('inf') else None,
            'best_kid': best_kid if best_kid != float('inf') else None
        }
            
        # Generate markdown report
        generate_markdown_report(g_losses_hist, c_losses_hist, fid_scores_hist, fid_epochs_hist, kid_scores_hist, kid_epochs_hist, training_info)

    except KeyboardInterrupt:
        logger.warning("--- Training Interrupted by User ---")
        stop_reason = "Manual Interruption"
        
        logger.warning("Attempting to save checkpoint and generate plots before exit...")
        
        if 'generator' in locals() and 'critic' in locals() and 'g_losses_hist' in locals():
            save_checkpoint({
                'epoch': epoch if 'epoch' in locals() else 0, 
                'step': global_step if 'global_step' in locals() else 0, 
                'generator_state_dict': generator.state_dict(),
                'critic_state_dict': critic.state_dict(),
                'optimizer_gen_state_dict': opt_gen.state_dict(),
                'optimizer_critic_state_dict': opt_critic.state_dict(),
                'scaler_gen_state_dict': scaler_gen.state_dict() if 'scaler_gen' in locals() else None,
                'scaler_critic_state_dict': scaler_critic.state_dict() if 'scaler_critic' in locals() else None,
                'g_losses_history': g_losses_hist,
                'c_losses_history': c_losses_hist,
                'fid_scores_history': fid_scores_hist,
                'fid_epochs_history': fid_epochs_hist,
                'kid_scores_history': kid_scores_hist,
                'kid_std_history': kid_std_hist,
                'kid_epochs_history': kid_epochs_hist,
                'best_fid': best_fid,
                'best_kid': best_kid,
                'epochs_no_improve': epochs_no_improve
            }, CHECKPOINT_FILE)
            
            # Check if best features exist
            if best_fake_features is None and os.path.exists(BEST_FEATURES_PATH):
                try:
                    best_fake_features = np.load(BEST_FEATURES_PATH)
                    logger.info(f"Loaded best fake features for visualization from: {BEST_FEATURES_PATH}")
                except Exception as e:
                    logger.warning(f"Failed to load best fake features: {e}")
            
            # Generate plots with available data
            generate_all_visualizations(
                g_losses_hist, c_losses_hist, 
                fid_scores_hist, fid_epochs_hist, 
                kid_scores_hist, kid_std_hist, kid_epochs_hist,
                real_features=final_real_features,
                fake_features=best_fake_features,  # Use best features for visualization if available
                save_dir=PLOT_DIR
            )
        
    except Exception as main_e:
        # Catch any unexpected error during setup or the main loop that wasn't handled inside
        logger.critical(f"Critical error during training setup or execution: {main_e}", exc_info=True)
        stop_reason = f"Error: {str(main_e)}"
        
        # Try to save checkpoint if models exist
        if 'generator' in locals() and 'critic' in locals() and 'g_losses_hist' in locals():
            logger.warning("Attempting to save checkpoint before exit...")
            save_checkpoint({
                'epoch': epoch if 'epoch' in locals() else 0, 
                'step': global_step if 'global_step' in locals() else 0, 
                'generator_state_dict': generator.state_dict(),
                'critic_state_dict': critic.state_dict(),
                'optimizer_gen_state_dict': opt_gen.state_dict(),
                'optimizer_critic_state_dict': opt_critic.state_dict(),
                'scaler_gen_state_dict': scaler_gen.state_dict() if 'scaler_gen' in locals() else None,
                'scaler_critic_state_dict': scaler_critic.state_dict() if 'scaler_critic' in locals() else None,
                'g_losses_history': g_losses_hist,
                'c_losses_history': c_losses_hist,
                'fid_scores_history': fid_scores_hist,
                'fid_epochs_history': fid_epochs_hist,
                'kid_scores_history': kid_scores_hist,
                'kid_std_history': kid_std_hist,
                'kid_epochs_history': kid_epochs_hist,
                'best_fid': best_fid,
                'best_kid': best_kid,
                'epochs_no_improve': epochs_no_improve
            }, CHECKPOINT_FILE)
            
            # Check if we have a generated sample to save
            if 'fake_features_epoch' in locals() and fake_features_epoch is not None:
                save_fake_features(fake_features_epoch)
        
    finally: 
        # This block ALWAYS runs, whether the try succeeded or an exception occurred
        logger.info("--- Running Final Cleanup ---")
        
        # Final memory usage report
        log_gpu_memory_usage("Final")
        
        # Cleanup NVML
        if NVML_AVAILABLE:
            try: pynvml.nvmlShutdown()
            except: pass
            
        # Run analysis scripts after training is complete
        logger.info("--- Running Analysis Scripts ---")
        try:
            # Import analysis modules - only re-import if needed to avoid namespace conflicts
            if 'plt' not in globals():
                import matplotlib.pyplot as plt
            import re
            import glob
            from PIL import Image
            
            # Log script execution
            logger.info(f"Starting training progress analysis...")
            
            # If we didn't already generate visualizations, do it now
            if 'g_losses_hist' in locals() and len(g_losses_hist) > 0 and 'c_losses_hist' in locals() and len(c_losses_hist) > 0:
                # Collect training info for report
                if 'training_info' not in locals():
                    training_info = {
                        'epochs': len(g_losses_hist),
                        'stop_reason': stop_reason if 'stop_reason' in locals() else "Unknown",
                        'training_time': total_training_time if 'total_training_time' in locals() else 0,
                        'best_fid': best_fid if 'best_fid' in locals() and best_fid != float('inf') else None,
                        'best_kid': best_kid if 'best_kid' in locals() and best_kid != float('inf') else None
                    }
                
                # Try to load best features for visualization if not already in memory
                if best_fake_features is None and os.path.exists(BEST_FEATURES_PATH):
                    try:
                        best_fake_features = np.load(BEST_FEATURES_PATH)
                        logger.info(f"Loaded best fake features for visualization from: {BEST_FEATURES_PATH}")
                    except Exception as e:
                        logger.warning(f"Failed to load best fake features: {e}")
                
                # Generate visualizations if not already done
                generate_all_visualizations(
                    g_losses_hist, c_losses_hist, 
                    fid_scores_hist, fid_epochs_hist, 
                    kid_scores_hist, kid_std_hist, kid_epochs_hist,
                    real_features=final_real_features if 'final_real_features' in locals() else None,
                    fake_features=best_fake_features if 'best_fake_features' in locals() else None,
                    save_dir=PLOT_DIR
                )
                    
                # Generate markdown report if not already generated
                generate_markdown_report(
                    g_losses_hist, c_losses_hist, 
                    fid_scores_hist, fid_epochs_hist, 
                    kid_scores_hist, kid_epochs_hist, 
                    training_info
                )
                
            else:
                logger.warning("No loss history available for analysis or visualizations already generated.")
                
        except Exception as analysis_error:
            logger.error(f"Error during analysis script execution: {analysis_error}", exc_info=True)
        
        logger.info("--- Shutting down logging ---")
        logging.shutdown()

# ==============================================================================
# --- End of Script ---
# ==============================================================================