In [None]:
# ==============================================================================
# Cell 1: Evaluation Setup and Configuration
# ==============================================================================
# This cell imports all necessary libraries and defines the configuration for
# evaluating a trained CycleGAN model. It sets which model checkpoint to load
# and where to find the necessary data and save the outputs.
# ==============================================================================

# --- Core Libraries ---
import os
import glob
import re
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# --- Deep Learning & Data Processing Libraries ---
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import tifffile
from torch_fidelity import calculate_metrics

# ==============================================================================
# 1. EVALUATION CONFIGURATION
# ==============================================================================
# All user-configurable parameters are grouped here.

# --- Model Selection ---
# Set the epoch number of the trained model you want to evaluate.
EPOCH_TO_EVALUATE = 50

# --- Training Run Parameters ---
# These parameters MUST exactly match the parameters of the training run you
# are evaluating to ensure the correct checkpoint paths are constructed.
BATCH_SIZE = 16
IMG_SIZE = 512
IMG_CHANNELS = 3

# --- Evaluation Output Settings ---
# Number of example image pairs to generate for the visual gallery.
NUM_GALLERY_IMAGES = 10


# ==============================================================================
# 2. SYSTEM & PATH SETUP (AUTOMATED)
# ==============================================================================
# This section automatically sets up devices and constructs all necessary paths
# based on the configuration above.

# --- Device Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Dynamic Path Construction ---
# This ensures that all paths are consistent with the training run being evaluated.
NOTEBOOK_CWD = os.getcwd()
run_id_string = f"cyclegan_bs{BATCH_SIZE}_img{IMG_SIZE}" # Re-create the same run ID as training

# Source paths for test data
H_AND_E_BASE_DIR = "H&E_split_dataset"
RETICULIN_BASE_DIR = "Retic_split_dataset"
PATH_TEST_H_FOLDER  = os.path.join(NOTEBOOK_CWD, H_AND_E_BASE_DIR, "test")
PATH_TEST_R_FOLDER  = os.path.join(NOTEBOOK_CWD, RETICULIN_BASE_DIR, "test")

# Source path for model checkpoints
CHECKPOINT_DIR = os.path.join(NOTEBOOK_CWD, "checkpoints", run_id_string)

# Destination paths for evaluation outputs
EVAL_OUTPUT_DIR = os.path.join(NOTEBOOK_CWD, "evaluation_results", run_id_string, f"epoch_{EPOCH_TO_EVALUATE}")
os.makedirs(EVAL_OUTPUT_DIR, exist_ok=True)


# ==============================================================================
# 3. INITIALIZATION & PRE-RUN CHECKS
# ==============================================================================

print("--- Starting Final Evaluation ---")
print(f"Device: {DEVICE}")
print(f"Evaluating Run ID: '{run_id_string}'")
print(f"Loading Checkpoint from Epoch: {EPOCH_TO_EVALUATE}")
print(f"Evaluation outputs will be saved to: {EVAL_OUTPUT_DIR}")

# --- Verify that the necessary checkpoint and data folders exist ---
if not os.path.isdir(CHECKPOINT_DIR):
    raise FileNotFoundError(f"CRITICAL: Checkpoint directory not found at {CHECKPOINT_DIR}. Please ensure the run ID and paths are correct.")

if not os.path.isdir(PATH_TEST_H_FOLDER) or not os.path.isdir(PATH_TEST_R_FOLDER):
    raise FileNotFoundError(f"CRITICAL: One or more test data folders not found. Please check paths.")

print("\n--- Cell 1 Setup Complete. Ready to load models and data. ---")

In [None]:
# ==============================================================================
# Cell 2: Load Trained Models and Test Data
# ==============================================================================
# This cell loads the trained generator models from the specified checkpoint
# epoch and prepares the test dataset for inference and evaluation.
# It assumes models and dataset classes are defined in separate .py files.
# ==============================================================================

# --- 1. Import Custom Modules ---
# Instead of copy-pasting class definitions, we import them.
# This makes the code cleaner and ensures you're always using the same
# model architecture as you did during training.
# (This assumes you have created src/models.py and src/dataset.py)
from src.models import Generator
from src.dataset import PairedImageDataset

# ==============================================================================
# 2. MODEL LOADING
# ==============================================================================

def load_generator_checkpoint(model, checkpoint_path):
    """
    Loads a generator's weights from a saved checkpoint file.

    Args:
        model (nn.Module): An instantiated generator model.
        checkpoint_path (str): The full path to the .pth.tar checkpoint file.
    
    Returns:
        nn.Module or None: The model with loaded weights, or None on failure.
    """
    if not os.path.exists(checkpoint_path):
        print(f"!!! ERROR: Checkpoint file not found at '{checkpoint_path}'.")
        return None
    
    print(f"Loading checkpoint: {os.path.basename(checkpoint_path)}")
    try:
        # Load the checkpoint dictionary onto the correct device (CPU or GPU)
        checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
        
        # Load the weights into the model
        model.load_state_dict(checkpoint["state_dict"])
        
        # Set the model to evaluation mode (disables dropout, etc.)
        model.eval()
        
        print("... Success!")
        return model
    except Exception as e:
        print(f"!!! ERROR: Failed to load checkpoint. Reason: {e}")
        return None

# --- Instantiate and Load the Generators ---
# We only need the generators for evaluation, not the discriminators.

# gen_H translates H&E -> Reticulin
# gen_R translates Reticulin -> H&E
gen_H = Generator(img_channels=IMG_CHANNELS, num_residuals=9).to(DEVICE)
gen_R = Generator(img_channels=IMG_CHANNELS, num_residuals=9).to(DEVICE)

# Construct the full paths to the checkpoint files to load
checkpoint_path_H = os.path.join(CHECKPOINT_DIR, f"genh_{run_id_string}_epoch{EPOCH_TO_EVALUATE}.pth.tar")
checkpoint_path_R = os.path.join(CHECKPOINT_DIR, f"genr_{run_id_string}_epoch{EPOCH_TO_EVALUATE}.pth.tar")

# Load the weights
gen_H = load_generator_checkpoint(gen_H, checkpoint_path_H)
gen_R = load_generator_checkpoint(gen_R, checkpoint_path_R)


# ==============================================================================
# 3. TEST DATASET PREPARATION
# ==============================================================================
# We use the same PairedImageDataset class from training to load the test data.
# We do not need to specify steps_per_epoch, so it will use the full test set.

if gen_H is not None and gen_R is not None:
    # We only need to create the dataset if the models loaded successfully.
    vis_dataset = PairedImageDataset(
        root_H_folder=PATH_TEST_H_FOLDER,
        root_R_folder=PATH_TEST_R_FOLDER,
        domain_name="Evaluation"
    )
    
    # We create a DataLoader for efficient iteration during FID calculation.
    # The batch size can be larger here to speed up image generation.
    vis_dataloader = DataLoader(
        vis_dataset,
        batch_size=32, # Larger batch size for faster inference
        shuffle=False,
        num_workers=0
    )
    print(f"\nTest dataset loaded with {len(vis_dataset)} images.")
    print("--- Cell 2 Model and Data Loading Complete ---")
else:
    print("\n!!! Halting due to model loading failure. Cannot proceed with evaluation. ---")

In [None]:
# ==============================================================================
# Cell 3: Generate Visual Outputs (Gallery and FID Fakes)
# ==============================================================================
# This cell performs two tasks:
# 1. Creates a visual gallery of sample translations.
# 2. Generates a full set of fake images (as .tif) needed for FID calculation.
# ==============================================================================

def denormalize(t):
    """Converts a tensor from the [-1, 1] range back to a [0, 1] numpy array."""
    # Move tensor to CPU, convert to NumPy, and reorder to (H, W, C) for saving
    return torch.clamp((t.cpu() + 1.0) / 2.0, 0.0, 1.0).permute(1, 2, 0).numpy()

def generate_outputs(gen_H, gen_R, dataloader, num_gallery_images, fid_fake_dir_H, fid_fake_dir_R, gallery_dir):
    """
    Generates all necessary visual outputs: a sample gallery and full sets of
    fake images in TIFF format for FID calculation.
    """
    if not dataloader:
        print("!!! Dataloader not available, skipping output generation.")
        return

    gen_H.eval()
    gen_R.eval()
    
    # Create output directories
    os.makedirs(gallery_dir, exist_ok=True)
    os.makedirs(fid_fake_dir_H, exist_ok=True)
    os.makedirs(fid_fake_dir_R, exist_ok=True)
    
    # Use a larger batch size for efficient inference
    eval_dataloader = DataLoader(dataloader.dataset, batch_size=32, shuffle=False, num_workers=0)
    
    print("\n--- Generating all visual outputs ---")
    with torch.no_grad():
        for i, (real_H_raw, real_R_raw) in enumerate(tqdm(eval_dataloader, desc="Generating Images")):
            real_H = transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True)(real_H_raw.to(DEVICE))
            real_R = transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True)(real_R_raw.to(DEVICE))
            
            # --- Generate Fake Images ---
            fake_R_batch = gen_H(real_H)  # H&E -> Fake Reticulin
            fake_H_batch = gen_R(real_R)  # Reticulin -> Fake H&E

            # --- Save Full Set for FID Calculation (as .tif) ---
            for j in range(real_H.size(0)):
                # Save fake Reticulin image (from real H&E)
                fake_R_np = denormalize(fake_R_batch[j])
                tifffile.imsave(os.path.join(fid_fake_dir_R, f"fake_R_{i * 32 + j:05d}.tif"), (fake_R_np * 255).astype(np.uint8))
                
                # Save fake H&E image (from real Reticulin)
                fake_H_np = denormalize(fake_H_batch[j])
                tifffile.imsave(os.path.join(fid_fake_dir_H, f"fake_H_{i * 32 + j:05d}.tif"), (fake_H_np * 255).astype(np.uint8))

            # --- Save a few samples for the Visual Gallery (as .png) ---
            if i < num_gallery_images:
                gallery_row = torch.cat([
                    denormalize(real_H[0].cpu().permute(1, 2, 0)),
                    denormalize(fake_R_batch[0].cpu().permute(1, 2, 0)),
                    denormalize(real_R[0].cpu().permute(1, 2, 0)),
                    denormalize(fake_H_batch[0].cpu().permute(1, 2, 0))
                ], dim=1) # Concatenate horizontally (dim=1 for H, W, C)
                
                plt.imsave(os.path.join(gallery_dir, f"gallery_row_{i+1:02d}.png"), gallery_row)

    print(f"Visual gallery saved in: {gallery_dir}")
    print(f"Fake TIFFs for FID saved in: {fid_fake_dir_H} and {fid_fake_dir_R}")

# --- Define Paths and Execute ---
GALLERY_OUTPUT_DIR = os.path.join(EVAL_OUTPUT_DIR, "visual_gallery")
FID_FAKE_H_DIR = os.path.join(EVAL_OUTPUT_DIR, "fid_fakes_H_from_R_tif") # Fake H&E
FID_FAKE_R_DIR = os.path.join(EVAL_OUTPUT_DIR, "fid_fakes_R_from_H_tif") # Fake Reticulin

if 'gen_H' in locals() and 'gen_R' in locals() and gen_H is not None and gen_R is not None:
    generate_outputs(
        gen_H=gen_H,
        gen_R=gen_R,
        dataloader=vis_dataloader,
        num_gallery_images=NUM_GALLERY_IMAGES,
        gallery_dir=GALLERY_OUTPUT_DIR,
        fid_fake_dir_H=FID_FAKE_H_DIR,
        fid_fake_dir_R=FID_FAKE_R_DIR
    )
else:
    print("!!! Models not loaded correctly, skipping output generation.")

print("\n--- Cell 3 Image Generation Complete ---")

In [None]:
# ==============================================================================
# Cell 4: Calculate Final Metrics and Generate Report
# ==============================================================================
# This cell uses the pre-generated fake TIFF images and the training log to
# calculate the final FID scores and plot the complete loss curves.
# ==============================================================================

# --- FID Calculation Function ---
def calculate_fid_from_tif_folders(real_tif_folder, fake_tif_folder, domain_label):
    """Calculates FID between a folder of real and a folder of fake TIFF images."""
    print(f"\n--- Calculating FID for {domain_label} ---")
    try:
        # Verify that both directories exist and contain .tif files
        if not os.path.isdir(real_tif_folder) or not glob.glob(os.path.join(real_tif_folder, "*.tif")):
            raise FileNotFoundError(f"Real TIF images not found in: {real_tif_folder}")
        if not os.path.isdir(fake_tif_folder) or not glob.glob(os.path.join(fake_tif_folder, "*.tif")):
            raise FileNotFoundError(f"Fake TIF images not found in: {fake_tif_folder}")

        print(f"Comparing real TIFs in '{os.path.basename(real_tif_folder)}' vs fake TIFs in '{os.path.basename(fake_tif_folder)}'")
        
        # Tell torch-fidelity to look for .tif files in BOTH input directories
        metrics = calculate_metrics(
            input1=real_tif_folder,
            input2=fake_tif_folder,
            cuda=(DEVICE=="cuda"),
            isc=False,
            fid=True,
            input1_exts=['tif'], # Specify for input 1
            input2_exts=['tif']  # Specify for input 2
        )
        fid_score = metrics['frechet_inception_distance']
        
        print(f"==========================================")
        print(f"  FID Score for {domain_label}: {fid_score:.4f}")
        print(f"==========================================")
        return fid_score
    except Exception as e:
        print(f"!!! FID calculation failed for {domain_label}: {e}")
        return float('nan')

# --- Log Parsing and Plotting Functions (Unchanged from previous correct version) ---
# ... (Copy the robust parse_log_for_metrics and plot_final_metrics functions here) ...
def parse_log_for_metrics(log_filepath):
    history_keys = ["gen_G_loss", "disc_H_loss", "disc_R_loss", "cycle_H_loss", "cycle_R_loss", "identity_H_loss", "identity_R_loss"]
    metrics_history = {key: [] for key in history_keys}
    try:
        with open(log_filepath, 'r') as f:
            for line in f:
                if "Summary. Time:" in line and "Avgs:" in line:
                    try:
                        avg_part = line.split("Avgs:[")[1].split("]")[0]
                        pairs = avg_part.split()
                        epoch_data = {key.split(':')[0] + "_loss": float(key.split(':')[1]) for key in pairs}
                        for key in history_keys:
                            metrics_history[key].append(epoch_data.get(key, np.nan))
                    except (IndexError, ValueError): pass
    except FileNotFoundError: print(f"!!! Log file not found at {log_filepath}")
    return metrics_history
def plot_final_metrics(metrics_history, final_epoch, save_dir):
    num_logged_epochs = max(len(v) for v in metrics_history.values() if v)
    if num_logged_epochs == 0: return
    epochs = list(range(1, num_logged_epochs + 1))
    plot_groups = [{'title': 'Generator Adversarial Loss', 'keys': ['gen_G_loss']}, {'title': 'Discriminator Loss', 'keys': ['disc_H_loss', 'disc_R_loss']}, {'title': 'Cycle Consistency Loss', 'keys': ['cycle_H_loss', 'cycle_R_loss']}, {'title': 'Identity Loss', 'keys': ['identity_H_loss', 'identity_R_loss']}, {'title': 'FrÃ©chet Inception Distance (FID)', 'keys': ['fid_A_score', 'fid_B_score']},]
    fig, axes = plt.subplots(len(plot_groups), 1, figsize=(14, 7 * len(plot_groups)), sharex=True)
    if len(plot_groups) == 1: axes = [axes]
    for i, group in enumerate(plot_groups):
        ax = axes[i]
        for key in group['keys']:
            if key in metrics_history and any(not np.isnan(v) for v in metrics_history.get(key, [])):
                series_data = metrics_history[key]
                padded_series = series_data + [np.nan] * (num_logged_epochs - len(series_data))
                ax.plot(epochs, padded_series, marker='o', linestyle='-', label=key)
        if num_logged_epochs <= 20: ax.set_xticks(epochs)
        else: ax.set_xticks([e for e in epochs if e % 5 == 0 or e == 1])
        ax.set_title(group['title'], fontsize=14); ax.set_ylabel("Loss / Score"); ax.legend(); ax.grid(True)
    axes[-1].set_xlabel("Epochs"); fig.suptitle(f'Final Training Metrics up to Epoch {final_epoch}', fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.97]);
    save_path = os.path.join(save_dir, f'FINAL_metrics_graph_epoch_{final_epoch:04d}.png')
    plt.savefig(save_path); plt.close(fig); print(f"Saved final metrics graph to {save_path}")

# --- Main Execution ---
# Define paths to the TIFF image folders
FID_FAKE_H_DIR = os.path.join(EVAL_OUTPUT_DIR, "fid_fakes_H_from_R_tif") # Fake H&E
FID_FAKE_R_DIR = os.path.join(EVAL_OUTPUT_DIR, "fid_fakes_R_from_H_tif") # Fake Reticulin

# Calculate FID scores
fid_H_score = calculate_fid_from_tif_folders(PATH_TEST_H_FOLDER, FID_FAKE_H_DIR, "Generated H&E")
fid_R_score = calculate_fid_from_tif_folders(PATH_TEST_R_FOLDER, FID_FAKE_R_DIR, "Generated Reticulin")

# Find, parse, and plot
log_files = sorted(glob.glob(os.path.join(NOTEBOOK_CWD, "logs", f"train_log_{run_id_string}_*.log")))
if log_files:
    latest_log_file = log_files[-1]
    final_metrics_history = parse_log_for_metrics(latest_log_file)
    num_epochs = len(final_metrics_history['gen_G_loss'])
    final_metrics_history['fid_A_score'] = [np.nan] * num_epochs
    final_metrics_history['fid_B_score'] = [np.nan] * num_epochs
    if not np.isnan(fid_H_score) and EPOCH_EVALUATED <= num_epochs: final_metrics_history['fid_A_score'][EPOCH_EVALUATED-1] = fid_H_score
    if not np.isnan(fid_R_score) and EPOCH_EVALUATED <= num_epochs: final_metrics_history['fid_B_score'][EPOCH_EVALUATED-1] = fid_R_score
    plot_final_metrics(final_metrics_history, num_epochs, EVAL_OUTPUT_DIR)
else:
    print("!!! Log file not found. Cannot plot metrics.")

print("\n--- Evaluation Complete ---")