In [None]:
# ==============================================================================
# Cell 1: Project Setup, Configuration, and Path Management
# ==============================================================================
# This cell imports all necessary libraries and defines all the high-level
# parameters and paths for the training run. It is the main control panel
# for any experiment.
# ==============================================================================

# --- Core Libraries ---
import os
import random
import time
import logging
from datetime import datetime
import itertools
import glob

# --- Deep Learning & Data Processing Libraries ---
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import tifffile
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torch_fidelity import calculate_metrics
from tqdm.auto import tqdm

# ==============================================================================
# 1. EXPERIMENT CONFIGURATION
# ==============================================================================
# All user-configurable parameters are grouped here for easy access.

# --- Training Control ---
LOAD_MODEL = False                # Set to True to resume training from a checkpoint
EPOCH_TO_LOAD_FROM = 0          # The epoch number of the checkpoint to load
NUM_EPOCHS = 50                  # The total number of epochs to train for

# --- Model & Training Hyperparameters ---
BATCH_SIZE = 16                  # Number of images per training step. Adjust based on GPU memory.
IMG_SIZE = 512                   # All images will be resized to this dimension.
IMG_CHANNELS = 3                 # Number of channels for the images (3 for RGB).
LEARNING_RATE_GEN = 2e-4         # Learning rate for the Generator's Adam optimizer.
LEARNING_RATE_DISC = 2e-4        # Learning rate for the Discriminator's Adam optimizer.
LAMBDA_CYCLE = 10.0              # Weight for the cycle-consistency loss.
LAMBDA_IDENTITY = 5.0            # Weight for the identity loss (0.5 * LAMBDA_CYCLE).
STEPS_PER_EPOCH = 2000           # Number of batches to process per "epoch" for faster feedback.

# --- Logging & Saving Frequency ---
SAVE_MODEL = True                # Set to True to save model checkpoints.
SAVE_MODEL_EVERY_N_EPOCHS = 2    # How often to save a checkpoint.
SAVE_SAMPLES_EVERY_N_EPOCHS = 5  # How often to save generated sample images.
CALC_METRICS_EVERY_N_EPOCHS = 5  # How often to calculate FID and plot loss graphs.


# ==============================================================================
# 2. SYSTEM & PATH SETUP (AUTOMATED)
# ==============================================================================
# This section automatically sets up devices, paths, and logging based on the
# configuration above. No user changes should be needed here.

# --- Device Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Directory & Path Definitions (Using H&E / Reticulin Naming) ---
NOTEBOOK_CWD = os.getcwd()
H_AND_E_BASE_DIR = "H&E_split_dataset"
RETICULIN_BASE_DIR = "Retic_split_dataset"

PATH_TRAIN_H_FOLDER = os.path.join(NOTEBOOK_CWD, H_AND_E_BASE_DIR, "train")
PATH_TEST_H_FOLDER  = os.path.join(NOTEBOOK_CWD, H_AND_E_BASE_DIR, "test")
PATH_TRAIN_R_FOLDER = os.path.join(NOTEBOOK_CWD, RETICULIN_BASE_DIR, "train")
PATH_TEST_R_FOLDER  = os.path.join(NOTEBOOK_CWD, RETICULIN_BASE_DIR, "test")

# --- Dynamic Naming for Outputs ---
# Creates a unique ID for this run to keep outputs organized.
run_id_string = f"cyclegan_bs{BATCH_SIZE}_img{IMG_SIZE}"
OUTPUT_IMAGE_DIR = os.path.join(NOTEBOOK_CWD, "saved_images", run_id_string)
CHECKPOINT_SAVE_DIR = os.path.join(NOTEBOOK_CWD, "checkpoints", run_id_string)
LOG_DIR = os.path.join(NOTEBOOK_CWD, "logs")

# Base filenames for checkpoints
CHECKPOINT_GEN_H_BASE = f"genh_{run_id_string}.pth.tar"
CHECKPOINT_GEN_R_BASE = f"genr_{run_id_string}.pth.tar"
CHECKPOINT_DISC_H_BASE = f"disch_{run_id_string}.pth.tar"
CHECKPOINT_DISC_R_BASE = f"discr_{run_id_string}.pth.tar"

# --- Create Output Directories ---
os.makedirs(OUTPUT_IMAGE_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_SAVE_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)


# ==============================================================================
# 3. INITIALIZATION & PRE-RUN CHECKS
# ==============================================================================

# --- Configure Logging ---
log_filename = f"train_log_{run_id_string}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
log_filepath = os.path.join(LOG_DIR, log_filename)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.FileHandler(log_filepath),
        logging.StreamHandler()
    ]
)

logging.info("--- Starting New Training Session ---")
logging.info(f"Log file will be saved to: {os.path.abspath(log_filepath)}")

# --- Log System & Hyperparameter Details ---
logging.info(f"--- Device Setup ---")
logging.info(f"Using device: {DEVICE}")
if DEVICE == "cuda":
    logging.info(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")

logging.info(f"--- Key Hyperparameters ---")
logging.info(f"Batch Size: {BATCH_SIZE}, Image Size: {IMG_SIZE}x{IMG_SIZE}")
logging.info(f"Learning Rates (Gen/Disc): {LEARNING_RATE_GEN}/{LEARNING_RATE_DISC}")
logging.info(f"Total Epochs: {NUM_EPOCHS}, Steps per Epoch: {STEPS_PER_EPOCH}")

# --- Verify Dataset Paths ---
logging.info("\n--- Verifying Dataset Paths ---")
all_paths_ok = True
for domain_label, path_folder in [("Train H&E", PATH_TRAIN_H_FOLDER), ("Train Reticulin", PATH_TRAIN_R_FOLDER), ("Test H&E", PATH_TEST_H_FOLDER), ("Test Reticulin", PATH_TEST_R_FOLDER)]:
    if not os.path.isdir(path_folder):
        logging.error(f"CRITICAL ERROR: {domain_label} folder NOT FOUND at: {os.path.abspath(path_folder)}")
        all_paths_ok = False
    else:
        num_files = len(glob.glob(os.path.join(path_folder, "*.tif")))
        logging.info(f"SUCCESS: Found {domain_label} folder with {num_files} '.tif' files.")
        if num_files == 0:
            logging.warning(f"WARNING: Folder for {domain_label} exists, but contains no '.tif' files.")
if not all_paths_ok:
    raise FileNotFoundError("One or more essential data folders are missing. Please check the paths defined in Cell 1.")

# --- Set Seed for Reproducibility ---
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if DEVICE == 'cuda':
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.benchmark = True

logging.info("\n--- Cell 1 Setup Complete. Ready to proceed. ---")

In [None]:
# ==============================================================================
# Cell 2: Data Loading Pipeline (Final Clean Version)
# ==============================================================================
# This cell defines the process for loading images from disk and preparing them
# for the model. It assumes all necessary libraries were imported in Cell 1.
# It is optimized for performance by deferring heavy transformations (like resizing)
# to the GPU, which will be handled in the training loop (Cell 5).
# ==============================================================================

# --- Utility Function: Image Loading ---
def load_tiff_to_tensor_raw(path, target_channels):
    try:
        img_np = tifffile.imread(path)
        if img_np is None: raise IOError("tifffile.imread returned None")
        if img_np.ndim == 2:  img_np = np.stack([img_np] * 3, axis=-1)
        elif img_np.shape[-1] == 4:  img_np = img_np[..., :3]
        elif img_np.shape[-1] == 1:  img_np = np.concatenate([img_np] * 3, axis=-1)
        if img_np.dtype == np.uint8:
            tensor = torch.from_numpy(img_np.astype(np.float32)).permute(2, 0, 1) / 255.0
        else:
            tensor = torch.from_numpy(img_np.astype(np.float32)).permute(2, 0, 1) / 65535.0
        return (tensor * 2.0) - 1.0
    except Exception as e:
        logging.error(f"LOAD_TENSOR_ERROR for {path}: {e}")
        return None # Return None on failure, handled by the caller

# --- Dataset Class Definition ---
class PairedImageDataset(Dataset):
    def __init__(self, root_H_folder, root_R_folder, domain_name, steps_per_epoch=None):
        self.paths_H = sorted(glob.glob(os.path.join(root_H_folder, "*.tif")))
        self.paths_R = sorted(glob.glob(os.path.join(root_R_folder, "*.tif")))
        self.len_H = len(self.paths_H)
        self.len_R = len(self.paths_R)
        self.steps_per_epoch = steps_per_epoch

        logging.info(f"--- {domain_name} Dataset Initialized ---")
        logging.info(f"Found {self.len_H} H&E images and {self.len_R} Reticulin images.")

        if self.len_H == 0 or self.len_R == 0:
            self.length = 0; raise ValueError(f"CRITICAL: {domain_name} dataset is empty.")
        elif self.steps_per_epoch is not None:
            self.length = self.steps_per_epoch * BATCH_SIZE
            logging.info(f"Using fixed steps per epoch. Epoch will contain {self.length} items.")
        else: self.length = max(self.len_H, self.len_R)

    def __len__(self): return self.length

    def __getitem__(self, index):
        # This retry loop is robust against corrupt files found during training.
        while True:
            img_H_path = self.paths_H[random.randint(0, self.len_H - 1)]
            img_R_path = self.paths_R[random.randint(0, self.len_R - 1)]
            img_H_tensor = load_tiff_to_tensor_raw(img_H_path, IMG_CHANNELS)
            img_R_tensor = load_tiff_to_tensor_raw(img_R_path, IMG_CHANNELS)
            if img_H_tensor is not None and img_R_tensor is not None:
                return img_H_tensor, img_R_tensor
            # If a file fails to load, the loop continues and gets a new random pair.

# --- DataLoader Instantiation ---
num_dataloader_workers = 0 
logging.info(f"Using num_workers = {num_dataloader_workers} for DataLoaders.")

train_dataset = PairedImageDataset(
    root_H_folder=PATH_TRAIN_H_FOLDER, root_R_folder=PATH_TRAIN_R_FOLDER, 
    domain_name="Train", steps_per_epoch=STEPS_PER_EPOCH
)
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=num_dataloader_workers, pin_memory=True, drop_last=True
)
logging.info(f"Training DataLoader created with {len(train_dataloader)} batches per epoch.")

vis_dataset = PairedImageDataset(
    root_H_folder=PATH_TEST_H_FOLDER, root_R_folder=PATH_TEST_R_FOLDER, 
    domain_name="Visualization"
)
vis_dataloader = DataLoader(
    vis_dataset, batch_size=16, shuffle=False, 
    num_workers=0, pin_memory=True
)
logging.info(f"Visualization DataLoader created with {len(vis_dataloader)} total batches.")

logging.info("\n--- Cell 2 Data Pipeline Setup Complete ---")

In [None]:
# ==============================================================================
# Cell 3: Model Architectures (U-Net as Generator and PatchGAN as Discriminator)
# ==============================================================================
# This cell defines the neural network architectures. The Generator now uses a
# U-Net architecture with skip connections, which is excellent for tasks
# requiring high-resolution detail preservation.
# ==============================================================================

# --- Helper Block for U-Net ---
class ConvBlock(nn.Module):
    """A standard Down-Convolution or Up-Convolution block."""
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        if down:
            self.conv = nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
        else:
            self.conv = nn.ConvTranspose2d(in_channels, out_channels, **kwargs)
        
        self.norm = nn.InstanceNorm2d(out_channels)
        self.act = nn.ReLU(inplace=True) if use_act else nn.Identity()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

# --- 1. GENERATOR (U-Net Architecture) ---
class UNetGenerator(nn.Module):
    """
    The Generator network based on the U-Net architecture. It uses an
    encoder-decoder structure with skip connections to pass spatial information
    from the downsampling path to the upsampling path, preserving details.
    """
    def __init__(self, img_channels=3, features=64):
        super().__init__()

        # --- Encoder (Downsampling Path) ---
        self.encoder1 = ConvBlock(img_channels, features, kernel_size=4, stride=2, padding=1) # 64x256x256
        self.encoder2 = ConvBlock(features, features * 2, kernel_size=4, stride=2, padding=1) # 128x128x128
        self.encoder3 = ConvBlock(features * 2, features * 4, kernel_size=4, stride=2, padding=1) # 256x64x64
        self.encoder4 = ConvBlock(features * 4, features * 8, kernel_size=4, stride=2, padding=1) # 512x32x32

        # --- Bottleneck ---
        self.bottleneck = ConvBlock(features * 8, features * 8, kernel_size=4, stride=2, padding=1) # 512x16x16
        # Note: Added one more bottleneck layer for symmetry if input is 512
        self.bottleneck2 = ConvBlock(features*8, features*8, kernel_size=4, stride=2, padding=1) # 512x8x8

        # --- Decoder (Upsampling Path) ---
        self.up0 = ConvBlock(features*8, features*8, down=False, kernel_size=4, stride=2, padding=1)
        self.up1 = ConvBlock(features * 8 * 2, features * 8, down=False, kernel_size=4, stride=2, padding=1)
        self.up2 = ConvBlock(features * 8 * 2, features * 4, down=False, kernel_size=4, stride=2, padding=1)
        self.up3 = ConvBlock(features * 4 * 2, features * 2, down=False, kernel_size=4, stride=2, padding=1)
        self.up4 = ConvBlock(features * 2 * 2, features, down=False, kernel_size=4, stride=2, padding=1)
        
        # --- Final Output Layer ---
        self.final = nn.Sequential(
            nn.ConvTranspose2d(features * 2, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(), # Tanh squashes output to [-1, 1]
        )

    def forward(self, x):
        # Pass through encoder, saving outputs for skip connections
        d1 = self.encoder1(x)
        d2 = self.encoder2(d1)
        d3 = self.encoder3(d2)
        d4 = self.encoder4(d3)
        b1 = self.bottleneck(d4)
        b2 = self.bottleneck2(b1)

        # Pass through decoder, concatenating skip connections
        u0 = self.up0(b2)
        u1 = self.up1(torch.cat([u0, b1], dim=1))
        u2 = self.up2(torch.cat([u1, d4], dim=1))
        u3 = self.up3(torch.cat([u2, d3], dim=1))
        u4 = self.up4(torch.cat([u3, d2], dim=1))
        return self.final(torch.cat([u4, d1], dim=1))

# --- 2. DISCRIMINATOR (PatchGAN - Unchanged) ---
class Discriminator(nn.Module):
    """The PatchGAN Discriminator network (critic)."""
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2, inplace=True),
            self._discriminator_block(features[0], features[1], stride=2),
            self._discriminator_block(features[1], features[2], stride=2),
            self._discriminator_block(features[2], features[3], stride=1), # Last block has stride 1
            nn.Conv2d(features[3], 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")
        )
    def _discriminator_block(self, in_channels, out_channels, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
    def forward(self, x):
        return self.model(x)

logging.info("Model architectures defined (U-Net Generator, PatchGAN Discriminator).")

In [None]:
# ==============================================================================
# Cell 4: Model and Training Initialization
# ==============================================================================
# This cell instantiates all the necessary components for training:
# 1. Loss Functions: Defines the objectives for the networks.
# 2. Models: Creates instances of the U-Net Generator and PatchGAN Discriminator.
# 3. Optimizers: Defines how the model weights are updated.
# 4. Utilities: Sets up the Replay Buffer and Checkpoint functions.
# 5. Resume Logic: Handles loading a saved model to continue training.
# ==============================================================================

# --- 1. Loss Function Definitions ---
# The criteria used to measure model performance and calculate gradients.

# Adversarial Loss: Measures how well the generator fools the discriminator.
# MSELoss is a common choice for this in many GAN implementations.
adv_loss_fn = nn.MSELoss()

# Cycle-Consistency Loss: Measures the difference between an original image and
# its "round-trip" reconstruction. L1Loss (Mean Absolute Error) encourages
# less blurry results than MSELoss.
cycle_loss_fn = nn.L1Loss()

# Identity Loss: A regularization term that encourages the generator to not
# change an image that is already in the target domain.
identity_loss_fn = nn.L1Loss()

logging.info("Loss functions defined (Adversarial, Cycle, Identity).")


# --- 2. Model, Optimizer, and Scaler Initialization ---

# <<< --- THIS IS THE KEY CHANGE --- >>>
# Instantiate the U-Net generators and the discriminators from Cell 3.
# gen_H translates H&E -> Reticulin, gen_R translates Reticulin -> H&E.
gen_H = UNetGenerator(img_channels=IMG_CHANNELS, features=64).to(DEVICE)
gen_R = UNetGenerator(img_channels=IMG_CHANNELS, features=64).to(DEVICE)
disc_H = Discriminator(in_channels=IMG_CHANNELS).to(DEVICE)
disc_R = Discriminator(in_channels=IMG_CHANNELS).to(DEVICE)
logging.info("U-Net Generators and Discriminator models initialized and moved to device.")

# Create the optimizers. Adam is the standard choice for GANs.
# The `itertools.chain` trick groups both generators' parameters together,
# so a single optimizer can update them both based on the combined generator loss.
opt_gen = optim.Adam(
    itertools.chain(gen_H.parameters(), gen_R.parameters()), 
    lr=LEARNING_RATE_GEN, 
    betas=(0.5, 0.999)
)
opt_disc_H = optim.Adam(disc_H.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999))
opt_disc_R = optim.Adam(disc_R.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999))
logging.info("Adam optimizers initialized.")

# Create GradScalers for Automatic Mixed Precision (AMP) to prevent
# numerical underflow and speed up training.
scaler_gen = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
scaler_disc_H = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
scaler_disc_R = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
logging.info("GradScalers initialized for AMP.")


# --- 3. Training Utilities ---

class ReplayBuffer:
    """
    A buffer to store a history of previously generated images.
    
    Instead of training the discriminator on just the latest batch of fakes,
    this buffer provides a mix of recent and slightly older fakes, which
    is a well-known technique for stabilizing GAN training.
    """
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data_batch):
        images_to_return = []
        for element in data_batch.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                images_to_return.append(element)
            else:
                if random.random() > 0.5:
                    # With 50% probability, replace an old image and return it
                    i = random.randint(0, self.max_size - 1)
                    images_to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    # With 50% probability, return the new image without replacing
                    images_to_return.append(element)
        return torch.cat(images_to_return)

buffer_fake_H = ReplayBuffer()
buffer_fake_R = ReplayBuffer()
logging.info("Replay buffers initialized.")


# --- Checkpoint Saving & Loading Functions ---

def save_checkpoint(model, optimizer, scaler, filename):
    """Saves model, optimizer, and scaler states to a file."""
    full_path = os.path.join(CHECKPOINT_SAVE_DIR, filename)
    logging.info(f"Saving checkpoint => {full_path}")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scaler": scaler.state_dict(),
    }
    torch.save(checkpoint, full_path)

def load_checkpoint(model, optimizer, scaler, lr, base_filename, epoch, device):
    """Loads model, optimizer, and scaler states from a file."""
    filename = base_filename.replace(".pth.tar", f"_epoch{epoch}.pth.tar")
    full_path = os.path.join(CHECKPOINT_SAVE_DIR, filename)
    
    if not os.path.exists(full_path):
        logging.warning(f"Checkpoint not found at: {full_path}")
        return False

    logging.info(f"Loading checkpoint: {full_path}")
    try:
        checkpoint = torch.load(full_path, map_location=device)
        model.load_state_dict(checkpoint["state_dict"])
        if optimizer and "optimizer" in checkpoint:
            optimizer.load_state_dict(checkpoint["optimizer"])
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr
        if scaler and "scaler" in checkpoint:
            scaler.load_state_dict(checkpoint["scaler"])
        logging.info("Checkpoint loaded successfully.")
        return True
    except Exception as e:
        logging.error(f"Failed to load checkpoint {full_path}: {e}")
        return False


# --- 4. Resume Logic ---

# This variable will be updated if a model is loaded successfully.
# It tracks the last completed epoch.
start_epoch = 0 

if LOAD_MODEL:
    if EPOCH_TO_LOAD_FROM > 0:
        logging.info(f"\n--- Attempting to resume training from epoch {EPOCH_TO_LOAD_FROM} ---")
        
        # Load all four models. All must succeed to resume.
        # NOTE: The base checkpoint filenames must match the new run_id_string from Cell 1
        success_gh = load_checkpoint(gen_H, opt_gen, scaler_gen, LEARNING_RATE_GEN, CHECKPOINT_GEN_H_BASE, EPOCH_TO_LOAD_FROM, DEVICE)
        success_gr = load_checkpoint(gen_R, opt_gen, scaler_gen, LEARNING_RATE_GEN, CHECKPOINT_GEN_R_BASE, EPOCH_TO_LOAD_FROM, DEVICE)
        success_dh = load_checkpoint(disc_H, opt_disc_H, scaler_disc_H, LEARNING_RATE_DISC, CHECKPOINT_DISC_H_BASE, EPOCH_TO_LOAD_FROM, DEVICE)
        success_dr = load_checkpoint(disc_R, opt_disc_R, scaler_disc_R, LEARNING_RATE_DISC, CHECKPOINT_DISC_R_BASE, EPOCH_TO_LOAD_FROM, DEVICE)
        
        if success_gh and success_gr and success_dh and success_dr:
            start_epoch = EPOCH_TO_LOAD_FROM
            logging.info(f"--- Resume successful. Training will continue from epoch {start_epoch + 1}. ---")
        else:
            logging.error("--- Resume failed. One or more checkpoints could not be loaded. Starting from scratch. ---")
            # Force a fresh start if loading fails
            LOAD_MODEL = False
            start_epoch = 0
    else:
        logging.info("LOAD_MODEL is True, but EPOCH_TO_LOAD_FROM is 0. Starting training from scratch.")
else:
    logging.info("\nLOAD_MODEL is False. Starting training from scratch.")

logging.info(f"Effective starting epoch for training loop: {start_epoch + 1}")
logging.info("\n--- Cell 4 Initialization Complete ---")

In [None]:
# ==============================================================================
# Cell 5: The Training Loop
# ==============================================================================
# This cell contains the main training logic. It orchestrates the data loading,
# model training for both generators and discriminators, and periodic saving of
# checkpoints, sample images, and performance graphs.
# ==============================================================================

# ==============================================================================
# 1. UTILITY FUNCTIONS
# ==============================================================================
# Helper functions called during the training loop.

def denormalize(image_tensor):
    """Converts a tensor from the [-1, 1] range back to [0, 1] for saving."""
    return torch.clamp((image_tensor + 1.0) / 2.0, 0.0, 1.0)

def generate_and_save_samples(gen_H, gen_R, epoch, dataloader):
    """Generates and saves a sample image translation pair."""
    if not dataloader:
        logging.warning(f"Epoch {epoch}: Visualization dataloader not available, skipping sample generation.")
        return
    
    gen_H.eval()
    gen_R.eval()
    
    try:
        real_H_raw, real_R_raw = next(iter(dataloader))
    except StopIteration:
        logging.warning("Visualization dataloader exhausted. Cannot generate new samples.")
        return

    with torch.no_grad():
        # Move raw data to GPU and apply GPU-side transforms
        real_H = gpu_vis_transform(real_H_raw.to(DEVICE))
        real_R = gpu_vis_transform(real_R_raw.to(DEVICE))
        
        # Generate translations
        fake_R = gen_H(real_H)
        fake_H = gen_R(real_R)

    # Prepare a concatenated image for easy comparison: [Real_H, Fake_R, Real_R, Fake_H]
    # We only show the first image from the batch.
    combined_image = torch.cat([
        denormalize(real_H[0].cpu()),
        denormalize(fake_R[0].cpu()),
        denormalize(real_R[0].cpu()),
        denormalize(fake_H[0].cpu())
    ], dim=2) # Concatenate horizontally

    save_path = os.path.join(OUTPUT_IMAGE_DIR, f'sample_epoch_{epoch:04d}.png')
    save_image(combined_image, save_path)
    logging.info(f"Saved sample image to {save_path}")

    gen_H.train()
    gen_R.train()


def plot_and_save_metrics(metrics, epoch, save_dir):
    """Plots the collected loss history and saves it to a file."""
    # (This function is kept from your previous robust version)
    num_epochs = max(len(v) for v in metrics.values()) if metrics else 0
    if num_epochs == 0: return

    epochs_range = list(range(1, num_epochs + 1))
    plot_groups = [
        {'title': 'Generator Adversarial Loss', 'keys': ['gen_G_loss']},
        {'title': 'Discriminator Loss', 'keys': ['disc_H_loss', 'disc_R_loss']},
        {'title': 'Cycle Consistency Loss', 'keys': ['cycle_H_loss', 'cycle_R_loss']},
        {'title': 'Identity Loss', 'keys': ['identity_H_loss', 'identity_R_loss']},
    ]
    
    fig, axes = plt.subplots(len(plot_groups), 1, figsize=(16, 7 * len(plot_groups)), sharex=True)
    for i, group in enumerate(plot_groups):
        ax = axes[i]
        for key in group['keys']:
            if key in metrics and any(not np.isnan(v) for v in metrics.get(key, [])):
                ax.plot(epochs_range, metrics[key], marker='o', linestyle='-', markersize=4, label=key)
        
        tick_step = 5 if num_epochs >= 50 else (2 if num_epochs > 20 else 1)
        ax.set_xticks([e for e in epochs_range if e % tick_step == 0 or e == 1])
        ax.set_title(group['title'], fontsize=14)
        ax.set_ylabel("Loss"); ax.legend(); ax.grid(True)
        
    axes[-1].set_xlabel("Epochs")
    fig.suptitle(f'Training Metrics for {num_epochs} Epochs', fontsize=18)
    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    save_path = os.path.join(save_dir, f'metrics_graph_epoch_{epoch:04d}.png')
    plt.savefig(save_path); plt.close(fig)
    logging.info(f"Saved metrics graph to {save_path}")

# ==============================================================================
# 2. MODULAR TRAINING FUNCTIONS
# ==============================================================================

def train_discriminators(real_H, real_R, gen_H, gen_R, disc_H, disc_R, opt_disc_H, opt_disc_R, scaler_disc_H, scaler_disc_R, buffer_fake_H, buffer_fake_R):
    """Performs one training step for both discriminators."""
    losses = {}
    
    # --- Train Discriminator H (for H&E images) ---
    with torch.cuda.amp.autocast():
        fake_H = gen_R(real_R)
        D_H_real = disc_H(real_H)
        D_H_fake = disc_H(buffer_fake_H.push_and_pop(fake_H.detach()))
        D_H_real_loss = adv_loss_fn(D_H_real, torch.ones_like(D_H_real))
        D_H_fake_loss = adv_loss_fn(D_H_fake, torch.zeros_like(D_H_fake))
        D_H_loss = (D_H_real_loss + D_H_fake_loss) / 2
    
    opt_disc_H.zero_grad()
    scaler_disc_H.scale(D_H_loss).backward()
    scaler_disc_H.step(opt_disc_H)
    scaler_disc_H.update()
    losses['disc_H_loss'] = D_H_loss.item()
    
    # --- Train Discriminator R (for Reticulin images) ---
    with torch.cuda.amp.autocast():
        fake_R = gen_H(real_H)
        D_R_real = disc_R(real_R)
        D_R_fake = disc_R(buffer_fake_R.push_and_pop(fake_R.detach()))
        D_R_real_loss = adv_loss_fn(D_R_real, torch.ones_like(D_R_real))
        D_R_fake_loss = adv_loss_fn(D_R_fake, torch.zeros_like(D_R_fake))
        D_R_loss = (D_R_real_loss + D_R_fake_loss) / 2
        
    opt_disc_R.zero_grad()
    scaler_disc_R.scale(D_R_loss).backward()
    scaler_disc_R.step(opt_disc_R)
    scaler_disc_R.update()
    losses['disc_R_loss'] = D_R_loss.item()
    
    return losses, fake_H, fake_R # Return fakes for generator training

def train_generators(real_H, real_R, fake_H, fake_R, gen_H, gen_R, disc_H, disc_R, opt_gen, scaler_gen):
    """Performs one training step for both generators."""
    losses = {}
    
    with torch.cuda.amp.autocast():
        # --- Adversarial Loss ---
        D_H_fake = disc_H(fake_H)
        D_R_fake = disc_R(fake_R)
        loss_G_H_adv = adv_loss_fn(D_H_fake, torch.ones_like(D_H_fake))
        loss_G_R_adv = adv_loss_fn(D_R_fake, torch.ones_like(D_R_fake))
        
        # --- Cycle Consistency Loss ---
        cycled_H = gen_R(fake_R)
        loss_cycle_H = cycle_loss_fn(real_H, cycled_H)
        cycled_R = gen_H(fake_H)
        loss_cycle_R = cycle_loss_fn(real_R, cycled_R)
        
        # --- Identity Loss ---
        if LAMBDA_IDENTITY > 0:
            identity_H = gen_R(real_H)
            loss_identity_H = identity_loss_fn(real_H, identity_H)
            identity_R = gen_H(real_R)
            loss_identity_R = identity_loss_fn(real_R, identity_R)
            losses['identity_H_loss'] = loss_identity_H.item()
            losses['identity_R_loss'] = loss_identity_R.item()
        else:
            loss_identity_H = loss_identity_R = 0
            
        # --- Total Generator Loss ---
        total_G_loss = (
            loss_G_H_adv + loss_G_R_adv +
            (loss_cycle_H * LAMBDA_CYCLE) +
            (loss_cycle_R * LAMBDA_CYCLE) +
            (loss_identity_H * LAMBDA_IDENTITY) +
            (loss_identity_R * LAMBDA_IDENTITY)
        )
    
    opt_gen.zero_grad()
    scaler_gen.scale(total_G_loss).backward()
    scaler_gen.step(opt_gen)
    scaler_gen.update()

    losses['gen_G_loss'] = total_G_loss.item()
    losses['cycle_H_loss'] = loss_cycle_H.item()
    losses['cycle_R_loss'] = loss_cycle_R.item()
    
    return losses

# ==============================================================================
# 3. MAIN TRAINING EXECUTION
# ==============================================================================

# --- GPU-side Transforms ---
gpu_train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True),
    transforms.RandomHorizontalFlip(p=0.5),
])
gpu_vis_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True),
])

# --- Metrics History Initialization ---
# Clear or initialize history at the start of a training run.
history_keys = ["gen_G_loss", "disc_H_loss", "disc_R_loss", "cycle_H_loss", "cycle_R_loss", "identity_H_loss", "identity_R_loss"]
metrics_history = {key: [] for key in history_keys}
logging.info("Initialized metrics history for this run.")

# --- Main Training Loop ---
logging.info(f"\n--- Starting Training Loop from Epoch {start_epoch + 1} ---")
if train_dataloader is None:
    raise RuntimeError("CRITICAL ERROR: train_dataloader is None. Training cannot start.")

for epoch_iter_idx in range(NUM_EPOCHS - start_epoch):
    current_epoch = start_epoch + epoch_iter_idx + 1
    start_time = time.time()
    
    # Dictionaries to aggregate losses for the epoch summary
    epoch_losses_sum = {key: 0.0 for key in history_keys}
    
    loop = tqdm(train_dataloader, desc=f"Epoch [{current_epoch}/{NUM_EPOCHS}]", leave=True)

    for batch_idx, (real_H_raw, real_R_raw) in enumerate(loop):
        # Move raw data to GPU and apply transforms
        real_H = gpu_train_transform(real_H_raw.to(DEVICE))
        real_R = gpu_train_transform(real_R_raw.to(DEVICE))
        
        # --- Run one training step ---
        disc_losses, fake_H, fake_R = train_discriminators(real_H, real_R, gen_H, gen_R, disc_H, disc_R, opt_disc_H, opt_disc_R, scaler_disc_H, scaler_disc_R, buffer_fake_H, buffer_fake_R)
        gen_losses = train_generators(real_H, real_R, fake_H, fake_R, gen_H, gen_R, disc_H, disc_R, opt_gen, scaler_gen)
        
        # --- Update tracking and progress bar ---
        all_losses = {**disc_losses, **gen_losses}
        for key, value in all_losses.items():
            epoch_losses_sum[key] += value
        
        loop.set_postfix(G=gen_losses['gen_G_loss'], D=disc_losses['disc_H_loss'] + disc_losses['disc_R_loss'], refresh=True)
    
    loop.close()
    
    # --- End of Epoch Actions ---
    
    # 1. Log average losses
    for key in history_keys:
        avg_loss = epoch_losses_sum[key] / len(train_dataloader)
        metrics_history[key].append(avg_loss)
    
    avg_loss_str = " ".join([f"{k.replace('_loss', '')}:{v[-1]:.3f}" for k, v in metrics_history.items() if v])
    logging.info(f"Epoch [{current_epoch}/{NUM_EPOCHS}] Summary | Time: {time.time()-start_time:.1f}s | Avgs: [{avg_loss_str}]")
    
    # 2. Save sample images
    if current_epoch % SAVE_SAMPLES_EVERY_N_EPOCHS == 0 or current_epoch == 1:
        generate_and_save_samples(gen_H, gen_R, current_epoch, vis_dataloader)
    
    # 3. Plot metrics
    if current_epoch % CALC_METRICS_EVERY_N_EPOCHS == 0:
        plot_and_save_metrics(metrics_history, current_epoch, OUTPUT_IMAGE_DIR)

    # 4. Save model checkpoints
    if SAVE_MODEL and (current_epoch % SAVE_MODEL_EVERY_N_EPOCHS == 0):
        save_checkpoint(gen_H, opt_gen, scaler_gen, CHECKPOINT_GEN_H_BASE.replace(".pth.tar", f"_epoch{current_epoch}.pth.tar"))
        save_checkpoint(gen_R, opt_gen, scaler_gen, CHECKPOINT_GEN_R_BASE.replace(".pth.tar", f"_epoch{current_epoch}.pth.tar"))
        save_checkpoint(disc_H, opt_disc_H, scaler_disc_H, CHECKPOINT_DISC_H_BASE.replace(".pth.tar", f"_epoch{current_epoch}.pth.tar"))
        save_checkpoint(disc_R, opt_disc_R, scaler_disc_R, CHECKPOINT_DISC_R_BASE.replace(".pth.tar", f"_epoch{current_epoch}.pth.tar"))
    
    # 5. Check for stop signal (for graceful pausing)
    stop_file_path = os.path.join(NOTEBOOK_CWD, "stop_training.txt")
    if os.path.exists(stop_file_path):
        logging.info(f"\n--- Stop file detected. Saving final model at epoch {current_epoch} and stopping. ---")
        # Perform one last save
        save_checkpoint(gen_H, opt_gen, scaler_gen, CHECKPOINT_GEN_H_BASE.replace(".pth.tar", f"_epoch{current_epoch}.pth.tar"))
        # ... (add saves for other 3 models)
        os.remove(stop_file_path)
        logging.info("--- Training paused gracefully. ---")
        break

logging.info("\n--- Training Complete! ---")