# Deep Analysis of Autoencoder Performance (L1 + Cosine + PSNR/SSIM)

This notebook analyzes the performance of different autoencoder models trained with varying latent and base channel configurations. We evaluate the models based on several metrics:

* **L2 Distance (MSE):** Measures the pixel-wise squared error between the original and reconstructed images.
* **Cosine Similarity:** Measures the similarity of the two image vectors.
* **Peak Signal-to-Noise Ratio (PSNR):** Measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher is better.
* **Structural Similarity Index (SSIM):** Measures the similarity between two images, considering luminance, contrast, and structure. Higher is better.

The analysis consists of the following sections:
1.  **Model and Dataset Definitions:** The core `Encoder`, `Decoder`, and `ImageDirectoryDataset` classes.
2.  **Configuration and Data Loading:** Setting up paths, parameters, and loading the analysis dataset.
3.  **Deep Analysis:** Iterating through trained models, generating reconstructions, and calculating the performance metrics.
4.  **Quantitative Visualization:** Creating bar plots to compare the average performance of all models across all metrics.
5.  **Qualitative Visualization:** Generating visual comparisons of original vs. reconstructed images for each model.

## SECTION 1: Core Imports and Definitions

First, we import the necessary libraries. We'll use `torch` for model building and inference, `PIL` for image handling, `numpy` for numerical operations, and `matplotlib`/`seaborn` for plotting. We also import `psnr` and `ssim` from `skimage.metrics` for our new evaluation metrics.

In [None]:
!pip install scikit-image


In [None]:
# --- Core Imports ---
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import random
import os
import re
import pandas as pd
from pathlib import Path

# --- Analysis and Plotting Imports ---
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from natsort import natsorted
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

# --- Set a consistent plotting style ---
sns.set_theme(style="whitegrid")

### Model and Dataset Classes

Here are the definitions for the `Encoder`, `Decoder`, and `ImageDirectoryDataset` classes. These are the same as in the training script to ensure that the model architecture is consistent.

In [None]:
from modules.encode_decoder import Encoder, Decoder, ImageDirectoryDataset

## SECTION 2: Configuration & Data Loading

We define the configuration parameters for the analysis. This includes the project directory, the experiment group to analyze, the directory with the source images, and other parameters. We then create a `DataLoader` for the analysis images.

In [None]:
BASE_PROJECT_DIR = Path("C:/Users/Hagai.LAPTOP-QAG9263N/Desktop/Thesis/notebooks")
EXPERIMENT_GROUP = 'latent_sweep'
SOURCE_DATA_DIR = BASE_PROJECT_DIR / "recolored_images"
NUM_SAMPLES_TO_ANALYZE = 500
BATCH_SIZE = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 1024
CHECKPOINT_PATH = BASE_PROJECT_DIR / "checkpoints" / "checkpoints" / EXPERIMENT_GROUP
PLOTS_DIR = BASE_PROJECT_DIR / "plots"

print(f"Using device: {DEVICE}")
print(f"Analyzing experiment group: '{EXPERIMENT_GROUP}'")
if not CHECKPOINT_PATH.exists():
    print(f"!!! FATAL ERROR: Checkpoint directory not found at {CHECKPOINT_PATH}")
    run_dirs = []
else:
    print(f"Loading models from: {CHECKPOINT_PATH}")
    run_dirs = natsorted([d.name for d in CHECKPOINT_PATH.iterdir() if d.is_dir()])

if not SOURCE_DATA_DIR.exists():
    print(f"!!! WARNING: Source data directory for analysis not found at {SOURCE_DATA_DIR}")

PLOTS_DIR.mkdir(parents=True, exist_ok=True)
print(f"Plots will be saved to: {PLOTS_DIR}")

analysis_dataloader = None
try:
    all_files = [f.name for f in SOURCE_DATA_DIR.iterdir() if f.is_file() and f.suffix.lower() in ['.png', '.jpg', '.jpeg']]
    if not all_files:
        raise FileNotFoundError(f"No images found in {SOURCE_DATA_DIR}")
    random.seed(42)
    analysis_files = random.sample(all_files, min(len(all_files), NUM_SAMPLES_TO_ANALYZE))
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    analysis_dataset = ImageDirectoryDataset(SOURCE_DATA_DIR, analysis_files, transform)
    analysis_dataloader = DataLoader(analysis_dataset, batch_size=BATCH_SIZE, shuffle=False)
    print(f"Created analysis dataset with {len(analysis_dataset)} images.")
except Exception as e:
    print(f"Could not create dataset. Please check the SOURCE_DATA_DIR path. Error: {e}")

## SECTION 3: Deep Analysis with L2, Cosine, PSNR, and SSIM

This is the main analysis loop. For each model found in the checkpoints directory:
1.  We parse the latent and base channels from the directory name.
2.  We load the corresponding trained `Encoder` and `Decoder`.
3.  We iterate through the analysis `DataLoader`, generate reconstructions for each image, and calculate L2 distance, Cosine Similarity, PSNR, and SSIM.
4.  We store the average of these metrics for each model in a list.

Finally, we create a Pandas DataFrame from the results for easy plotting and inspection.

In [None]:
analysis_results = []
for run_name in tqdm(run_dirs, desc="Analyzing All Runs"):
    checkpoint_run_path = CHECKPOINT_PATH / run_name
    encoder_path = checkpoint_run_path / "encoder_final.pt"
    decoder_path = checkpoint_run_path / "decoder_final.pt"

    if not encoder_path.exists() or not decoder_path.exists():
        print(f"Skipping '{run_name}': Missing 'encoder_final.pt' or 'decoder_final.pt'.")
        continue

    try:
        match = re.search(r"latent(\d+)_base(\d+)", run_name)
        if not match: raise ValueError("Directory name format is incorrect.")
        lc, bc = int(match.group(1)), int(match.group(2))
    except (ValueError, IndexError):
        print(f"Skipping '{run_name}': Could not parse latent/base channels from name.")
        continue

    try:
        encoder = Encoder(base_channels=bc, latent_channels=lc).to(DEVICE)
        decoder = Decoder(base_channels=bc, latent_channels=lc).to(DEVICE)
        encoder.load_state_dict(torch.load(encoder_path, map_location=DEVICE))
        decoder.load_state_dict(torch.load(decoder_path, map_location=DEVICE))
        encoder.eval(), decoder.eval()
    except Exception as e:
        print(f"Skipping '{run_name}': ARCHITECTURE MISMATCH or other loading error. Details: {e}")
        continue

    all_l2, all_cosine, all_psnr, all_ssim = [], [], [], []
    if analysis_dataloader:
        with torch.no_grad():
            for images in tqdm(analysis_dataloader, desc=f"Run: base={bc}, lat={lc}", leave=False):
                images = images.to(DEVICE)
                recon_images = decoder(encoder(images))
                
                # --- Metric Calculations ---
                l2_dist = nn.functional.mse_loss(recon_images, images, reduction='none').mean(dim=[1,2,3])
                cosine_sim = nn.functional.cosine_similarity(recon_images.view(images.size(0), -1), images.view(images.size(0), -1), dim=1)
                all_l2.extend(l2_dist.cpu().numpy())
                all_cosine.extend(cosine_sim.cpu().numpy())
                
                # For PSNR and SSIM, we need to un-normalize the images from [-1, 1] to [0, 1]
                images_un = (images + 1) / 2
                recon_images_un = (recon_images + 1) / 2
                
                for i in range(images.size(0)):
                    img_orig = images_un[i].permute(1, 2, 0).cpu().numpy()
                    img_recon = recon_images_un[i].permute(1, 2, 0).cpu().numpy()
                    
                    # Ensure pixel values are clipped to [0, 1] range for metrics
                    img_orig = np.clip(img_orig, 0, 1)
                    img_recon = np.clip(img_recon, 0, 1)

                    psnr_val = peak_signal_noise_ratio(img_orig, img_recon, data_range=1.0)
                    ssim_val = structural_similarity(img_orig, img_recon, multichannel=True, data_range=1.0, channel_axis=-1) # Use multichannel for RGB
                    
                    all_psnr.append(psnr_val)
                    all_ssim.append(ssim_val)

    if all_l2 and all_cosine:
        analysis_results.append({
            "run_id": run_name, "base_channels": bc, "latent_channels": lc,
            "avg_l2_distance": np.mean(all_l2), 
            "avg_cosine_similarity": np.mean(all_cosine),
            "avg_psnr": np.mean(all_psnr),
            "avg_ssim": np.mean(all_ssim)
        })

df_analysis = pd.DataFrame(analysis_results)
if not df_analysis.empty:
    df_analysis = df_analysis.sort_values(by=["base_channels", "latent_channels"]).reset_index(drop=True)

print("\n--- Analysis Complete ---")
if not df_analysis.empty: 
    display(df_analysis)
else: 
    print("No models were successfully analyzed.")

## SECTION 4: Quantitative Visualization

Now we create a 2x2 grid of bar plots to compare the performance of all models. Each plot represents one of our four metrics. This allows for a quick and easy comparison of the different model configurations.

In [None]:
if not df_analysis.empty:
    df_analysis['label'] = df_analysis.apply(
        lambda row: f"Latent={row['latent_channels']}\n(Base={row['base_channels']})", axis=1
    )
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle('Autoencoder Performance Comparison', fontsize=24)

    # L2 Distance Plot
    sns.barplot(x="label", y="avg_l2_distance", data=df_analysis, ax=axes[0, 0], palette="plasma")
    axes[0, 0].set_title("Average L2 Distance (MSE)", fontsize=16)
    axes[0, 0].set_xlabel("Model Configuration", fontsize=12)
    axes[0, 0].set_ylabel("Average L2 Distance (Lower is Better)", fontsize=12)
    axes[0, 0].tick_params(axis='x', rotation=45)

    # Cosine Similarity Plot
    sns.barplot(x="label", y="avg_cosine_similarity", data=df_analysis, ax=axes[0, 1], palette="viridis")
    axes[0, 1].set_title("Average Cosine Similarity", fontsize=16)
    axes[0, 1].set_xlabel("Model Configuration", fontsize=12)
    axes[0, 1].set_ylabel("Average Cosine Similarity (Higher is Better)", fontsize=12)
    axes[0, 1].tick_params(axis='x', rotation=45)
    
    # PSNR Plot
    sns.barplot(x="label", y="avg_psnr", data=df_analysis, ax=axes[1, 0], palette="cividis")
    axes[1, 0].set_title("Average Peak Signal-to-Noise Ratio (PSNR)", fontsize=16)
    axes[1, 0].set_xlabel("Model Configuration", fontsize=12)
    axes[1, 0].set_ylabel("Average PSNR (Higher is Better)", fontsize=12)
    axes[1, 0].tick_params(axis='x', rotation=45)

    # SSIM Plot
    sns.barplot(x="label", y="avg_ssim", data=df_analysis, ax=axes[1, 1], palette="magma")
    axes[1, 1].set_title("Average Structural Similarity Index (SSIM)", fontsize=16)
    axes[1, 1].set_xlabel("Model Configuration", fontsize=12)
    axes[1, 1].set_ylabel("Average SSIM (Higher is Better)", fontsize=12)
    axes[1, 1].tick_params(axis='x', rotation=45)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    save_path = PLOTS_DIR / f"quantitative_comparison_{EXPERIMENT_GROUP}_with_psnr_ssim.png"
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\nSaved quantitative comparison plot to: {save_path}")
    plt.show()
else:
    print("\nAnalysis DataFrame is empty. Cannot generate quantitative plots.")

## SECTION 5: Qualitative Visualization

Finally, we generate qualitative comparisons for each model. We select a few random sample images and for each model, we plot the original image next to its reconstruction. The title of the reconstruction now includes all four metrics (L2, Cosine Similarity, PSNR, and SSIM) to provide a complete picture of the model's performance on that specific image.

In [None]:
if not df_analysis.empty and analysis_dataloader:
    NUM_VISUAL_SAMPLES = 4 # Show 4 samples per model
    print(f"\nGenerating {len(df_analysis)} qualitative plots with {NUM_VISUAL_SAMPLES} samples each...")

    # Select 4 random sample images ONCE to use for all models for fair comparison
    sample_indices = random.sample(range(len(analysis_dataset)), NUM_VISUAL_SAMPLES)
    sample_images = torch.stack([analysis_dataset[i] for i in sample_indices]).to(DEVICE)

    # Loop through each model in the analysis results
    for idx, row in df_analysis.iterrows():
        lc, bc = row['latent_channels'], row['base_channels']
        run_id = row['run_id']
        checkpoint_run_path = CHECKPOINT_PATH / run_id

        fig, axes = plt.subplots(NUM_VISUAL_SAMPLES, 2, figsize=(12, 22))
        fig.suptitle(f'Reconstruction Quality for Model: {run_id}\n($C_{{base}}={bc}$, $C_{{lat}}={lc}$)', fontsize=18)

        # Reload the specific model for this row
        encoder = Encoder(base_channels=bc, latent_channels=lc).to(DEVICE)
        decoder = Decoder(base_channels=bc, latent_channels=lc).to(DEVICE)
        encoder.load_state_dict(torch.load(checkpoint_run_path / "encoder_final.pt", map_location=DEVICE))
        decoder.load_state_dict(torch.load(checkpoint_run_path / "decoder_final.pt", map_location=DEVICE))
        encoder.eval()
        decoder.eval()

        with torch.no_grad():
            recon_samples = decoder(encoder(sample_images))

        def prep_img(tensor_img):
            return np.clip(tensor_img.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5, 0, 1)

        for i in range(NUM_VISUAL_SAMPLES):
            original, recon = sample_images[i], recon_samples[i]
            
            # Calculate metrics for this specific sample
            l2 = nn.functional.mse_loss(recon, original).item()
            cosine = nn.functional.cosine_similarity(recon.view(1, -1), original.view(1, -1)).item()
            
            original_un = (original + 1) / 2
            recon_un = (recon + 1) / 2
            img_orig_np = np.clip(original_un.permute(1, 2, 0).cpu().numpy(), 0, 1)
            img_recon_np = np.clip(recon_un.permute(1, 2, 0).cpu().numpy(), 0, 1)
            
            psnr_val = peak_signal_noise_ratio(img_orig_np, img_recon_np, data_range=1.0)
            ssim_val = structural_similarity(img_orig_np, img_recon_np, multichannel=True, data_range=1.0, channel_axis=-1)

            axes[i, 0].imshow(prep_img(original))
            axes[i, 0].set_title(f"Original Sample #{i+1}")
            axes[i, 0].axis('off')

            axes[i, 1].imshow(prep_img(recon))
            axes[i, 1].set_title(f"Reconstruction\nL2: {l2:.4f} | CosSim: {cosine:.4f}\nPSNR: {psnr_val:.2f} | SSIM: {ssim_val:.4f}")
            axes[i, 1].axis('off')

        plt.tight_layout(rect=[0, 0, 1, 0.95])
        save_path = PLOTS_DIR / f"qualitative_recon_{run_id}_with_psnr_ssim.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved qualitative plot to: {save_path}")
        plt.show()
        plt.close(fig)

else:
    print("\nAnalysis DataFrame is empty. Cannot generate qualitative comparisons.")