In [None]:
# "script_v2-151_2A1_continuation-modified[dot]ipynb"

In [None]:
#!/usr/bin/env python
# Enhanced Continuation Script for "WGAN-SN v2.151" with Comprehensive Improvements
# Builds upon the original training with additional features for better monitoring,
# stability, metrics, and visualizations

import os
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend for server environments
import matplotlib.pyplot as plt
import json
import time
import glob
import logging
import sys
import gc
import random
import traceback
import importlib
import subprocess
import shutil
import re
from datetime import datetime
from math import ceil
from tqdm import tqdm
from PIL import Image, UnidentifiedImageError
from collections import deque

print("!!! setting the 'TF_ENABLE_ONEDNN_OPTS' value to '0' for avoiding the 'oneDNN custom operations' message in powershell console !!!")
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

# --- Auto-installation Block ---
def install_and_import(package_name, import_name=None, pip_name=None):
    """Tries to import a package, installs it via pip if import fails."""
    if import_name is None:
        import_name = package_name
    if pip_name is None:
        pip_name = package_name
    try:
        module = importlib.import_module(package_name)
        globals()[import_name] = module
        print(f"Successfully imported {package_name} as {import_name}")
        return True
    except ImportError:
        print(f"{package_name} not found. Attempting installation using pip...")
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name])
            module = importlib.import_module(package_name)
            globals()[import_name] = module
            print(f"Successfully installed and imported {package_name} as {import_name}")
            return True
        except (subprocess.CalledProcessError, ImportError, ModuleNotFoundError) as e:
            print(f"ERROR: Failed to install/import {package_name} (pip name: {pip_name}). {e}")
            print("Please install required packages manually and restart the script.")
            return False

print("--- Checking and Installing Dependencies ---")
# Core dependencies
numpy_success = install_and_import('numpy', 'np')
torch_success = install_and_import('torch')
torchvision_success = install_and_import('torchvision')
install_and_import('PIL')
install_and_import('tqdm')
install_and_import('matplotlib.pyplot', 'plt')
install_and_import('scipy')
install_and_import('pytorch_fid', pip_name='pytorch-fid')

# For advanced visualizations
install_and_import('sklearn.manifold', 'manifold', pip_name='scikit-learn')
install_and_import('umap', pip_name='umap-learn')

# For improved metrics
install_and_import('torchmetrics.image.kid', 'torchmetrics_kid', pip_name='torchmetrics')

# Check critical dependencies
critical_imports_successful = all([numpy_success, torch_success, torchvision_success])
if not critical_imports_successful:
    print("ERROR: Critical packages (numpy, torch, torchvision) failed to import.")
    print("Please install these packages manually and restart the script.")
    sys.exit(1)

# --- Core Imports ---
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader 
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.nn.utils import spectral_norm 
from torch.amp import autocast, GradScaler
import torch.nn.functional as F

# --- Import FID calculation utilities ---
try:
    from pytorch_fid.inception import InceptionV3
    from pytorch_fid.fid_score import calculate_frechet_distance
    print("Successfully imported FID utilities")
    FID_AVAILABLE = True
except ImportError:
    print("WARNING: Could not import FID utilities from 'pytorch_fid'. FID calculation will be disabled.")
    FID_AVAILABLE = False

# --- Import KID calculation (TorchMetrics) ---
try:
    from torchmetrics.image.kid import KernelInceptionDistance
    print("Successfully imported KID utilities from torchmetrics")
    TORCHMETRICS_KID_AVAILABLE = True
except ImportError:
    print("WARNING: Could not import KID from torchmetrics. Will use custom implementation.")
    TORCHMETRICS_KID_AVAILABLE = False

# --- Import visualization utilities ---
try:
    from sklearn.manifold import TSNE
    TSNE_AVAILABLE = True
    print("Successfully imported t-SNE from sklearn")
except ImportError:
    print("WARNING: Could not import t-SNE. Feature space visualization will be limited.")
    TSNE_AVAILABLE = False

try:
    import umap
    UMAP_AVAILABLE = True
    print("Successfully imported UMAP")
except ImportError:
    print("WARNING: Could not import UMAP. Feature space visualization will be limited.")
    UMAP_AVAILABLE = False

# --- Import models and dataset class from separate files ---
try:
    from wgan_models_v2A1 import Generator, CriticSN, initialize_weights 
    print("Successfully imported models from local .py files.")
    MODELS_AVAILABLE = True
except ImportError as e:
    print(f"ERROR: Could not import models from .py files: {e}")
    print("Ensure wgan_models_v2A1.py exists in the same directory.")
    MODELS_AVAILABLE = False

try:
    from pollen_datasets_v2A1 import PollenDataset 
    print("Successfully imported dataset from local .py files.")
    DATASET_AVAILABLE = True
except ImportError as e:
    print(f"ERROR: Could not import dataset from .py files: {e}")
    print("Ensure pollen_datasets_v2A1.py exists in the same directory.")
    DATASET_AVAILABLE = False

# --- Verify critical imports ---
if not MODELS_AVAILABLE or not DATASET_AVAILABLE:
    print("Critical model or dataset definitions missing. Please fix import issues.")
    sys.exit(1)

# ====== CONFIGURATION SECTION - MODIFY THESE SETTINGS ======

# --- Original Paths ---
ORIGINAL_OUTPUT_DIR = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\WGAN-SN_training-output_v2-151"
PREPROCESSED_DATA_DIR = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\pre-processing_px-128_step_automated-labels_pc-150"

# --- Continuation Settings ---
CONTINUATION_DIR_NAME = "continuation_enhanced"  # Change this for different continuation runs
CONTINUATION_DIR = os.path.join(ORIGINAL_OUTPUT_DIR, CONTINUATION_DIR_NAME)
CONTINUATION_PREFIX = "cont_"  # Prefix for continuation files

# --- Checkpoint Selection ---
START_FROM_EPOCH_CKPT = None  # Specific epoch number or None
USE_BEST_FID_CKPT = True  # Use best FID checkpoint (ignored if START_FROM_EPOCH_CKPT is not None)
LATEST_CKPT_FILENAME = "latest_checkpoint_sn_v2151.pth.tar"
BEST_FID_CKPT_FILENAME = "best_fid_checkpoint_v2151.pth.tar"

# --- Training Settings ---
USE_ORIGINAL_SETTINGS = False  # Load HPs from original config or use NEW ones below?
NEW_LEARNING_RATE = 0.00002  # Used if USE_ORIGINAL_SETTINGS is False
NEW_BATCH_SIZE = 64  # Used if USE_ORIGINAL_SETTINGS is False
NEW_CRITIC_ITERATIONS = 5  # Used if USE_ORIGINAL_SETTINGS is False
USE_GRADIENT_CLIPPING = False  # Enable/disable gradient clipping (SN should suffice)
ADDITIONAL_EPOCHS = 250  # Number of additional epochs to train

# --- Early Stopping Settings ---
NEW_EARLY_STOPPING_PATIENCE = 50  # New patience value for the continuation run

# --- Evaluation Settings ---
PRIMARY_EVAL_METRIC = "FID"  # "FID" or "KID"
FID_FREQ_EPOCHS = 1  # How often to calculate metrics
FID_NUM_IMAGES = 10000  # Number of images for FID/KID calculation
FID_BATCH_SIZE = 64  # Batch size for metric calculations
CALCULATE_KID = True  # Also calculate KID alongside FID
KID_SUBSET_SIZE = 1000  # Subset size for KID calculation
KID_SUBSETS = 100  # Number of subsets for KID calculation

# --- Plotting Settings ---
PLOT_PER_EPOCH = True  # Generate plots after each evaluation epoch
PLOT_FEATURE_SPACE = True  # Generate t-SNE/UMAP visualizations
USE_UMAP = True  # Prefer UMAP over t-SNE if available
TSNE_UMAP_SAMPLE_SIZE = 2000  # Number of samples for feature space visualization
TSNE_PERPLEXITY = 30  # t-SNE perplexity parameter
TSNE_UMAP_RANDOM_STATE = 42  # Random seed for reproducibility

# --- Sample Settings ---
SAMPLE_FREQ_STEPS = 500  # How often to save sample images during training

# --- Stability Monitoring ---
MONITOR_LOSS_STABILITY = True  # Enable loss stability monitoring
LOSS_STABILITY_WINDOW = 100  # Window size for moving average
LOSS_STABILITY_THRESHOLD = 5.0  # Threshold for abnormal loss spikes
HEARTBEAT_LOG_FREQ = 200  # How often to log "heartbeat" messages (0 to disable)

# --- Misc Settings ---
FORCE_RECALCULATE_REAL_STATS = False  # Force recalculation of real image statistics
AMP_ENABLED = False  # Automatic Mixed Precision (keep disabled as in original script)

# --- Random Seed for Reproducibility ---
MANUAL_SEED = 42

# ====== DO NOT MODIFY BELOW THIS LINE UNLESS YOU KNOW WHAT YOU'RE DOING ======

# --- Create Continuation Directories ---
CONT_CHKPT_DIR = os.path.join(CONTINUATION_DIR, "checkpoints")
CONT_SAMPLE_DIR = os.path.join(CONTINUATION_DIR, "samples")
CONT_LOG_DIR = os.path.join(CONTINUATION_DIR, "logs")
CONT_PLOT_DIR = os.path.join(CONTINUATION_DIR, "plots")
CONT_ANALYSIS_DIR = os.path.join(CONTINUATION_DIR, "analysis_results")

# Create all required directories
for directory in [CONTINUATION_DIR, CONT_CHKPT_DIR, CONT_SAMPLE_DIR, CONT_LOG_DIR, CONT_PLOT_DIR, CONT_ANALYSIS_DIR]:
    os.makedirs(directory, exist_ok=True)

# --- Setup Logging ---
log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s - %(message)s')
logger = logging.getLogger("WGAN_Continuation_Enhanced")
logger.setLevel(logging.INFO)
# Clear any existing handlers
if logger.hasHandlers():
    logger.handlers.clear()

# Add console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(log_formatter)
logger.addHandler(console_handler)

# Add file handler
log_file = os.path.join(CONT_LOG_DIR, f"{CONTINUATION_PREFIX}training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
file_handler = logging.FileHandler(log_file, mode='a')
file_handler.setFormatter(log_formatter)
logger.addHandler(file_handler)

logger.info("="*80)
logger.info(f"WGAN-SN Training Continuation Script (Enhanced) - {datetime.now()}")
logger.info(f"Output directory: {CONTINUATION_DIR}")
logger.info("="*80)

# --- Set Random Seeds ---
if MANUAL_SEED is not None:
    logger.info(f"Using manual seed: {MANUAL_SEED}")
    random.seed(MANUAL_SEED)
    np.random.seed(MANUAL_SEED)
    torch.manual_seed(MANUAL_SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(MANUAL_SEED)

# --- Setup Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

if device.type == "cuda":
    logger.info(f"CUDA Version: {torch.version.cuda}")
    gpu_name = torch.cuda.get_device_name(0)
    logger.info(f"GPU Name: {gpu_name}")

logger.info(f"AMP Enabled: {AMP_ENABLED}")
logger.info(f"Gradient Clipping Enabled: {USE_GRADIENT_CLIPPING}")

# ==============================================
# Utility Functions
# ==============================================

# --- GPU Memory Logger ---
def log_gpu_memory_usage(step=''):
    """Log current GPU memory usage to the logger"""
    if not torch.cuda.is_available():
        return
    
    allocated = torch.cuda.memory_allocated() / (1024 ** 3)  # GB
    reserved = torch.cuda.memory_reserved() / (1024 ** 3)    # GB
    
    logger.info(f"GPU Memory [{step}]: Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB")

# --- KID Calculation Fallback Functions ---
def polynomial_kernel_custom(X, Y, degree=3, gamma=None, coef0=1.0):
    """
    Polynomial kernel for KID: k(x,y) = (gamma <x,y> + coef0)^degree
    Using carefully balanced parameters to prevent underflow/overflow.
    """
    # Convert to higher precision
    X = X.astype(np.float64)
    Y = Y.astype(np.float64)
    
    # Normalize features with slightly relaxed epsilon
    X_norm = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-8)
    Y_norm = Y / (np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8)
    
    # More balanced parameters that won't underflow
    gamma = 0.2
    coef0 = 1.0
    degree = 3
    
    dot_product = np.matmul(X_norm, Y_norm.T)
    
    # Prevent exact zeros with lower bound
    return np.clip((gamma * dot_product + coef0) ** degree, 1e-8, 1e6)

def calculate_kid_from_features_custom(real_features, fake_features, subset_size=1000, num_subsets=100):
    """
    Calculate KID given features extracted from Inception.
    Uses polynomial kernel and subsampling with safeguards against numerical issues.
    """
    # Use high precision
    real_features = real_features.astype(np.float64)
    fake_features = fake_features.astype(np.float64)
    
    # Center the features (remove mean) - this is still good practice
    real_features = real_features - np.mean(real_features, axis=0, keepdims=True)
    fake_features = fake_features - np.mean(fake_features, axis=0, keepdims=True)
    
    n_r, n_f = real_features.shape[0], fake_features.shape[0]
    
    subset_size = min(subset_size, min(n_r, n_f))
    kid_values = []
    
    # Verify inputs aren't identical
    if np.array_equal(real_features, fake_features):
        logger.warning("WARNING: real_features and fake_features are identical arrays! KID calculation will be biased.")
    
    for _ in range(num_subsets):
        # Sample subset_size features from both distributions
        r_idx = np.random.choice(n_r, size=subset_size, replace=False)
        f_idx = np.random.choice(n_f, size=subset_size, replace=False)
        
        r_subset = real_features[r_idx]
        f_subset = fake_features[f_idx]
        
        # Calculate polynomial kernel MMD (Maximum Mean Discrepancy)
        k_rr = polynomial_kernel_custom(r_subset, r_subset)
        k_rf = polynomial_kernel_custom(r_subset, f_subset)
        k_ff = polynomial_kernel_custom(f_subset, f_subset)
        
        # Calculate unbiased MMD estimate with safeguards
        n = subset_size
        mmd_numerator = np.sum(k_rr) - np.trace(k_rr) + np.sum(k_ff) - np.trace(k_ff) - 2 * np.sum(k_rf)
        mmd_denominator = n * (n-1)
        
        # Prevent division by zero (should never happen with our subset size checks)
        if mmd_denominator <= 0:
            logger.warning("WARNING: Invalid denominator in KID calculation!")
            mmd = 0.01  # Fallback value
        else:
            mmd = mmd_numerator / mmd_denominator
        
        # Ensure non-negative MMD and prevent exact zeros
        mmd = max(1e-8, mmd)
        kid_values.append(mmd)
    
    return np.mean(kid_values), np.std(kid_values)

# --- Checkpoint Finding Utilities ---
def find_checkpoint_file(original_ckpt_dir):
    """Find the appropriate checkpoint file to load based on configuration settings"""
    if START_FROM_EPOCH_CKPT is not None:
        # Look for checkpoint from specific epoch
        specific_ckpt = glob.glob(os.path.join(original_ckpt_dir, f"checkpoint_epoch_{START_FROM_EPOCH_CKPT:04d}*.pth.tar"))
        if specific_ckpt:
            logger.info(f"Found checkpoint for specified epoch {START_FROM_EPOCH_CKPT}: {os.path.basename(specific_ckpt[0])}")
            return specific_ckpt[0]
        else:
            logger.warning(f"No checkpoint found for epoch {START_FROM_EPOCH_CKPT}. Will try best FID/latest checkpoint.")
    
    if USE_BEST_FID_CKPT:
        # Try to find best FID checkpoint
        best_fid_path = os.path.join(original_ckpt_dir, BEST_FID_CKPT_FILENAME)
        if os.path.exists(best_fid_path):
            logger.info(f"Using best FID checkpoint: {BEST_FID_CKPT_FILENAME}")
            return best_fid_path
            
        # If specific best_fid file not found, try to find any best_fid checkpoint
        best_fid_ckpts = glob.glob(os.path.join(original_ckpt_dir, "best_fid_checkpoint*.pth.tar"))
        if best_fid_ckpts:
            # Sort by FID score if possible
            try:
                # Extract FID score from filename using regex (e.g. best_fid_checkpoint_e0074_fid68.32.pth.tar)
                scores = []
                for ckpt in best_fid_ckpts:
                    match = re.search(r'fid(\d+\.\d+)', os.path.basename(ckpt))
                    if match:
                        scores.append((float(match.group(1)), ckpt))
                if scores:
                    best_score, best_ckpt = min(scores, key=lambda x: x[0])
                    logger.info(f"Using best FID checkpoint with score {best_score:.4f}: {os.path.basename(best_ckpt)}")
                    return best_ckpt
            except:
                pass
                
            # If regex fails, just use the first one
            logger.info(f"Using best FID checkpoint: {os.path.basename(best_fid_ckpts[0])}")
            return best_fid_ckpts[0]
    
    # Fall back to latest checkpoint
    latest_ckpt = os.path.join(original_ckpt_dir, LATEST_CKPT_FILENAME)
    if os.path.exists(latest_ckpt):
        logger.info(f"Using latest checkpoint: {LATEST_CKPT_FILENAME}")
        return latest_ckpt
        
    # If all else fails, look for any checkpoint
    all_ckpts = glob.glob(os.path.join(original_ckpt_dir, "*.pth.tar"))
    if all_ckpts:
        logger.warning(f"Specified checkpoints not found. Using available checkpoint: {os.path.basename(all_ckpts[0])}")
        return all_ckpts[0]
        
    logger.error(f"No checkpoints found in {original_ckpt_dir}. Cannot continue.")
    return None

# --- Checkpoint Loading/Saving ---
def load_checkpoint(checkpoint_path):
    """Load a checkpoint file and extract configuration and history data"""
    logger.info(f"Loading checkpoint from: {checkpoint_path}")
    try:
        print("!!! maintain the 'weights_only' value as 'false' to avoid issues even if a warning appears in jupyter console !!!")
        checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        logger.info("Checkpoint loaded successfully!")
        
        # Extract training history
        g_losses_hist = checkpoint.get('g_losses_history', [])
        c_losses_hist = checkpoint.get('c_losses_history', [])
        fid_scores_hist = checkpoint.get('fid_scores_history', [])
        fid_epochs_hist = checkpoint.get('fid_epochs_history', [])
        kid_scores_hist = checkpoint.get('kid_scores_history', [])
        kid_std_hist = checkpoint.get('kid_std_history', [])
        kid_epochs_hist = checkpoint.get('kid_epochs_history', [])
        best_fid = checkpoint.get('best_fid', float('inf'))
        best_kid = checkpoint.get('best_kid', float('inf'))
        epochs_no_improve = checkpoint.get('epochs_no_improve', 0)
        start_epoch = checkpoint.get('epoch', 0)
        global_step = checkpoint.get('step', 0)
        
        logger.info(f"Loaded checkpoint from Epoch {start_epoch}, Step {global_step}")
        logger.info(f"Training history: {len(g_losses_hist)} loss points, {len(fid_scores_hist)} FID scores, {len(kid_scores_hist)} KID scores")
        logger.info(f"Best scores: FID: {best_fid:.4f}, KID: {best_kid:.6f}")
        logger.info(f"Original early stopping status: {epochs_no_improve}/{checkpoint.get('early_stopping_patience', 10)}")
        
        history_data = {
            'g_losses_hist': g_losses_hist,
            'c_losses_hist': c_losses_hist,
            'fid_scores_hist': fid_scores_hist,
            'fid_epochs_hist': fid_epochs_hist,
            'kid_scores_hist': kid_scores_hist,
            'kid_std_hist': kid_std_hist,
            'kid_epochs_hist': kid_epochs_hist,
            'best_fid': best_fid,
            'best_kid': best_kid,
            'start_epoch': start_epoch,
            'global_step': global_step
        }
        
        return checkpoint, history_data
    except Exception as e:
        logger.error(f"Error loading checkpoint: {e}")
        return None, None

def save_checkpoint(state, filename):
    """Save a checkpoint file to the continuation directory"""
    save_path = os.path.join(CONT_CHKPT_DIR, filename)
    logger.info(f"Saving checkpoint to {save_path}")
    try:
        torch.save(state, save_path)
        logger.info("Checkpoint saved successfully.")
    except Exception as e:
        logger.error(f"Failed to save checkpoint: {e}", exc_info=True)

def save_best_model_inference(generator, model_config, filename=f"{CONTINUATION_PREFIX}best_model.pt"):
    """Save the best generator model in a format suitable for inference"""
    save_path = os.path.join(CONTINUATION_DIR, filename)
    logger.info(f"Saving best model for inference to {save_path}")
    try:
        torch.save({
            'model_state_dict': generator.state_dict(),
            'model_config': model_config
        }, save_path)
        logger.info("Best model saved successfully.")
    except Exception as e:
        logger.error(f"Failed to save best model: {e}", exc_info=True)

def save_best_fake_features(fake_features, filename=f"{CONTINUATION_PREFIX}best_fake_features.npy"):
    """Save the best fake features for later analysis"""
    if fake_features is None:
        logger.warning("No fake features to save.")
        return False
        
    save_path = os.path.join(CONT_ANALYSIS_DIR, filename)
    logger.info(f"Saving best fake features to {save_path}")
    try:
        np.save(save_path, fake_features)
        logger.info("Best fake features saved successfully.")
        return True
    except Exception as e:
        logger.error(f"Failed to save fake features: {e}", exc_info=True)
        return False

def save_best_samples(generator, fixed_noise, epoch, filename=None):
    """Generate and save samples using the generator with fixed noise"""
    if filename is None:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        filename = f"{CONTINUATION_PREFIX}best_samples_e{epoch:04d}_{timestamp}.png"
        
    save_path = os.path.join(CONT_SAMPLE_DIR, filename)
    logger.info(f"Saving best samples to {save_path}")
    
    try:
        generator.eval()
        with torch.no_grad():
            fake_samples = generator(fixed_noise)
            img_grid = vutils.make_grid(fake_samples * 0.5 + 0.5, normalize=False)
            vutils.save_image(img_grid, save_path)
        generator.train()
        logger.info("Best samples saved successfully.")
        return True
    except Exception as e:
        logger.error(f"Failed to save best samples: {e}", exc_info=True)
        generator.train()
        return False

# --- Configuration Loading ---
def extract_original_config(original_dir):
    """Try to load original configuration from JSON file or use defaults"""
    config_path = os.path.join(original_dir, "training_config_v2151.json")
    if os.path.exists(config_path):
        try:
            with open(config_path, 'r') as f:
                config = json.load(f)
            logger.info(f"Loaded original configuration from {config_path}")
            return config
        except Exception as e:
            logger.error(f"Error loading configuration: {e}")
    
    logger.warning("Original configuration file not found. Using default values.")
    return {
        "LEARNING_RATE": 0.00005,
        "BATCH_SIZE": 64,
        "CRITIC_ITERATIONS": 5,
        "NUM_EPOCHS": 250,
        "NOISE_DIM": 100,
        "CHANNELS_IMG": 1,
        "G_FEATURES": 64,
        "C_FEATURES": 64,
        "FID_FREQ_EPOCHS": 1,
        "FID_NUM_IMAGES": 10000,
        "FID_BATCH_SIZE": 64,
        "SAMPLE_FREQ_STEPS": 500,
        "EARLY_STOPPING_PATIENCE": 10,
        "CHECKPOINT_FREQ_EPOCHS": 5,
        "IMAGE_SIZE": 128
    }

# --- Plotting Utilities ---
def _annotate_plot(fig, ax, final_epoch, stop_reason=None, status_text=None):
    """Add annotations to a plot regarding training status"""
    if status_text is None:
        status_text = f"Training status: Finished during Epoch {final_epoch}"
    if stop_reason is not None:
        status_text += f"\nStop reason: {stop_reason}"
        
    # Add text below the x-axis, aligned to the left instead of center
    fig.text(0.1, 0.01, status_text, ha='left', fontsize=9)  # Changed from 0.5 to 0.1 and ha='center' to ha='left'
    
    # Mark the last epoch with a distinctive marker
    # The following lines are removed to stop marking the last epoch with a star
    # last_x = ax.get_lines()[0].get_xdata()[-1]
    # last_y = ax.get_lines()[0].get_ydata()[-1]
    # ax.plot(last_x, last_y, 'r*', markersize=10, label='Last Epoch')

def plot_losses(g_losses, c_losses, save_dir, filename_base, final_epoch, timestamp, is_final=False, stop_reason=None):
    """Plot generator and critic losses with current epoch in title"""
    try:
        if not g_losses or not c_losses:
            logger.warning("No loss data to plot")
            return False
            
        fig, ax = plt.figure(figsize=(10, 6)), plt.gca()
        epochs = range(1, len(g_losses) + 1)
        
        # Plot losses
        ax.plot(epochs, g_losses, 'r-', label="Generator Loss", alpha=0.8)
        ax.plot(epochs, c_losses, 'b-', label="Critic Loss", alpha=0.8)
        
        # Fix: Include current epoch number in title
        ax.set_title(f"Generator and Critic Loss vs. Epoch (Most recent: {final_epoch})")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.legend()
        ax.grid(True, linestyle='--', alpha=0.6)
        
        # Mark the last epoch and annotate with status
        if is_final:
            _annotate_plot(fig, ax, final_epoch, stop_reason)
            filename = f"{filename_base}_final_{timestamp}.png"
        else:
            filename = f"{filename_base}_current.png"
            
        plt.tight_layout()
        save_path = os.path.join(save_dir, filename)
        plt.savefig(save_path, dpi=300)
        plt.close(fig)
        
        logger.info(f"Saved loss plot to {save_path}")
        return True
    except Exception as e:
        logger.error(f"Failed to generate or save loss plot: {e}", exc_info=True)
        return False

def plot_metric(scores, epochs_hist, best_score, metric_name, save_dir, filename_base, final_epoch, timestamp, 
               is_final=False, stop_reason=None, std_devs=None):
    """Plot metric (FID or KID) scores with current epoch in title"""
    try:
        if not scores or not epochs_hist:
            logger.warning(f"No {metric_name} data to plot")
            return False
            
        fig, ax = plt.figure(figsize=(10, 6)), plt.gca()
        
        # Plot with error bars if std_devs provided
        if std_devs and len(std_devs) == len(scores):
            # Change the color parameter here from blue to green
            ax.errorbar(epochs_hist, scores, yerr=std_devs, fmt='g-o',  # Changed 'o-' or similar to 'g-o'
                      label=f"{metric_name} Score", capsize=4)
        else:
            # Change the color parameter here from blue to green
            ax.plot(epochs_hist, scores, marker='o', linestyle='-', color='green', label=f"{metric_name} Score")
        
        # Find and annotate best score
        best_score_val = min(scores)
        best_epoch_idx = scores.index(best_score_val)
        best_epoch = epochs_hist[best_epoch_idx]
        ax.scatter([best_epoch], [best_score_val], color='purple', s=100, zorder=5, marker='o')
        ax.annotate(f'Best: {best_score_val:.4f}\nEpoch: {best_epoch}', 
                    xy=(best_epoch, best_score_val), xytext=(10, -30),
                    textcoords='offset points', arrowprops=dict(arrowstyle="->", color='purple'),
                    color='purple')
        
        # Fix: Include current epoch number in title
        ax.set_title(f"{metric_name} Score vs. Epoch (Most recent: {final_epoch})")
        ax.set_xlabel("Epoch")
        ax.set_ylabel(f"{metric_name} Score (Lower is Better)")
        ax.grid(True, linestyle='--', alpha=0.6)
        
        # Mark the last epoch and annotate with status
        if is_final:
            _annotate_plot(fig, ax, final_epoch, stop_reason)
            filename = f"{filename_base}_final_{timestamp}.png"
        else:
            filename = f"{filename_base}_current.png"
            
        plt.tight_layout()
        save_path = os.path.join(save_dir, filename)
        plt.savefig(save_path, dpi=300)
        plt.close(fig)
        
        logger.info(f"Saved {metric_name} plot to {save_path}")
        return True
    except Exception as e:
        logger.error(f"Failed to generate or save {metric_name} plot: {e}", exc_info=True)
        return False

def plot_combined_metrics(fid_scores, kid_scores, epochs_hist, save_dir, filename_base, final_epoch, timestamp, 
                         is_final=False, stop_reason=None):
    """Plot FID and KID scores on the same graph with dual axes and proper handling of different lengths"""
    try:
        if not fid_scores or not kid_scores:
            logger.warning("Missing data for combined metrics plot")
            return False
            
        fig, ax1 = plt.subplots(figsize=(10, 6))
        
        # Fix: Make sure we have matching x and y arrays for both metrics
        # For FID (primary y-axis)
        fid_epochs = epochs_hist[:len(fid_scores)]  # Just in case lengths don't match
        line1, = ax1.plot(fid_epochs, fid_scores, 'b-o', label="FID Score")
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('FID Score', color='b')
        ax1.tick_params(axis='y', labelcolor='b')
        
        # Create right y-axis for KID, ensuring we use matching x and y lengths
        ax2 = ax1.twinx()
        kid_epochs = epochs_hist[:len(kid_scores)]  # Just in case lengths don't match
        line2, = ax2.plot(kid_epochs, kid_scores, 'r-o', label="KID Score")
        ax2.set_ylabel('KID Score', color='r')
        ax2.tick_params(axis='y', labelcolor='r')
        
        # Add grid to the plot
        ax1.grid(True, linestyle='--', alpha=0.3)
        
        # Add legend
        lines = [line1, line2]
        labels = [line.get_label() for line in lines]
        ax1.legend(lines, labels, loc="upper right")
        
        # Fix: Include current epoch number in title
        plt.title(f'FID and KID Scores vs. Epoch (Most recent: {final_epoch})')
        
        # Mark the last epoch and annotate with status
        if is_final:
            _annotate_plot(fig, ax1, final_epoch, stop_reason)
            filename = f"{filename_base}_final_{timestamp}.png"
        else:
            filename = f"{filename_base}_current.png"
            
        plt.tight_layout()
        save_path = os.path.join(save_dir, filename)
        plt.savefig(save_path, dpi=300)
        plt.close(fig)
        
        logger.info(f"Saved combined metrics plot to {save_path}")
        return True
    except Exception as e:
        logger.error(f"Failed to generate or save combined metrics plot: {e}", exc_info=True)
        return False

def plot_feature_space(real_features, fake_features, method, save_dir, filename_base, final_epoch, timestamp, stop_reason=None):
    """Create t-SNE or UMAP visualization of real vs fake feature distributions"""
    try:
        if real_features is None or fake_features is None:
            logger.warning("Missing features for visualization")
            return False
            
        # Sample if too many points
        max_samples = TSNE_UMAP_SAMPLE_SIZE // 2  # Half for real, half for fake
        
        if len(real_features) > max_samples:
            indices = np.random.choice(len(real_features), max_samples, replace=False)
            real_sample = real_features[indices]
        else:
            real_sample = real_features
            
        if len(fake_features) > max_samples:
            indices = np.random.choice(len(fake_features), max_samples, replace=False)
            fake_sample = fake_features[indices]
        else:
            fake_sample = fake_features
        
        # Combine features
        combined_features = np.vstack([real_sample, fake_sample])
        
        # Create labels (0 for real, 1 for fake)
        labels = np.zeros(len(combined_features))
        labels[len(real_sample):] = 1
        
        # Perform dimensionality reduction
        if method.lower() == 'umap' and UMAP_AVAILABLE:
            logger.info("Computing UMAP embedding...")
            reducer = umap.UMAP(random_state=TSNE_UMAP_RANDOM_STATE)
            embedding = reducer.fit_transform(combined_features)
            title = f'UMAP Visualization of Real vs Generated Feature Distributions (Epoch {final_epoch})'
        else:
            logger.info("Computing t-SNE embedding...")
            tsne = manifold.TSNE(n_components=2, perplexity=TSNE_PERPLEXITY, 
                                random_state=TSNE_UMAP_RANDOM_STATE, n_iter=1000)
            embedding = tsne.fit_transform(combined_features)
            title = f't-SNE Visualization of Real vs Generated Feature Distributions (Epoch {final_epoch})'
        
        # Create plot
        fig, ax = plt.subplots(figsize=(10, 8))
        
        real_points = embedding[:len(real_sample)]
        fake_points = embedding[len(real_sample):]
        
        ax.scatter(real_points[:, 0], real_points[:, 1], c='blue', alpha=0.6, label='Real', s=20)
        ax.scatter(fake_points[:, 0], fake_points[:, 1], c='red', alpha=0.6, label='Generated', s=20)
        
        # Fix: Include current epoch number in title
        ax.set_title(title)
        ax.legend()
        
        # Add a description of what plot shows
        ax.annotate("Note: Points closer together have similar feature representations\n"
                    "Good generation = red points distributed similarly to blue points", 
                    xy=(0.5, -0.01), xycoords='axes fraction', 
                    ha='center', va='top', fontsize=9)
        
        # Add status information
        status_text = f"Feature space visualization after Epoch {final_epoch}"
        if stop_reason:
            status_text += f"\nStop reason: {stop_reason}"
        fig.text(0.5, 0.01, status_text, ha='center', fontsize=9)
        
        filename = f"{filename_base}_{method.lower()}_{timestamp}.png"
        save_path = os.path.join(save_dir, filename)
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300)
        plt.close(fig)
        
        logger.info(f"Saved {method} visualization to {save_path}")
        return True
    except Exception as e:
        logger.error(f"Failed to generate or save feature space visualization: {e}", exc_info=True)
        return False

def generate_final_sample_grid(samples_dir, grid_size=(4, 8), max_samples=32, epoch=None, timestamp=None):
    """Create a grid of sample images, preferring the best epoch samples if available"""
    try:
        # Look for best samples first
        best_samples = sorted(glob.glob(os.path.join(samples_dir, f"{CONTINUATION_PREFIX}best_samples*.png")))
        
        # If no best samples, use regular samples
        if not best_samples:
            sample_files = sorted(glob.glob(os.path.join(samples_dir, f"{CONTINUATION_PREFIX}sample_*.png")))
            if not sample_files:
                logger.warning(f"No sample images found in {samples_dir}")
                return False
        else:
            # Use the most recent best sample file
            best_sample = best_samples[-1]
            
            # Extract epoch from filename if possible
            epoch_match = re.search(r'e(\d+)', os.path.basename(best_sample))
            best_epoch = int(epoch_match.group(1)) if epoch_match else "best"
            
            # Load this single file as the best image
            best_img = Image.open(best_sample)
            
            # Create a filename for the output
            if timestamp is None:
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            
            filename = f"{CONTINUATION_PREFIX}best_epoch_{best_epoch}_sample_grid_{timestamp}.png"
            save_path = os.path.join(CONT_ANALYSIS_DIR, filename)
            
            # Save directly (no need for grid since it's already a grid)
            best_img.save(save_path)
            logger.info(f"Saved best epoch sample grid to {save_path}")
            return True
            
        # If we're here, we're using regular samples
        selected_samples = sample_files[-max_samples:] if len(sample_files) > max_samples else sample_files
        
        # Create figure for grid
        plt.figure(figsize=(20, 10))
        plt.suptitle(f"Generated Samples (Current Epoch: {epoch})", fontsize=16)
        
        for i, sample_path in enumerate(selected_samples):
            if i >= grid_size[0] * grid_size[1]:
                break
                
            # Load and display the image
            img = Image.open(sample_path)
            plt.subplot(grid_size[0], grid_size[1], i+1)
            plt.imshow(np.array(img), cmap='gray')
            
            # Extract epoch and step from filename
            filename = os.path.basename(sample_path)
            match = re.search(r'sample_(\d+)_(\d+)', filename)
            if match:
                epoch, step = match.groups()
                plt.title(f"Epoch {int(epoch)}, Step {int(step)}")
            else:
                plt.title(filename)
                
            plt.axis('off')
        
        # Save the grid
        if timestamp is None:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            
        if epoch is None:
            epoch_str = "final"
        else:
            epoch_str = f"{epoch:04d}"
            
        filename = f"{CONTINUATION_PREFIX}sample_grid_e{epoch_str}_{timestamp}.png"
        save_path = os.path.join(CONT_ANALYSIS_DIR, filename)
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300)
        plt.close()
        
        logger.info(f"Saved sample grid to {save_path}")
        return True
    except Exception as e:
        logger.error(f"Failed to generate or save sample grid: {e}", exc_info=True)
        return False

# --- FID/KID Calculation ---
def get_inception_model(device):
    """Load the InceptionV3 model for feature extraction"""
    if not FID_AVAILABLE:
        raise RuntimeError("pytorch-fid library not available.")
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
    model = InceptionV3([block_idx]).to(device)
    model.eval()
    return model

def get_activations(dataloader_or_generator, model, device, num_images, batch_size, desc="", noise_dim=None):
    """Extract features using the Inception model from real or generated images"""
    if not FID_AVAILABLE:
        logger.warning("pytorch-fid not available.")
        return None
        
    n_batches = ceil(num_images / batch_size)
    n_used_imgs = 0
    pred_list = []
    
    # Determine if we're processing a dataloader or generating from a generator
    is_dataloader = noise_dim is None
    
    if is_dataloader:
        logger.info(f"Getting activations from {num_images} real images ({n_batches} batches)...")
        iterator = iter(dataloader_or_generator)
    else:
        logger.info(f"Generating {num_images} fake images & activations ({n_batches} batches)...")
        generator = dataloader_or_generator
        generator.eval()
    
    with torch.no_grad():
        for i in tqdm.tqdm(range(n_batches), desc=f"Activations {desc}", leave=False):
            try:
                if is_dataloader:
                    # Process real images from dataloader
                    try:
                        batch = next(iterator).to(device)
                    except StopIteration:
                        logger.warning(f"Dataloader exhausted early @ batch {i}. Using {n_used_imgs} images.")
                        break
                        
                    if isinstance(batch, (list, tuple)):
                        batch = batch[0]
                    if batch.shape[0] == 0:
                        continue
                else:
                    # Generate fake images
                    current_batch_size = min(batch_size, num_images - n_used_imgs)
                    if current_batch_size <= 0:
                        break
                        
                    noise = torch.randn(current_batch_size, noise_dim, 1, 1, device=device)
                    batch = generator(noise)
                
                current_batch_size = batch.shape[0]
                
                # Preprocess batch for Inception
                if batch.dtype != torch.float32:
                    batch = batch.float()
                if batch.shape[1] == 1:
                    batch = batch.repeat(1, 3, 1, 1)  # Expand grayscale to RGB
                if batch.shape[1] != 3:
                    raise ValueError(f"Batch needs 3 channels, got {batch.shape[1]}")
                
                # Rescale from [-1,1] to [0,1]
                batch = (batch * 0.5) + 0.5  
                batch = torch.clamp(batch, 0.0, 1.0)
                
                # Get activations and process them
                pred = model(batch)[0]
                if pred.size(2) != 1 or pred.size(3) != 1:
                    pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
                
                pred_list.append(pred.squeeze(3).squeeze(2).cpu().numpy())
                n_used_imgs += current_batch_size
                
                if n_used_imgs >= num_images:
                    break
                    
                # Clean up
                del batch, pred
                if not is_dataloader:
                    del noise
                    
            except Exception as e:
                logger.error(f"Error during activation batch {i}: {e}", exc_info=True)
                if not is_dataloader:
                    generator.train()
                return None
    
    # Restore generator state if needed
    if not is_dataloader:
        generator.train()
    
    # Process results
    if not pred_list:
        return None
    
    pred_arr = np.concatenate(pred_list, axis=0)
    pred_arr = pred_arr[:num_images]
    
    # Clean up
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return pred_arr

def get_real_stats_and_features(real_dataloader, paths, inception_model, device, num_images, batch_size, force_recalculate=False):
    """Get real image statistics and features for FID/KID calculation"""
    real_stats_path = paths.get('real_stats_path')
    real_features_path = paths.get('real_features_path')
    original_real_stats_path = paths.get('original_real_stats_path')
    original_real_features_path = paths.get('original_real_features_path')
    
    # Check if we can use existing stats/features
    if not force_recalculate and real_stats_path and real_features_path:
        if os.path.exists(real_stats_path) and os.path.exists(real_features_path):
            logger.info(f"Loading pre-calculated real stats and features")
            try:
                stats = np.load(real_stats_path)
                mu_real, sigma_real = stats['mu'], stats['sigma']
                
                real_features = np.load(real_features_path)
                
                if (mu_real is not None and sigma_real is not None and 
                    mu_real.shape == (2048,) and sigma_real.shape == (2048, 2048) and
                    real_features.shape[1] == 2048):
                    logger.info("Loaded real stats and features successfully.")
                    return mu_real, sigma_real, real_features
                else:
                    logger.warning("Loaded real stats or features invalid.")
            except Exception as e:
                logger.warning(f"Could not load real stats or features file: {e}")
    
    # Try to copy from original run
    if original_real_stats_path and original_real_features_path:
        if os.path.exists(original_real_stats_path) and os.path.exists(original_real_features_path):
            logger.info("Copying real stats and features from original run")
            try:
                # Ensure destination directories exist
                os.makedirs(os.path.dirname(real_stats_path), exist_ok=True)
                os.makedirs(os.path.dirname(real_features_path), exist_ok=True)
                
                # Copy files
                shutil.copy2(original_real_stats_path, real_stats_path)
                shutil.copy2(original_real_features_path, real_features_path)
                
                # Load copied files
                stats = np.load(real_stats_path)
                mu_real, sigma_real = stats['mu'], stats['sigma']
                
                real_features = np.load(real_features_path)
                
                if (mu_real is not None and sigma_real is not None and 
                    mu_real.shape == (2048,) and sigma_real.shape == (2048, 2048) and
                    real_features.shape[1] == 2048):
                    logger.info("Copied and loaded real stats and features successfully.")
                    return mu_real, sigma_real, real_features
                else:
                    logger.warning("Copied real stats or features invalid.")
            except Exception as e:
                logger.warning(f"Failed to copy/load real stats or features: {e}")
    
    # If we get here, we need to calculate from scratch
    logger.info(f"Calculating FID stats and features for {num_images} real images...")
    real_activations = get_activations(real_dataloader, inception_model, device, num_images, batch_size, desc="Real")
    
    if real_activations is None or len(real_activations) < num_images // 2:  # Allow some tolerance
        logger.error(f"Failed to get enough real activations. Got {len(real_activations) if real_activations is not None else 0}.")
        return None, None, None
    
    # Calculate statistics
    mu_real = np.mean(real_activations, axis=0)
    sigma_real = np.cov(real_activations, rowvar=False)
    
    logger.info(f"Calculated real stats (mu: {mu_real.shape}, sigma: {sigma_real.shape}).")
    
    # Save results
    try:
        # Save stats for FID
        os.makedirs(os.path.dirname(real_stats_path), exist_ok=True)
        np.savez(real_stats_path, mu=mu_real, sigma=sigma_real)
        logger.info(f"Saved real FID stats to: {real_stats_path}")
        
        # Save raw features for KID and visualization
        np.save(real_features_path, real_activations)
        logger.info(f"Saved real features to: {real_features_path}")
    except Exception as e:
        logger.error(f"Failed to save real stats or features: {e}", exc_info=True)
    
    return mu_real, sigma_real, real_activations

def calculate_metrics(generator, inception_model, real_mu, real_sigma, real_features, device, noise_dim, num_images, batch_size):
    """Calculate FID and KID metrics using the same generated features"""
    if not FID_AVAILABLE:
        logger.warning("pytorch-fid not available.")
        return float('inf'), (float('inf'), float('inf')), None
    
    if real_mu is None or real_sigma is None or real_features is None:
        logger.error("Real stats or features not available.")
        return float('inf'), (float('inf'), float('inf')), None
    
    logger.info(f"Calculating FID and KID using {num_images} generated images...")
    fake_features = get_activations(generator, inception_model, device, num_images, batch_size, desc="Fake (FID/KID)", noise_dim=noise_dim)
    
    if fake_features is None or len(fake_features) < num_images // 2:  # Allow some tolerance
        logger.error(f"Failed to get enough fake activations. Got {len(fake_features) if fake_features is not None else 0}.")
        return float('inf'), (float('inf'), float('inf')), None
    
    # Calculate FID
    mu_fake = np.mean(fake_features, axis=0)
    sigma_fake = np.cov(fake_features, rowvar=False)
    
    logger.info("Calculating Frechet distance...")
    
    try:
        fid_value = calculate_frechet_distance(mu_fake, sigma_fake, real_mu, real_sigma)
        logger.info(f"Calculated FID: {fid_value:.4f}")
    except Exception as e:
        logger.error(f"Error calculating Frechet distance: {e}", exc_info=True)
        fid_value = float('inf')
    
    # Calculate KID
    if CALCULATE_KID:
        logger.info(f"Calculating KID score...")
        try:
            # Fixed: Always use custom KID implementation with pre-computed features
            logger.info("Using custom KID implementation with pre-computed features")
            kid_mean, kid_std = calculate_kid_from_features_custom(
                real_features, fake_features, 
                subset_size=KID_SUBSET_SIZE,
                num_subsets=KID_SUBSETS
            )
            logger.info(f"Calculated KID (custom): {kid_mean:.6f} ± {kid_std:.6f}")
        except Exception as e:
            logger.error(f"Error calculating KID: {e}", exc_info=True)
            kid_mean, kid_std = float('inf'), float('inf')
    else:
        kid_mean, kid_std = float('inf'), float('inf')
    
    # Clean up
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return fid_value, (kid_mean, kid_std), fake_features

# --- Markdown Report ---
def generate_markdown_report(history_data, training_info, output_path):
    """Generate a markdown report of training results"""
    g_losses_hist = history_data.get('g_losses_hist', [])
    c_losses_hist = history_data.get('c_losses_hist', [])
    fid_scores_hist = history_data.get('fid_scores_hist', [])
    fid_epochs_hist = history_data.get('fid_epochs_hist', [])
    kid_scores_hist = history_data.get('kid_scores_hist', [])
    kid_epochs_hist = history_data.get('kid_epochs_hist', [])
    best_fid = history_data.get('best_fid', float('inf'))
    best_kid = history_data.get('best_kid', float('inf'))
    
    try:
        with open(output_path, "w") as f:
            f.write(f"# WGAN-SN Training Continuation Report (Enhanced)\n\n")
            f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            
            # Continuation settings
            f.write("## Continuation Settings\n\n")
            f.write(f"- Continued from: {'Best FID model' if USE_BEST_FID_CKPT else 'Latest checkpoint'}\n")
            f.write(f"- Original epochs: {training_info.get('original_epochs', 'N/A')}\n")
            f.write(f"- Additional epochs: {ADDITIONAL_EPOCHS}\n")
            f.write(f"- Total epochs completed: {training_info.get('current_epoch', 'N/A')}\n")
            f.write(f"- Used original settings: {USE_ORIGINAL_SETTINGS}\n")
            
            if not USE_ORIGINAL_SETTINGS:
                f.write(f"- Learning rate: {NEW_LEARNING_RATE}\n")
                f.write(f"- Batch size: {NEW_BATCH_SIZE}\n")
                f.write(f"- Critic iterations: {NEW_CRITIC_ITERATIONS}\n")
                f.write(f"- Gradient clipping: {USE_GRADIENT_CLIPPING}\n")
            
            f.write(f"- Early stopping patience: {NEW_EARLY_STOPPING_PATIENCE}\n")
            
            # Training summary
            f.write("\n## Training Summary\n\n")
            f.write(f"- Final status: {training_info.get('stop_reason', 'Completed')}\n")
            f.write(f"- Training time: {training_info.get('training_time', 0):.2f} seconds\n")
            f.write(f"- Final generator loss: {g_losses_hist[-1] if g_losses_hist else 'N/A'}\n")
            f.write(f"- Final critic loss: {c_losses_hist[-1] if c_losses_hist else 'N/A'}\n")
            
            # FID information
            if fid_scores_hist:
                f.write("\n## FID Analysis\n\n")
                
                # Original best FID
                original_fid_scores = [fid for i, fid in enumerate(fid_scores_hist) 
                                      if fid_epochs_hist[i] <= training_info.get('original_epochs', 0)]
                original_best_fid = min(original_fid_scores) if original_fid_scores else float('inf')
                
                # Continuation best FID
                cont_fid_scores = [fid for i, fid in enumerate(fid_scores_hist) 
                                   if fid_epochs_hist[i] > training_info.get('original_epochs', 0)]
                cont_best_fid = min(cont_fid_scores) if cont_fid_scores else float('inf')
                
                f.write(f"- Original best FID score: {original_best_fid:.4f}\n")
                f.write(f"- Continuation best FID score: {cont_best_fid:.4f}\n")
                f.write(f"- FID improvement: {original_best_fid - cont_best_fid:.4f}\n")
                f.write(f"- Final FID score: {fid_scores_hist[-1] if fid_scores_hist else 'N/A'}\n")
                
                # FID trends
                if len(fid_scores_hist) >= 3:
                    last_fids = fid_scores_hist[-3:]
                    
                    if last_fids[0] > last_fids[-1] and last_fids[1] > last_fids[-1]:
                        f.write("\n- FID scores were continuing to improve at the end.\n")
                    elif all(abs(last_fids[0] - fid) < 1.0 for fid in last_fids[1:]):
                        f.write("\n- FID scores had stabilized at the end.\n")
                    elif last_fids[0] < last_fids[-1] and last_fids[1] < last_fids[-1]:
                        f.write("\n- FID scores were worsening at the end.\n")
            
            # KID information
            if kid_scores_hist:
                f.write("\n## KID Analysis\n\n")
                
                # Original best KID
                original_kid_scores = [kid for i, kid in enumerate(kid_scores_hist) 
                                      if kid_epochs_hist[i] <= training_info.get('original_epochs', 0)]
                original_best_kid = min(original_kid_scores) if original_kid_scores else float('inf')
                
                # Continuation best KID
                cont_kid_scores = [kid for i, kid in enumerate(kid_scores_hist) 
                                   if kid_epochs_hist[i] > training_info.get('original_epochs', 0)]
                cont_best_kid = min(cont_kid_scores) if cont_kid_scores else float('inf')
                
                f.write(f"- Original best KID score: {original_best_kid:.6f}\n")
                f.write(f"- Continuation best KID score: {cont_best_kid:.6f}\n")
                f.write(f"- KID improvement: {original_best_kid - cont_best_kid:.6f}\n")
                f.write(f"- Final KID score: {kid_scores_hist[-1] if kid_scores_hist else 'N/A'}\n")
                
                # KID trends
                if len(kid_scores_hist) >= 3:
                    last_kids = kid_scores_hist[-3:]
                    
                    if last_kids[0] > last_kids[-1] and last_kids[1] > last_kids[-1]:
                        f.write("\n- KID scores were continuing to improve at the end.\n")
                    elif all(abs(last_kids[0] - kid) < 0.0005 for kid in last_kids[1:]):
                        f.write("\n- KID scores had stabilized at the end.\n")
                    elif last_kids[0] < last_kids[-1] and last_kids[1] < last_kids[-1]:
                        f.write("\n- KID scores were worsening at the end.\n")
            
            # Overall improvement analysis
            if fid_scores_hist and kid_scores_hist:
                f.write("\n## Overall Improvement Analysis\n\n")
                fid_improved = cont_best_fid < original_best_fid if cont_fid_scores else False
                kid_improved = cont_best_kid < original_best_kid if cont_kid_scores else False
                
                if fid_improved and kid_improved:
                    f.write("- **Both FID and KID scores improved** during continuation training.\n")
                    f.write("- The model is generating more realistic images with better feature distributions.\n")
                elif fid_improved:
                    f.write("- **FID improved but KID did not** during continuation training.\n")
                    f.write("- The model is generating more realistic images, but the feature distribution may not have improved.\n")
                elif kid_improved:
                    f.write("- **KID improved but FID did not** during continuation training.\n")
                    f.write("- The model's feature distribution improved, but overall image realism may not have increased.\n")
                else:
                    f.write("- **Neither FID nor KID scores improved** during continuation training.\n")
                    f.write("- The continuation training did not yield better results than the original training.\n")
            
            # Recommendations
            f.write("\n## Recommendations\n\n")
            
            if training_info.get('stop_reason', '') == "Early stopping":
                f.write("- Training stopped early due to metric plateau. Consider:\n")
                f.write("  - Using an even lower learning rate (e.g., 0.00001)\n")
                f.write("  - Experimenting with different data augmentations\n")
                f.write("  - Trying a different architecture or hyperparameter configuration\n")
            
            if fid_scores_hist and kid_scores_hist and (fid_improved or kid_improved):
                f.write("- The continuation strategy was successful. For further improvement:\n")
                f.write("  - Try additional training with current parameters\n")
                f.write("  - Consider minor learning rate adjustments\n")
                f.write("  - Experiment with different critic iteration counts\n")
            elif fid_scores_hist and kid_scores_hist:
                f.write("- The continuation strategy did not yield improvements. Consider:\n")
                f.write("  - More significant learning rate adjustments\n")
                f.write("  - Different batch size or critic iterations\n")
                f.write("  - Alternative GAN architecture or training methodology\n")
            
            # Sample Images
            f.write("\n## Generated Images\n\n")
            f.write("Sample images are saved in the `samples` directory.\n")
            f.write("Best samples grid is available in the `analysis_results` directory.\n")
            
            # Plots
            f.write("\n## Analysis Plots\n\n")
            f.write("- Loss plot: `plots/cont_losses_*.png`\n")
            f.write("- FID plot: `plots/cont_fid_*.png`\n")
            f.write("- KID plot: `plots/cont_kid_*.png`\n")
            f.write("- Combined metrics plot: `plots/cont_combined_metrics_*.png`\n")
            if PLOT_FEATURE_SPACE:
                if USE_UMAP and UMAP_AVAILABLE:
                    f.write("- UMAP feature space visualization: `plots/cont_feature_space_umap_*.png`\n")
                else:
                    f.write("- t-SNE feature space visualization: `plots/cont_feature_space_tsne_*.png`\n")
        
        logger.info(f"Generated markdown report at {output_path}")
        return True
    except Exception as e:
        logger.error(f"Failed to generate markdown report: {e}", exc_info=True)
        return False

# ==============================================
# Initialization Phase
# ==============================================
logger.info("--- Initialization Phase ---")
timestamp_start = datetime.now().strftime('%Y%m%d_%H%M%S')

# --- Determine & Load Checkpoint ---
original_ckpt_dir = os.path.join(ORIGINAL_OUTPUT_DIR, "checkpoints")
checkpoint_to_load_path = find_checkpoint_file(original_ckpt_dir)

if checkpoint_to_load_path is None:
    logger.error("No checkpoint found to continue from. Exiting.")
    sys.exit(1)

checkpoint, history_data = load_checkpoint(checkpoint_to_load_path)
if checkpoint is None or history_data is None:
    logger.error("Failed to load checkpoint data. Exiting.")
    sys.exit(1)

# --- Load Original Configuration ---
original_config = extract_original_config(ORIGINAL_OUTPUT_DIR)

# --- Set Effective Configuration ---
if USE_ORIGINAL_SETTINGS:
    logger.info("Using original hyperparameters from checkpoint/config")
    LEARNING_RATE = original_config.get("LEARNING_RATE", 0.00005)
    BATCH_SIZE = original_config.get("BATCH_SIZE", 64)
    CRITIC_ITERATIONS = original_config.get("CRITIC_ITERATIONS", 5)
else:
    logger.info("Using new hyperparameters for continuation")
    LEARNING_RATE = NEW_LEARNING_RATE
    BATCH_SIZE = NEW_BATCH_SIZE
    CRITIC_ITERATIONS = NEW_CRITIC_ITERATIONS

# Load other necessary parameters from original config
NOISE_DIM = original_config.get("NOISE_DIM", 100)
CHANNELS_IMG = original_config.get("CHANNELS_IMG", 1)
G_FEATURES = original_config.get("G_FEATURES", 64)
C_FEATURES = original_config.get("C_FEATURES", 64)
IMAGE_SIZE = original_config.get("IMAGE_SIZE", 128)
CHECKPOINT_FREQ_EPOCHS = original_config.get("CHECKPOINT_FREQ_EPOCHS", 5)

# Override early stopping patience
EARLY_STOPPING_PATIENCE = NEW_EARLY_STOPPING_PATIENCE
logger.info(f"Early stopping patience set to {EARLY_STOPPING_PATIENCE}")

# Log effective configuration
logger.info("\n=== Continuation Configuration ===")
logger.info(f"Starting from epoch: {history_data['start_epoch']}")
logger.info(f"Additional epochs: {ADDITIONAL_EPOCHS}")
logger.info(f"Learning rate: {LEARNING_RATE}")
logger.info(f"Batch size: {BATCH_SIZE}")
logger.info(f"Critic iterations: {CRITIC_ITERATIONS}")
logger.info(f"Gradient clipping: {USE_GRADIENT_CLIPPING}")
logger.info(f"Primary evaluation metric: {PRIMARY_EVAL_METRIC}")
logger.info(f"Early stopping patience: {EARLY_STOPPING_PATIENCE}")
logger.info(f"Using {'best FID' if USE_BEST_FID_CKPT else 'latest'} checkpoint")
logger.info("================================\n")

# Save continuation configuration
cont_config = {
    "ORIGINAL_OUTPUT_DIR": ORIGINAL_OUTPUT_DIR,
    "CONTINUATION_DIR": CONTINUATION_DIR,
    "USE_BEST_FID_CKPT": USE_BEST_FID_CKPT,
    "USE_ORIGINAL_SETTINGS": USE_ORIGINAL_SETTINGS,
    "LEARNING_RATE": LEARNING_RATE,
    "BATCH_SIZE": BATCH_SIZE,
    "CRITIC_ITERATIONS": CRITIC_ITERATIONS,
    "USE_GRADIENT_CLIPPING": USE_GRADIENT_CLIPPING,
    "ADDITIONAL_EPOCHS": ADDITIONAL_EPOCHS,
    "EARLY_STOPPING_PATIENCE": EARLY_STOPPING_PATIENCE,
    "ORIGINAL_EPOCHS": history_data['start_epoch'],
    "ORIGINAL_BEST_FID": float(history_data['best_fid']),
    "ORIGINAL_BEST_KID": float(history_data['best_kid']) if 'best_kid' in history_data else float('inf'),
    "NOISE_DIM": NOISE_DIM,
    "CHANNELS_IMG": CHANNELS_IMG,
    "G_FEATURES": G_FEATURES,
    "C_FEATURES": C_FEATURES
}

try:
    with open(os.path.join(CONTINUATION_DIR, f"{CONTINUATION_PREFIX}config.json"), 'w') as f:
        json.dump(cont_config, f, indent=4, default=str)
    logger.info(f"Saved continuation configuration")
except Exception as e:
    logger.warning(f"Failed to save continuation configuration: {e}")

# --- Initialize Models ---
logger.info("Initializing models...")
generator = Generator(NOISE_DIM, CHANNELS_IMG, G_FEATURES).to(device)
critic = CriticSN(CHANNELS_IMG, C_FEATURES).to(device)

# Load model weights from checkpoint
logger.info("Loading model weights from checkpoint...")
generator.load_state_dict(checkpoint['generator_state_dict'])
critic.load_state_dict(checkpoint['critic_state_dict'])

# Initialize optimizers
opt_gen = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

# Load optimizer states from checkpoint
opt_gen.load_state_dict(checkpoint['optimizer_gen_state_dict'])
opt_critic.load_state_dict(checkpoint['optimizer_critic_state_dict'])

# Update learning rates if using new settings
if not USE_ORIGINAL_SETTINGS:
    for param_group in opt_gen.param_groups:
        param_group['lr'] = LEARNING_RATE
    for param_group in opt_critic.param_groups:
        param_group['lr'] = LEARNING_RATE
    logger.info(f"Updated optimizer learning rates to {LEARNING_RATE}")

# Initialize GradScalers
scaler_critic = GradScaler(enabled=AMP_ENABLED)
scaler_gen = GradScaler(enabled=AMP_ENABLED)

# Load scaler states if they exist in the checkpoint
if 'scaler_gen_state_dict' in checkpoint and scaler_gen is not None:
    scaler_gen.load_state_dict(checkpoint['scaler_gen_state_dict'])
    logger.info("Loaded GradScaler state for Generator.")
else:
    logger.warning("Generator GradScaler state not found in checkpoint.")

if 'scaler_critic_state_dict' in checkpoint and scaler_critic is not None:
    scaler_critic.load_state_dict(checkpoint['scaler_critic_state_dict'])
    logger.info("Loaded GradScaler state for Critic.")
else:
    logger.warning("Critic GradScaler state not found in checkpoint.")

# --- Setup Dataset and DataLoader ---
logger.info("Setting up Dataset and DataLoader...")
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

try:
    dataset = PollenDataset(
        root_dir=PREPROCESSED_DATA_DIR, 
        transform=transform,
        image_size=IMAGE_SIZE,      
        channels_img=CHANNELS_IMG 
    )
    
    if len(dataset) == 0:
        raise ValueError("Dataset is empty.")
    
    # Optimize worker count based on CPU cores
    dataloader_num_workers = min(max(os.cpu_count() // 2, 1), 4)  # Cap at 4
    
    dataloader = DataLoader(
        dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=dataloader_num_workers, 
        pin_memory=(device.type == 'cuda'), 
        persistent_workers=False,
        drop_last=True
    ) 
    
    logger.info(f"DataLoader created with {len(dataloader)} batches per epoch.")
    
except Exception as e:
    logger.error(f"Failed to create Dataset/DataLoader: {e}", exc_info=True)
    sys.exit(1)

# --- Prepare for Metric Calculations ---
logger.info("Preparing for FID/KID calculation...")
inception_model = None
real_mu, real_sigma, real_features = None, None, None

if FID_AVAILABLE:
    try:
        inception_model = get_inception_model(device)
        logger.info("InceptionV3 model loaded for metrics.")
        
        # Setup dataset for metric calculations
        fid_transform = transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5] * CHANNELS_IMG, [0.5] * CHANNELS_IMG)
        ])
        
        fid_dataset = PollenDataset(
            PREPROCESSED_DATA_DIR, 
            transform=fid_transform, 
            image_size=IMAGE_SIZE, 
            channels_img=CHANNELS_IMG
        )
        
        actual_fid_num_images = min(FID_NUM_IMAGES, len(fid_dataset))
        if actual_fid_num_images < FID_NUM_IMAGES:
            logger.warning(f"Using {actual_fid_num_images} images for metrics calculation (limited by dataset size).")
        
        fid_dataloader = DataLoader(
            fid_dataset, 
            batch_size=FID_BATCH_SIZE, 
            shuffle=False, 
            num_workers=dataloader_num_workers, 
            pin_memory=(device.type == 'cuda')
        )
        
        # Define paths for stats and features
        original_real_stats_path = os.path.join(ORIGINAL_OUTPUT_DIR, "real_fid_stats_10k.npz")
        original_real_features_path = os.path.join(ORIGINAL_OUTPUT_DIR, "real_inception_features_10k.npy")
        cont_real_stats_path = os.path.join(CONTINUATION_DIR, f"{CONTINUATION_PREFIX}real_stats.npz")
        cont_real_features_path = os.path.join(CONTINUATION_DIR, f"{CONTINUATION_PREFIX}real_features.npy")
        
        paths = {
            'original_real_stats_path': original_real_stats_path,
            'original_real_features_path': original_real_features_path,
            'real_stats_path': cont_real_stats_path,
            'real_features_path': cont_real_features_path
        }
        
        # Get real stats and features
        real_mu, real_sigma, real_features = get_real_stats_and_features(
            fid_dataloader,
            paths,
            inception_model,
            device,
            actual_fid_num_images,
            FID_BATCH_SIZE,
            FORCE_RECALCULATE_REAL_STATS
        )
        
        if real_mu is None or real_sigma is None or real_features is None:
            logger.error("Failed to obtain real statistics or features. Metrics will be disabled.")
            FID_AVAILABLE = False
            CALCULATE_KID = False
        else:
            logger.info("Successfully loaded/calculated real image statistics and features.")
        
        # Clean up dataloader/dataset to free memory
        del fid_dataloader, fid_dataset, fid_transform
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
    except Exception as e:
        logger.error(f"Error during metrics setup: {e}", exc_info=True)
        logger.error("Metrics calculation will be disabled.")
        FID_AVAILABLE = False
        CALCULATE_KID = False
        inception_model = None

# --- Initialize Training State Variables ---
global_step = history_data.get('global_step', 0)
start_epoch = history_data.get('start_epoch', 0)
g_losses_hist = history_data.get('g_losses_hist', [])
c_losses_hist = history_data.get('c_losses_hist', [])
fid_scores_hist = history_data.get('fid_scores_hist', [])
fid_epochs_hist = history_data.get('fid_epochs_hist', [])
kid_scores_hist = history_data.get('kid_scores_hist', [])
kid_std_hist = history_data.get('kid_std_hist', [])
kid_epochs_hist = history_data.get('kid_epochs_hist', [])
best_fid = history_data.get('best_fid', float('inf'))
best_kid = history_data.get('best_kid', float('inf'))

# Reset early stopping patience counter for continuation
epochs_no_improve = 0

# For loss stability monitoring
critic_loss_window = deque(maxlen=LOSS_STABILITY_WINDOW)
gen_loss_window = deque(maxlen=LOSS_STABILITY_WINDOW)

# Fixed noise for sample generation
fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)

# Holds the current best fake features from this continuation run
current_best_fake_features = None
current_best_epoch = None

# ==============================================
# Training Loop
# ==============================================
logger.info(f"\n--- Starting Continuation Training from Epoch {start_epoch + 1} ---")
log_gpu_memory_usage("Before Training Loop")

# Set models to training mode
generator.train()
critic.train()

# Main training parameters and tracking variables
training_start_time = time.time()
early_stop_triggered = False
stop_reason = "Completed all epochs"
total_epochs = start_epoch + ADDITIONAL_EPOCHS

try:  # Wrap main loop in try/except to handle interruptions
    for epoch in range(start_epoch, total_epochs):
        epoch_start_time = time.time()
        
        # Setup progress bar for this epoch
        loop_pbar = tqdm.tqdm(enumerate(dataloader), total=len(dataloader), leave=True, 
                        desc=f"Epoch [{epoch+1}/{total_epochs}]")
        
        # Track average losses for the epoch
        avg_loss_c_epoch = 0.0
        avg_loss_g_epoch = 0.0
        batches_in_epoch = 0

        for batch_idx, real_images in loop_pbar:
            # --- Heartbeat Log ---
            if HEARTBEAT_LOG_FREQ > 0 and global_step % HEARTBEAT_LOG_FREQ == 0:
                logger.debug(f"Heartbeat: Still running Step {global_step} in Epoch {epoch+1}")

            # --- Main Training Step ---
            try:
                if real_images is None:
                    logger.warning(f"Skipping batch {batch_idx} due to None data.")
                    continue
                    
                real_images = real_images.to(device)
                cur_batch_size = real_images.shape[0]
                if cur_batch_size == 0:
                    continue

                # --- Train Critic ---
                critic_loss_accum_iter = 0.0
                opt_critic.zero_grad(set_to_none=True)
                
                for _ in range(CRITIC_ITERATIONS):
                    noise = torch.randn(cur_batch_size, NOISE_DIM, 1, 1).to(device)
                    
                    with autocast(device_type='cuda', enabled=AMP_ENABLED):
                        with torch.no_grad():
                            fake_images = generator(noise)
                        critic_real = critic(real_images).reshape(-1)
                        critic_fake = critic(fake_images).reshape(-1)
                        loss_critic = torch.mean(critic_fake) - torch.mean(critic_real)
                    
                    # Check for NaN/Inf in critic loss (MANDATORY)
                    if not torch.isfinite(loss_critic):
                        logger.critical(f"Non-finite critic loss detected at Step {global_step}: {loss_critic.item()}. Stopping training.")
                        early_stop_triggered = True
                        stop_reason = "Non-finite loss detected"
                        break
                    
                    critic_loss_accum_iter += loss_critic.item()
                    
                    # Backprop critic
                    scaler_critic.scale(loss_critic).backward()
                
                # Skip optimizer step if early stop triggered
                if early_stop_triggered:
                    break
                    
                # Unscale for gradient clipping if enabled
                if USE_GRADIENT_CLIPPING:
                    scaler_critic.unscale_(opt_critic)
                    torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=1.0)
                
                # Optimizer step and scaler update
                scaler_critic.step(opt_critic)
                scaler_critic.update()
                
                avg_loss_c_iter = critic_loss_accum_iter / CRITIC_ITERATIONS

                # --- Train Generator ---
                opt_gen.zero_grad(set_to_none=True)
                
                with autocast(device_type='cuda', enabled=AMP_ENABLED):
                    noise_for_g = torch.randn(cur_batch_size, NOISE_DIM, 1, 1).to(device)
                    fake_images_for_g = generator(noise_for_g)
                    critic_fake_for_gen = critic(fake_images_for_g).reshape(-1)
                    loss_gen = -torch.mean(critic_fake_for_gen)
                
                # Check for NaN/Inf in generator loss (MANDATORY)
                if not torch.isfinite(loss_gen):
                    logger.critical(f"Non-finite generator loss detected at Step {global_step}: {loss_gen.item()}. Stopping training.")
                    early_stop_triggered = True
                    stop_reason = "Non-finite loss detected"
                    break
                
                # Backprop generator
                scaler_gen.scale(loss_gen).backward()
                
                # Unscale for gradient clipping if enabled
                if USE_GRADIENT_CLIPPING:
                    scaler_gen.unscale_(opt_gen)
                    torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
                
                # Optimizer step and scaler update
                scaler_gen.step(opt_gen)
                scaler_gen.update()

                # --- Tracking and Logging ---
                loss_g_item = loss_gen.item()
                avg_loss_c_epoch += avg_loss_c_iter
                avg_loss_g_epoch += loss_g_item
                batches_in_epoch += 1
                
                # Update loss stability tracking
                critic_loss_window.append(avg_loss_c_iter)
                gen_loss_window.append(loss_g_item)
                
                # Check for loss stability issues
                if MONITOR_LOSS_STABILITY and len(critic_loss_window) == LOSS_STABILITY_WINDOW:
                    # Calculate mean and std for critic loss
                    c_mean = np.mean(list(critic_loss_window)[:-1])  # All but the last
                    c_std = np.std(list(critic_loss_window)[:-1])
                    c_current = critic_loss_window[-1]
                    
                    # Check if current loss is an outlier
                    if abs(c_current - c_mean) > LOSS_STABILITY_THRESHOLD * c_std:
                        logger.warning(f"Critic loss spike detected at Step {global_step}: {c_current:.4f} "
                                      f"(mean: {c_mean:.4f}, std: {c_std:.4f})")
                    
                    # Same for generator loss
                    g_mean = np.mean(list(gen_loss_window)[:-1])
                    g_std = np.std(list(gen_loss_window)[:-1])
                    g_current = gen_loss_window[-1]
                    
                    if abs(g_current - g_mean) > LOSS_STABILITY_THRESHOLD * g_std:
                        logger.warning(f"Generator loss spike detected at Step {global_step}: {g_current:.4f} "
                                      f"(mean: {g_mean:.4f}, std: {g_std:.4f})")

                # Periodic logging (every 100 steps)
                if global_step % 100 == 0:
                    logger.info(f"Step {global_step} | Loss C: {avg_loss_c_iter:.4f}, Loss G: {loss_g_item:.4f}")
                    # Track memory usage
                    log_gpu_memory_usage(f"Step {global_step}")

                # Save periodic samples
                if global_step % SAMPLE_FREQ_STEPS == 0:
                    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                    logger.info(f"Saving samples at step {global_step}")
                    
                    generator.eval()
                    with torch.no_grad():
                        fake_samples = generator(fixed_noise)
                        img_grid = vutils.make_grid(fake_samples * 0.5 + 0.5, normalize=False)
                        vutils.save_image(img_grid, os.path.join(CONT_SAMPLE_DIR, 
                                        f"{CONTINUATION_PREFIX}sample_{epoch+1:04d}_{global_step:07d}_{timestamp}.png"))
                    generator.train()
                    
                    # Clean up sample generation tensors
                    del fake_samples, img_grid
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                # Update progress bar
                loop_pbar.set_description(f"Epoch [{epoch+1}/{total_epochs}]")
                loop_pbar.set_postfix(loss_C=avg_loss_c_iter, loss_G=loss_g_item, step=global_step)
                
                # Clean up batch tensors
                del real_images, noise, fake_images, critic_real, critic_fake, loss_critic
                del noise_for_g, fake_images_for_g, critic_fake_for_gen, loss_gen
                
                global_step += 1

            except RuntimeError as e:  # Handle CUDA OOM and other runtime errors
                if "out of memory" in str(e).lower():
                    logger.error(f"CUDA out of memory at Step {global_step}! Consider reducing batch size.")
                    logger.warning("Attempting to save checkpoint before stopping...")
                    try:
                        # Create emergency checkpoint
                        checkpoint_state = {
                            'epoch': epoch,
                            'step': global_step,
                            'generator_state_dict': generator.state_dict(),
                            'critic_state_dict': critic.state_dict(),
                            'optimizer_gen_state_dict': opt_gen.state_dict(),
                            'optimizer_critic_state_dict': opt_critic.state_dict(),
                            'scaler_gen_state_dict': scaler_gen.state_dict(),
                            'scaler_critic_state_dict': scaler_critic.state_dict(),
                            'g_losses_history': g_losses_hist,
                            'c_losses_history': c_losses_hist,
                            'fid_scores_history': fid_scores_hist,
                            'fid_epochs_history': fid_epochs_hist,
                            'kid_scores_history': kid_scores_hist,
                            'kid_std_history': kid_std_hist,
                            'kid_epochs_history': kid_epochs_hist,
                            'best_fid': best_fid,
                            'best_kid': best_kid,
                            'epochs_no_improve': epochs_no_improve
                        }
                        save_checkpoint(checkpoint_state, f"{CONTINUATION_PREFIX}emergency_oom_{timestamp_start}.pth.tar")
                    except Exception as save_e:
                        logger.error(f"Failed to save emergency checkpoint: {save_e}")
                    
                    early_stop_triggered = True
                    stop_reason = "CUDA Out of Memory"
                    break
                else:
                    logger.error(f"Runtime error at Step {global_step}: {e}", exc_info=True)
                    early_stop_triggered = True
                    stop_reason = f"Runtime error: {str(e)[:100]}..."
                    break
            except Exception as e:  # Handle general exceptions
                logger.error(f"Error at batch {batch_idx}, Step {global_step}: {e}", exc_info=True)
                early_stop_triggered = True
                stop_reason = f"Training error: {str(e)[:100]}..."
                break

        # --- End of Batch Loop ---
        if early_stop_triggered:
            break

        # --- End of Epoch ---
        epoch_duration = time.time() - epoch_start_time
        if batches_in_epoch > 0:
            avg_loss_c_epoch /= batches_in_epoch
            avg_loss_g_epoch /= batches_in_epoch
             
        # Append losses to history
        g_losses_hist.append(avg_loss_g_epoch)
        c_losses_hist.append(avg_loss_c_epoch)
             
        logger.info(f"Epoch [{epoch+1}/{total_epochs}] Completed in {epoch_duration:.2f}s | "
                   f"Avg Loss C: {avg_loss_c_epoch:.4f} | Avg Loss G: {avg_loss_g_epoch:.4f}")

        # --- Memory cleanup ---
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        log_gpu_memory_usage(f"After Epoch {epoch+1}")
        
        # --- Metric Calculation & Evaluation ---
        primary_metric_improved_this_epoch = False
        best_fid_updated = False
        best_kid_updated = False
        
        if FID_AVAILABLE and inception_model is not None and (epoch + 1) % FID_FREQ_EPOCHS == 0:
            current_fid, (current_kid, current_kid_std), fake_features_epoch = calculate_metrics(
                generator, 
                inception_model, 
                real_mu, 
                real_sigma,
                real_features,
                device, 
                NOISE_DIM, 
                FID_NUM_IMAGES, 
                FID_BATCH_SIZE
            )
            
            # Track FID if calculated and valid
            if current_fid != float('inf'):
                fid_scores_hist.append(current_fid)
                fid_epochs_hist.append(epoch + 1)
                logger.info(f"--- FID Score @ Epoch {epoch+1}: {current_fid:.4f} ---")
            
            # Track KID if calculated and valid
            if CALCULATE_KID and current_kid != float('inf'):
                kid_scores_hist.append(current_kid)
                kid_std_hist.append(current_kid_std)
                kid_epochs_hist.append(epoch + 1)
                logger.info(f"--- KID Score @ Epoch {epoch+1}: {current_kid:.6f} ± {current_kid_std:.6f} ---")
            
            # Check if metrics improved
            if PRIMARY_EVAL_METRIC == "FID" and current_fid != float('inf'):
                if current_fid < best_fid:
                    logger.info(f"FID improved: {best_fid:.4f} -> {current_fid:.4f}. Saving best FID checkpoint.")
                    best_fid = current_fid
                    best_fid_updated = True
                    primary_metric_improved_this_epoch = True
                    
                    # Save the best fake features for visualization
                    if fake_features_epoch is not None:
                        current_best_fake_features = fake_features_epoch
                        current_best_epoch = epoch + 1
                        save_best_fake_features(fake_features_epoch)
            
            elif PRIMARY_EVAL_METRIC == "KID" and CALCULATE_KID and current_kid != float('inf'):
                if current_kid < best_kid:
                    logger.info(f"KID improved: {best_kid:.6f} -> {current_kid:.6f}. Saving best KID checkpoint.")
                    best_kid = current_kid
                    best_kid_updated = True
                    primary_metric_improved_this_epoch = True
                    
                    # Save the best fake features for visualization
                    if fake_features_epoch is not None:
                        current_best_fake_features = fake_features_epoch
                        current_best_epoch = epoch + 1
                        save_best_fake_features(fake_features_epoch)
            
            # Track non-primary metric improvements too
            if PRIMARY_EVAL_METRIC == "KID" and current_fid != float('inf'):
                if current_fid < best_fid:
                    logger.info(f"FID improved: {best_fid:.4f} -> {current_fid:.4f}. (Not primary metric)")
                    best_fid = current_fid
                    best_fid_updated = True
            
            elif PRIMARY_EVAL_METRIC == "FID" and CALCULATE_KID and current_kid != float('inf'):
                if current_kid < best_kid:
                    logger.info(f"KID improved: {best_kid:.6f} -> {current_kid:.6f}. (Not primary metric)")
                    best_kid = current_kid
                    best_kid_updated = True
            
            # Define checkpoint state for saving
            checkpoint_state = {
                'epoch': epoch + 1, 
                'step': global_step,
                'generator_state_dict': generator.state_dict(),
                'critic_state_dict': critic.state_dict(),
                'optimizer_gen_state_dict': opt_gen.state_dict(),
                'optimizer_critic_state_dict': opt_critic.state_dict(),
                'scaler_gen_state_dict': scaler_gen.state_dict(),
                'scaler_critic_state_dict': scaler_critic.state_dict(),
                'g_losses_history': g_losses_hist,
                'c_losses_history': c_losses_hist,
                'fid_scores_history': fid_scores_hist,
                'fid_epochs_history': fid_epochs_hist,
                'kid_scores_history': kid_scores_hist,
                'kid_std_history': kid_std_hist,
                'kid_epochs_history': kid_epochs_hist,
                'best_fid': best_fid,
                'best_kid': best_kid,
                'epochs_no_improve': epochs_no_improve
            }
            
            # --- Save Best State ---
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            
            if primary_metric_improved_this_epoch:
                # Reset early stopping counter
                epochs_no_improve = 0
                
                # Save best model for inference
                model_config = {
                    'noise_dim': NOISE_DIM,
                    'channels_img': CHANNELS_IMG,
                    'features_g': G_FEATURES
                }
                save_best_model_inference(generator, model_config)
                
                # Save best samples
                save_best_samples(generator, fixed_noise, epoch + 1)
            
            if best_fid_updated:
                # Save best FID checkpoint (overwrite previous best)
                save_checkpoint(checkpoint_state, f"{CONTINUATION_PREFIX}best_fid_checkpoint.pth.tar")
                
                # Also save a timestamped version for archiving
                if primary_metric_improved_this_epoch and PRIMARY_EVAL_METRIC == "FID":
                    save_checkpoint(checkpoint_state, 
                                  f"{CONTINUATION_PREFIX}best_fid_checkpoint_e{epoch+1:04d}_fid{current_fid:.2f}_{timestamp}.pth.tar")
            
            if best_kid_updated:
                # Save best KID checkpoint (overwrite previous best)
                save_checkpoint(checkpoint_state, f"{CONTINUATION_PREFIX}best_kid_checkpoint.pth.tar")
                
                # Also save a timestamped version for archiving
                if primary_metric_improved_this_epoch and PRIMARY_EVAL_METRIC == "KID":
                    save_checkpoint(checkpoint_state, 
                                  f"{CONTINUATION_PREFIX}best_kid_checkpoint_e{epoch+1:04d}_kid{current_kid:.6f}_{timestamp}.pth.tar")
            
            # --- Early Stopping Check ---
            if not primary_metric_improved_this_epoch:
                epochs_no_improve += 1
                logger.info(f"{PRIMARY_EVAL_METRIC} did not improve. "
                          f"Patience: {epochs_no_improve}/{EARLY_STOPPING_PATIENCE}.")
                
                if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
                    logger.warning(f"--- Early stopping triggered after {epochs_no_improve} epochs "
                                  f"without {PRIMARY_EVAL_METRIC} improvement. ---")
                    stop_reason = f"Early stopping ({PRIMARY_EVAL_METRIC})"
                    early_stop_triggered = True
            
            # --- Optional Per-Epoch Plotting ---
            if PLOT_PER_EPOCH:
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                
                # Generate plots consistently for all metrics regardless of calculation outcome
                # This fixes issues with plots not being updated
                
                # Always plot losses
                plot_losses(g_losses_hist, c_losses_hist, CONT_PLOT_DIR, 
                          f"{CONTINUATION_PREFIX}losses", epoch + 1, timestamp)
                           
                # Plot FID if we have any history
                if fid_scores_hist and len(fid_scores_hist) > 0:
                    plot_metric(fid_scores_hist, fid_epochs_hist, best_fid, "FID", CONT_PLOT_DIR,
                              f"{CONTINUATION_PREFIX}fid", epoch + 1, timestamp)
                          
                # Plot KID if we have any history
                if CALCULATE_KID and kid_scores_hist and len(kid_scores_hist) > 0:
                    plot_metric(kid_scores_hist, kid_epochs_hist, best_kid, "KID", CONT_PLOT_DIR,
                              f"{CONTINUATION_PREFIX}kid", epoch + 1, timestamp, std_devs=kid_std_hist)
                
                # Only attempt combined plot if both metrics have data and lengths match
                if (fid_scores_hist and kid_scores_hist and 
                    len(fid_scores_hist) > 0 and len(kid_scores_hist) > 0):
                    try:
                        # Use the shortest history length to avoid mismatch errors
                        common_length = min(len(fid_scores_hist), len(kid_scores_hist))
                        if common_length > 0:
                            fid_subset = fid_scores_hist[:common_length]
                            kid_subset = kid_scores_hist[:common_length]
                            epochs_subset = fid_epochs_hist[:common_length]  # Use FID epochs as common epochs
                            
                            plot_combined_metrics(fid_subset, kid_subset, epochs_subset, CONT_PLOT_DIR,
                                               f"{CONTINUATION_PREFIX}combined_metrics", epoch + 1, timestamp)
                    except Exception as plot_e:
                        logger.error(f"Error creating combined metrics plot: {plot_e}")
                        # Continue training even if plotting fails
        
        # --- Memory cleanup after metrics --- 
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # --- Checkpointing ---
        checkpoint_state = {
            'epoch': epoch + 1, 
            'step': global_step,
            'generator_state_dict': generator.state_dict(),
            'critic_state_dict': critic.state_dict(),
            'optimizer_gen_state_dict': opt_gen.state_dict(),
            'optimizer_critic_state_dict': opt_critic.state_dict(),
            'scaler_gen_state_dict': scaler_gen.state_dict(),
            'scaler_critic_state_dict': scaler_critic.state_dict(),
            'g_losses_history': g_losses_hist,
            'c_losses_history': c_losses_hist,
            'fid_scores_history': fid_scores_hist,
            'fid_epochs_history': fid_epochs_hist,
            'kid_scores_history': kid_scores_hist,
            'kid_std_history': kid_std_hist,
            'kid_epochs_history': kid_epochs_hist,
            'best_fid': best_fid,
            'best_kid': best_kid,
            'epochs_no_improve': epochs_no_improve
        }
        
        # Save checkpoint every N epochs or if it's the last epoch
        if (epoch + 1) % CHECKPOINT_FREQ_EPOCHS == 0 or (epoch + 1) == total_epochs:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            save_checkpoint(checkpoint_state, f"{CONTINUATION_PREFIX}checkpoint_epoch_{epoch+1:04d}_{timestamp}.pth.tar")
            
        # Always save the latest checkpoint
        save_checkpoint(checkpoint_state, f"{CONTINUATION_PREFIX}latest_checkpoint.pth.tar") 
        
        # Check for early stopping
        if early_stop_triggered:
            break

    # --- End of Epoch Loop ---

except KeyboardInterrupt:
    logger.warning("--- Training Interrupted by User ---")
    stop_reason = "Manual Interruption"
    
    logger.warning("Attempting to save checkpoint and generate plots before exit...")
    
    # Save emergency checkpoint
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    try:
        checkpoint_state = {
            'epoch': epoch if 'epoch' in locals() else start_epoch, 
            'step': global_step,
            'generator_state_dict': generator.state_dict(),
            'critic_state_dict': critic.state_dict(),
            'optimizer_gen_state_dict': opt_gen.state_dict(),
            'optimizer_critic_state_dict': opt_critic.state_dict(),
            'scaler_gen_state_dict': scaler_gen.state_dict(),
            'scaler_critic_state_dict': scaler_critic.state_dict(),
            'g_losses_history': g_losses_hist,
            'c_losses_history': c_losses_hist,
            'fid_scores_history': fid_scores_hist,
            'fid_epochs_history': fid_epochs_hist,
            'kid_scores_history': kid_scores_hist,
            'kid_std_history': kid_std_hist,
            'kid_epochs_history': kid_epochs_hist,
            'best_fid': best_fid,
            'best_kid': best_kid,
            'epochs_no_improve': epochs_no_improve
        }
        save_checkpoint(checkpoint_state, f"{CONTINUATION_PREFIX}interrupted_{timestamp}.pth.tar")
    except Exception as e:
        logger.error(f"Failed to save checkpoint after interruption: {e}", exc_info=True)
        
except Exception as e:
    logger.critical(f"Critical error during training: {e}", exc_info=True)
    stop_reason = f"Error: {str(e)[:100]}..."
    
    # Try to save emergency checkpoint
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    try:
        checkpoint_state = {
            'epoch': epoch if 'epoch' in locals() else start_epoch, 
            'step': global_step,
            'generator_state_dict': generator.state_dict(),
            'critic_state_dict': critic.state_dict(),
            'optimizer_gen_state_dict': opt_gen.state_dict(),
            'optimizer_critic_state_dict': opt_critic.state_dict(),
            'scaler_gen_state_dict': scaler_gen.state_dict(),
            'scaler_critic_state_dict': scaler_critic.state_dict(),
            'g_losses_history': g_losses_hist,
            'c_losses_history': c_losses_hist,
            'fid_scores_history': fid_scores_hist,
            'fid_epochs_history': fid_epochs_hist,
            'kid_scores_history': kid_scores_hist,
            'kid_std_history': kid_std_hist,
            'kid_epochs_history': kid_epochs_hist,
            'best_fid': best_fid,
            'best_kid': best_kid,
            'epochs_no_improve': epochs_no_improve
        }
        save_checkpoint(checkpoint_state, f"{CONTINUATION_PREFIX}error_{timestamp}.pth.tar")
    except Exception as save_e:
        logger.error(f"Failed to save checkpoint after error: {save_e}", exc_info=True)

# ==============================================
# Post-Training Analysis
# ==============================================
logger.info("="*80)
logger.info("--- Post-Training Analysis Phase ---")

# Calculate training duration and final status
total_training_time = time.time() - training_start_time
final_epoch = epoch if 'epoch' in locals() else start_epoch
logger.info(f"Training finished after {final_epoch-start_epoch} epochs, during Epoch {final_epoch} ({total_training_time:.2f} seconds)")
logger.info(f"Stop reason: {stop_reason}")
logger.info(f"Final Global Step: {global_step}")

if fid_scores_hist:
    logger.info(f"Best FID Score: {best_fid:.4f}")
    
    # Compare with original best FID
    original_best_fid = history_data.get('best_fid', float('inf'))
    if best_fid < original_best_fid:
        improvement = original_best_fid - best_fid
        logger.info(f"FID improved by {improvement:.4f} points from original best of {original_best_fid:.4f}")
    else:
        logger.info(f"FID did not improve from original best of {original_best_fid:.4f}")

if kid_scores_hist:
    logger.info(f"Best KID Score: {best_kid:.6f}")
    
    # Compare with original best KID
    original_best_kid = history_data.get('best_kid', float('inf'))
    if best_kid < original_best_kid:
        improvement = original_best_kid - best_kid
        logger.info(f"KID improved by {improvement:.6f} points from original best of {original_best_kid:.6f}")
    else:
        logger.info(f"KID did not improve from original best of {original_best_kid:.6f}")

# Generate timestamp for final files
final_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# --- Generate Final Plots ---
logger.info("Generating final visualization plots...")

# Generate individual plots
plot_losses(g_losses_hist, c_losses_hist, CONT_PLOT_DIR, 
          f"{CONTINUATION_PREFIX}losses", final_epoch+1, final_timestamp, 
          is_final=True, stop_reason=stop_reason)

if fid_scores_hist:
    plot_metric(fid_scores_hist, fid_epochs_hist, best_fid, "FID", CONT_PLOT_DIR,
              f"{CONTINUATION_PREFIX}fid", final_epoch+1, final_timestamp, 
              is_final=True, stop_reason=stop_reason)

if kid_scores_hist:
    plot_metric(kid_scores_hist, kid_epochs_hist, best_kid, "KID", CONT_PLOT_DIR,
              f"{CONTINUATION_PREFIX}kid", final_epoch+1, final_timestamp, 
              is_final=True, stop_reason=stop_reason, std_devs=kid_std_hist)
    
    # Generate combined FID/KID plot if both are available
    if fid_scores_hist:
        try:
            # Use the shortest history length to avoid mismatch errors
            common_length = min(len(fid_scores_hist), len(kid_scores_hist))
            if common_length > 0:
                fid_subset = fid_scores_hist[:common_length]
                kid_subset = kid_scores_hist[:common_length]
                epochs_subset = fid_epochs_hist[:common_length]  # Use FID epochs as common epochs
                
                plot_combined_metrics(fid_subset, kid_subset, epochs_subset, CONT_PLOT_DIR,
                                   f"{CONTINUATION_PREFIX}combined_metrics", final_epoch+1, final_timestamp, 
                                   is_final=True, stop_reason=stop_reason)
        except Exception as e:
            logger.error(f"Failed to generate combined metrics plot: {e}", exc_info=True)

# --- Generate Feature Space Visualization ---
if PLOT_FEATURE_SPACE and real_features is not None:
    # Try to load best fake features if not already in memory
    if current_best_fake_features is None:
        try:
            best_fake_features_path = os.path.join(CONT_ANALYSIS_DIR, f"{CONTINUATION_PREFIX}best_fake_features.npy")
            if os.path.exists(best_fake_features_path):
                current_best_fake_features = np.load(best_fake_features_path)
                logger.info(f"Loaded best fake features for visualization")
        except Exception as e:
            logger.warning(f"Failed to load best fake features: {e}")
    
    # If we have fake features, generate visualization
    if current_best_fake_features is not None:
        if USE_UMAP and UMAP_AVAILABLE:
            plot_feature_space(real_features, current_best_fake_features, "umap", CONT_PLOT_DIR,
                             f"{CONTINUATION_PREFIX}feature_space", final_epoch+1, final_timestamp, 
                             stop_reason=stop_reason)
        elif TSNE_AVAILABLE:
            plot_feature_space(real_features, current_best_fake_features, "tsne", CONT_PLOT_DIR,
                             f"{CONTINUATION_PREFIX}feature_space", final_epoch+1, final_timestamp, 
                             stop_reason=stop_reason)
        else:
            logger.warning("Neither UMAP nor t-SNE available. Skipping feature space visualization.")
    else:
        logger.warning("No fake features available for visualization.")

# --- Generate Final Sample Grid ---
logger.info("Generating final sample grid...")
generate_final_sample_grid(CONT_SAMPLE_DIR, epoch=final_epoch+1, timestamp=final_timestamp)

# --- Generate Markdown Report ---
logger.info("Generating markdown report...")
training_info = {
    'original_epochs': start_epoch,
    'current_epoch': final_epoch+1,
    'total_additional_epochs': final_epoch+1-start_epoch,
    'stop_reason': stop_reason,
    'training_time': total_training_time,
}

report_path = os.path.join(CONT_ANALYSIS_DIR, f"{CONTINUATION_PREFIX}training_report_{final_timestamp}.md")
generate_markdown_report({
    'g_losses_hist': g_losses_hist,
    'c_losses_hist': c_losses_hist,
    'fid_scores_hist': fid_scores_hist,
    'fid_epochs_hist': fid_epochs_hist,
    'kid_scores_hist': kid_scores_hist,
    'kid_epochs_hist': kid_epochs_hist,
    'best_fid': best_fid,
    'best_kid': best_kid
}, training_info, report_path)

# ==============================================
# Optional Precision/Recall Calculation
# ==============================================
try:
    logger.info("Skipping optional Precision/Recall calculation (not implemented)")
    # This is where you could add code for Precision/Recall metrics
    # It would require more libraries and computation
except Exception as e:
    logger.error(f"Error during Precision/Recall calculation: {e}")

# ==============================================
# Final Cleanup
# ==============================================
logger.info("--- Running Final Cleanup ---")
log_gpu_memory_usage("Final")

logger.info("="*80)
logger.info(f"Continuation script execution complete. Results saved in: {CONTINUATION_DIR}")
logger.info(f"Final best FID: {best_fid:.4f}")
if kid_scores_hist:
    logger.info(f"Final best KID: {best_kid:.6f}")
logger.info(f"Final Stop Reason: {stop_reason}")
logger.info("="*80)

# Close all open logs
logging.shutdown()