In [1]:
import cv2
import os
import pandas as pd
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
import torch

# importing the necessary functions
from restorations.traditional.Gaussian_Noise import apply_median_filter, apply_non_local_means
from restorations.deep_learning import apply_deep_learning_denoiser
from distortions.gaussian_noise import apply_GaussianNoise

# --- configuration ---
ORIGINAL_DIR = "data/original"
DISTORTED_DIR = "data/distorted/gaussian_noise"
RESULTS_DIR = "results/for_GaussianNoise"
NOISE_SIGMA = 50
IMAGES_FOR_VISUALIZATION = ["0267.png", "0060.png", "0268.png"]

CROP_REGIONS = {
    "0267.png": (1178, 623, 128, 128),
    "0060.png": (1026, 659, 128, 128),
    "0268.png": (1045, 562, 128, 128)
}

# === Function Definitions ===

#   Visualizer Function
def visualize_comparison(images, distortion_fn, traditional_fn, deep_learning_fn, save_path=None):
    num_images = len(images)
    if num_images == 0:
        print("No images provided for visualization.")
        return

    fig, axes = plt.subplots(num_images, 4, figsize=(16, 4 * num_images))
    if num_images < 3: fig.set_size_inches(16, 6 * num_images)
    fig.suptitle("Image Restoration Comparison", fontsize=16)
    col_titles = ["Original", "Distorted", "Traditional (NLM)", "Deep Learning"]
    if num_images == 1: axes = np.array([axes])

    for j, title in enumerate(col_titles):
        axes[0, j].set_title(title, fontsize=12)

    for i, img_path in enumerate(images):
        img = cv2.imread(img_path)
        if img is None:
            print(f"Warning: Could not read image at {img_path} for viz. Skipping row.")
            for j in range(4):
                 axes[i, j].imshow(np.zeros((100, 100, 3), dtype=np.uint8))
                 axes[i, j].axis("off")
            continue

        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        print(f"\nProcessing {os.path.basename(img_path)} for main visualization...")
        distorted = distortion_fn(img_rgb)
        restored_trad = traditional_fn(distorted)
        restored_dl = deep_learning_fn(distorted)
        display_images = [img_rgb, distorted, restored_trad, restored_dl]

        for j, disp_img in enumerate(display_images):
            axes[i, j].imshow(disp_img)
            axes[i, j].axis("off")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    if save_path:
        try:
            plt.savefig(save_path)
            print(f"--> Main comparison plot SAVED to: {save_path}")
        except Exception as e:
            print(f"Error saving main comparison plot: {e}")
        plt.close(fig) 
    else:
        print("--> Displaying main comparison plot...")
        plt.show() 

def setup_directories():
    """Creates necessary directories."""
    print(f"Ensuring directories exist: {DISTORTED_DIR}, {RESULTS_DIR}")
    os.makedirs(DISTORTED_DIR, exist_ok=True)
    os.makedirs(RESULTS_DIR, exist_ok=True)

def generate_distorted_images(image_files):
    """Generates and saves noisy images if they don't already exist."""
    print(f"Checking/Generating distorted images with sigma={NOISE_SIGMA}...")
    count = 0
    image_paths = [os.path.join(ORIGINAL_DIR, f) for f in image_files]
    for img_path in image_paths:
        filename = os.path.basename(img_path)
        save_path = os.path.join(DISTORTED_DIR, filename)
        if not os.path.exists(save_path):
            img = cv2.imread(img_path)
            if img is None: continue
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            noisy_image = apply_GaussianNoise(img_rgb, sigma=NOISE_SIGMA)
            cv2.imwrite(save_path, cv2.cvtColor(noisy_image, cv2.COLOR_RGB2BGR))
            count += 1
    if count > 0: print(f"Generated {count} new distorted images.")
    else: print("Distorted images already exist.")

def run_full_analysis(image_files):
    """Calculates PSNR/SSIM for all restoration methods."""
    results = []
    try:
        if apply_deep_learning_denoiser is not None:
             print("Pre-loading DL model...")
             _ = apply_deep_learning_denoiser(np.zeros((64,64,3), dtype=np.uint8), sigma=NOISE_SIGMA)
             print("DL model pre-loaded/checked.")
    except Exception as e: print(f"Warning: Could not pre-load DL model: {e}")

    for i, filename in enumerate(image_files):
        print(f"\nProcessing metrics for image {i+1}/{len(image_files)}: {filename}")
        original_img = cv2.imread(os.path.join(ORIGINAL_DIR, filename))
        distorted_img = cv2.imread(os.path.join(DISTORTED_DIR, filename))
        if original_img is None or distorted_img is None: continue
        distorted_rgb = cv2.cvtColor(distorted_img, cv2.COLOR_BGR2RGB)

        restored_median = apply_median_filter(distorted_rgb)
        restored_nl_means = apply_non_local_means(distorted_rgb)
        restored_dl = apply_deep_learning_denoiser(distorted_rgb, sigma=NOISE_SIGMA)

        metrics_data = {'Image': filename}
        methods = {
            'Distorted': distorted_img, 'Median': cv2.cvtColor(restored_median, cv2.COLOR_RGB2BGR),
            'NL_Means': cv2.cvtColor(restored_nl_means, cv2.COLOR_RGB2BGR), 'DeepLearning': restored_dl
        }
        for name, restored_image in methods.items():
            if restored_image is None or restored_image.shape != original_img.shape:
                 metrics_data[f'PSNR_{name}'] = np.nan; metrics_data[f'SSIM_{name}'] = np.nan
                 continue
            metrics_data[f'PSNR_{name}'] = psnr(original_img, restored_image, data_range=255)
            h, w, _ = original_img.shape; win_size = min(7, h, w); win_size -= (1 - win_size % 2)
            if win_size < 3: win_size = 3
            metrics_data[f'SSIM_{name}'] = ssim(original_img, restored_image, data_range=255, channel_axis=2, win_size=win_size)
        results.append(metrics_data)
    return pd.DataFrame(results)

def plot_and_save_results(df):
    """Prints table and saves quantitative chart."""
    csv_path = os.path.join(RESULTS_DIR, f"metrics_sigma{NOISE_SIGMA}.csv")
    chart_path = os.path.join(RESULTS_DIR, f"metrics_chart_sigma{NOISE_SIGMA}.png")
    print("\n--- FINAL METRICS TABLE ---")
    print(df.round(4).to_string(index=False))
    df.to_csv(csv_path, index=False)
    print(f"\nResults table saved to {csv_path}")

    avg_psnr = df.filter(like='PSNR').mean(skipna=True)
    avg_ssim = df.filter(like='SSIM').mean(skipna=True)
    plot_df = pd.DataFrame({
        'Method': ['Distorted', 'Median Filter', 'Non-local Means', 'Deep Learning'],
        'PSNR': [avg_psnr.get('PSNR_Distorted',np.nan), avg_psnr.get('PSNR_Median',np.nan), avg_psnr.get('PSNR_NL_Means',np.nan), avg_psnr.get('PSNR_DeepLearning',np.nan)],
        'SSIM': [avg_ssim.get('SSIM_Distorted',np.nan), avg_ssim.get('SSIM_Median',np.nan), avg_ssim.get('SSIM_NL_Means',np.nan), avg_ssim.get('SSIM_DeepLearning',np.nan)]
    })

    fig, ax1 = plt.subplots(figsize=(12, 7))
    plot_df.plot(kind='bar', x='Method', y='PSNR', ax=ax1, color='skyblue', position=1, width=0.4, legend=False)
    ax1.set_ylabel('PSNR (dB) - Higher is Better', color='skyblue')
    plt.xticks(rotation=15)
    ax2 = ax1.twinx()
    plot_df.plot(kind='line', x='Method', y='SSIM', ax=ax2, color='salmon', marker='o', legend=False)
    ax2.set_ylabel('SSIM - Higher is Better', color='salmon')
    ssim_min = plot_df['SSIM'].min(skipna=True)
    if not np.isnan(ssim_min): ax2.set_ylim(bottom=max(0, ssim_min - 0.05), top=1.0)
    else: ax2.set_ylim(bottom=0, top=1.0)
    fig.legend(labels=['PSNR (dB)', 'SSIM'], loc='upper center', bbox_to_anchor=(0.5, 0.95), ncol=2)
    plt.title(f'Average Restoration Performance (Gaussian Noise Sigma={NOISE_SIGMA})', pad=40)
    plt.tight_layout()
    plt.savefig(chart_path)
    print(f"Comparison chart saved to {chart_path}")
    plt.close(fig)

def generate_close_up_plot(original_img, distorted_img, traditional_img, dl_img, crop_region, save_path_prefix):
    """Crops the four input images and saves a comparison plot."""
    x, y, w, h = crop_region
    h_orig, w_orig, _ = original_img.shape
    if y + h > h_orig or x + w > w_orig:
        print(f"  - Warning: Crop region {crop_region} is outside image bounds ({w_orig}x{h_orig}). Skipping close-up.")
        return

    images_to_crop = [original_img, distorted_img, traditional_img, dl_img]
    cropped_images = [img[y:y+h, x:x+w] for img in images_to_crop]
    titles = ["Original (Close-up)", "Distorted", "Traditional (NLM)", "Deep Learning"]

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    fig.suptitle(f"Close-up Comparison @ ({x},{y}) [{os.path.basename(save_path_prefix)}]", fontsize=16)

    for ax, img, title in zip(axes, cropped_images, titles):
        ax.imshow(img)
        ax.set_title(title)
        ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.92])
    save_path = f"{save_path_prefix}_closeup_sigma{NOISE_SIGMA}.png"
    plt.savefig(save_path)
    plt.close(fig)
    print(f"  - Close-up plot saved to {save_path}")

if __name__ == "__main__":
    setup_directories()

    all_image_files = [f for f in os.listdir(ORIGINAL_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))]
    if not all_image_files:
        print(f"Error: No images found in '{ORIGINAL_DIR}'.")
    else:
        generate_distorted_images(all_image_files)
        metrics_df = run_full_analysis(all_image_files)

        if not metrics_df.empty:
            plot_and_save_results(metrics_df)
        else:
            print("Metrics DataFrame is empty. Cannot plot results.")

        print("\n--- Generating Visual Plots ---")
        viz_paths = [os.path.join(ORIGINAL_DIR, f) for f in IMAGES_FOR_VISUALIZATION if f in all_image_files]
        if viz_paths:
            
            main_viz_save_path = os.path.join(RESULTS_DIR, f"main_comparison_sigma{NOISE_SIGMA}.png")
            print(f"\nGenerating and saving full-size comparison plot to {main_viz_save_path}...")
            
            visualize_comparison(
                images=viz_paths,
                distortion_fn=lambda img: apply_GaussianNoise(img, sigma=NOISE_SIGMA),
                traditional_fn=apply_non_local_means,
                deep_learning_fn=lambda img: apply_deep_learning_denoiser(img, sigma=NOISE_SIGMA),
                save_path=main_viz_save_path # Pass the save path
            )

            print("\nGenerating close-up comparison plots...")
            for img_path in viz_paths:
                filename = os.path.basename(img_path)
                if filename not in CROP_REGIONS:
                    print(f"  - Warning: No crop region defined for {filename}. Skipping close-up.")
                    continue

                print(f"  - Creating close-up for {filename}")
                original_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
                distorted_rgb = apply_GaussianNoise(original_rgb, sigma=NOISE_SIGMA)
                trad_restored = apply_non_local_means(distorted_rgb)
                dl_restored = apply_deep_learning_denoiser(distorted_rgb, sigma=NOISE_SIGMA)

                close_up_save_prefix = os.path.join(RESULTS_DIR, os.path.splitext(filename)[0])
                crop_region = CROP_REGIONS[filename]

                generate_close_up_plot(
                    original_img=original_rgb, distorted_img=distorted_rgb,
                    traditional_img=trad_restored, dl_img=dl_restored,
                    crop_region=crop_region, save_path_prefix=close_up_save_prefix
                )
        else:
            print("Skipping visual comparisons as no valid images were selected or found.")

    print("\n--- Analysis Script Finished ---")

Ensuring directories exist: data/distorted/gaussian_noise, results/for_GaussianNoise
Checking/Generating distorted images with sigma=50...
Distorted images already exist.
Pre-loading DL model...
Loading deep learning model (DRUNet Color)...
Downloading pre-trained model weights to restorations/deep_learning/models\drunet_color.pth...
Download complete.
Deep learning model loaded successfully.
DL model pre-loaded/checked.

Processing metrics for image 1/3: 0060.png

Processing metrics for image 2/3: 0267.png

Processing metrics for image 3/3: 0268.png

--- FINAL METRICS TABLE ---
   Image  PSNR_Distorted  SSIM_Distorted  PSNR_Median  SSIM_Median  PSNR_NL_Means  SSIM_NL_Means  PSNR_DeepLearning  SSIM_DeepLearning
0060.png         14.8762          0.1975      22.0417       0.4394        23.6842         0.7204            18.6585             0.8359
0267.png         14.8403          0.0972      24.7496       0.4348        26.5076         0.7778            10.4654             0.7824
0268.png 