In [None]:
import cv2
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import traceback

# --- Import from your project structure ---
try:
    from distortions.down_sampling import apply_DownSampling
    from restorations.traditional.Down_Sampling.bicubic_interpolation import apply_bicubic_interpolation
    from restorations.deep_learning.esrgan_model import SuperResolutionModel
    print("All custom modules imported successfully.")
except ImportError as e:
    print(f"ERROR: Could not import modules. {e}")
    print("Please ensure 'bicubic.py' and 'esrgan_model.py' exist in the correct 'restorations/' subfolders.")
    print("And ensure 'apply_DownSampling' is in 'distortions/down_sampling.py'.")

# --- Import Metric Functions (from skimage) ---
try:
    from skimage.metrics import structural_similarity as ssim_sk
    print("Metrics library (skimage) imported.")
except ImportError:
    print("ERROR: skimage not found. Please run: pip install scikit-image")


# ===================================================================
# === METRIC FUNCTIONS (Added directly to the notebook) ===
# ===================================================================

def calculate_psnr(img1, img2):
    """Calculates the PSNR between two images."""
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

def calculate_ssim(img1, img2):
    """Calculates the SSIM between two images."""
    if img1.ndim == 2:
        return ssim_sk(img1, img2, data_range=255.0)
    else:
        # Use win_size=7 for robustness, matching common implementations
        return ssim_sk(img1, img2, multichannel=True, channel_axis=2, data_range=255.0, win_size=7)

def calculate_metrics_pair(original, restored):
    """
    Calculates both PSNR and SSIM and returns them.
    Ensures images are the same size.
    """
    h, w = original.shape[:2]
    if original.shape[:2] != restored.shape[:2]:
        restored = cv2.resize(restored, (w, h), interpolation=cv2.INTER_AREA)
    
    psnr = calculate_psnr(original, restored)
    ssim_val = calculate_ssim(original, restored)
    
    return psnr, ssim_val

print("All metric calculation functions are defined.")

# ===================================================================
# === CONFIGURATION ===
# ===================================================================

# NOTE: Paths are relative to the notebook's location (root folder)
ORIGINAL_DIR = "data/original"
DISTORTED_DIR = "data/distorted/downsampled_x2" # Directory for LOW-RES images
RESULTS_DIR = "results/for_Downsampling_x2"

# --- Model Parameters ---
SCALE_FACTOR = 2   # <-- CHANGE THIS FROM 4 TO 2
MODEL_FILE_NAME = 'RRDB_ESRGAN_x4.pth' # <-- The name doesn't matter, we know it's a x2 model
MODEL_PATH = os.path.join('models', MODEL_FILE_NAME)

# --- Visualization Parameters ---
IMAGES_FOR_VISUALIZATION = ["0267.png", "0060.png", "0268.png"]
CROP_REGIONS = {
    "0267.png": (1178, 623, 128, 128), 
    "0060.png": (1026, 659, 128, 128),
    "0268.png": (1045, 562, 128, 128)
}

# ===================================================================
# === ANALYSIS & PLOTTING FUNCTIONS ===
# ===================================================================

def setup_directories():
    """Creates necessary directories."""
    print(f"Ensuring directories exist: {DISTORTED_DIR}, {RESULTS_DIR}")
    os.makedirs(DISTORTED_DIR, exist_ok=True)
    os.makedirs(RESULTS_DIR, exist_ok=True)

def generate_downsampled_images(image_files):
    """Generates and saves low-resolution images if they don't already exist."""
    print(f"Checking/Generating downsampled images with scale={SCALE_FACTOR}...")
    count = 0
    image_paths = [os.path.join(ORIGINAL_DIR, f) for f in image_files]
    
    for img_path in image_paths:
        filename = os.path.basename(img_path)
        save_path = os.path.join(DISTORTED_DIR, filename)
        
        if not os.path.exists(save_path):
            img = cv2.imread(img_path)
            if img is None: 
                print(f"Warning: Could not read {img_path}, skipping.")
                continue
                
            h, w = img.shape[:2]
            h_new = (h // SCALE_FACTOR) * SCALE_FACTOR
            w_new = (w // SCALE_FACTOR) * SCALE_FACTOR
            if h != h_new or w != w_new:
                print(f"  - Cropping {filename} from {h}x{w} to {h_new}x{w_new}")
                img = img[:h_new, :w_new]
            
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            lr_image = apply_DownSampling(img_rgb, scale_factor=SCALE_FACTOR)
            cv2.imwrite(save_path, cv2.cvtColor(lr_image, cv2.COLOR_RGB2BGR))
            count += 1
            
    if count > 0: print(f"Generated {count} new downsampled (low-res) images.")
    else: print("Downsampled (low-res) images already exist.")

def run_full_analysis(image_files, model):
    """Calculates PSNR/SSIM for all restoration methods."""
    results = []
    
    for i, filename in enumerate(image_files):
        print(f"\nProcessing metrics for image {i+1}/{len(image_files)}: {filename}")
        
        original_path = os.path.join(ORIGINAL_DIR, filename)
        distorted_path = os.path.join(DISTORTED_DIR, filename)
        
        original_img_bgr = cv2.imread(original_path)
        distorted_img_bgr = cv2.imread(distorted_path)
        
        if original_img_bgr is None or distorted_img_bgr is None: 
            print(f"  - Skipping {filename}, could not load images.")
            continue
        
        h, w = original_img_bgr.shape[:2]
        h_new = (h // SCALE_FACTOR) * SCALE_FACTOR
        w_new = (w // SCALE_FACTOR) * SCALE_FACTOR
        original_img_bgr = original_img_bgr[:h_new, :w_new]
        
        distorted_img_rgb = cv2.cvtColor(distorted_img_bgr, cv2.COLOR_BGR2RGB)

        # --- Apply Restoration Methods ---
        target_shape = (original_img_bgr.shape[1], original_img_bgr.shape[0])
        
        # 1. Distorted (Nearest Neighbor Baseline)
        restored_distorted = cv2.resize(distorted_img_bgr, target_shape, interpolation=cv2.INTER_NEAREST)
        
        # 2. Traditional (Bicubic)
        restored_bicubic_rgb = apply_bicubic_interpolation(distorted_img_rgb, scale_factor=SCALE_FACTOR)
        restored_bicubic = cv2.cvtColor(restored_bicubic_rgb, cv2.COLOR_RGB2BGR)
        
        # 3. Deep Learning (ESRGAN)
        restored_dl_rgb = model.upscale(distorted_img_rgb)
        restored_dl = cv2.cvtColor(restored_dl_rgb, cv2.COLOR_RGB2BGR)

        # --- Calculate Metrics ---
        metrics_data = {'Image': filename}
        methods = {
            'Distorted': restored_distorted,
            'Bicubic': restored_bicubic,
            'DeepLearning': restored_dl
        }
        
        for name, restored_image in methods.items():
            # The 'calculate_metrics_pair' function handles any minor shape mismatches
            psnr_val, ssim_val = calculate_metrics_pair(original_img_bgr, restored_image)
            metrics_data[f'PSNR_{name}'] = psnr_val
            metrics_data[f'SSIM_{name}'] = ssim_val
            print(f"  - {name}: PSNR={psnr_val:.2f}, SSIM={ssim_val:.4f}")

        results.append(metrics_data)
        
    return pd.DataFrame(results)

def plot_and_save_results(df):
    """Prints table and saves quantitative chart."""
    csv_path = os.path.join(RESULTS_DIR, f"metrics_scale{SCALE_FACTOR}.csv")
    chart_path = os.path.join(RESULTS_DIR, f"metrics_chart_scale{SCALE_FACTOR}.png")
    
    print("\n--- FINAL METRICS TABLE ---")
    print(df.round(4).to_string(index=False))
    df.to_csv(csv_path, index=False)
    print(f"\nResults table saved to {csv_path}")

    avg_psnr = df.filter(like='PSNR').mean(skipna=True)
    avg_ssim = df.filter(like='SSIM').mean(skipna=True)
    
    plot_df = pd.DataFrame({
        'Method': ['Distorted (Nearest)', 'Traditional (Bicubic)', 'Deep Learning (ESRGAN)'],
        'PSNR': [avg_psnr.get('PSNR_Distorted',np.nan), avg_psnr.get('PSNR_Bicubic',np.nan), avg_psnr.get('PSNR_DeepLearning',np.nan)],
        'SSIM': [avg_ssim.get('SSIM_Distorted',np.nan), avg_ssim.get('SSIM_Bicubic',np.nan), avg_ssim.get('SSIM_DeepLearning',np.nan)]
    })

    fig, ax1 = plt.subplots(figsize=(12, 7))
    plot_df.plot(kind='bar', x='Method', y='PSNR', ax=ax1, color='skyblue', position=1, width=0.4, legend=False)
    ax1.set_ylabel('PSNR (dB) - Higher is Better', color='skyblue')
    ax1.tick_params(axis='y', labelcolor='skyblue')
    ax1.set_xticklabels(plot_df['Method'], rotation=15)
    
    ax2 = ax1.twinx()
    plot_df_reset = plot_df.reset_index() # Use index for line plot
    plot_df_reset.plot(kind='line', x='index', y='SSIM', ax=ax2, color='salmon', marker='o', legend=False)
    ax2.set_ylabel('SSIM - Higher is Better', color='salmon')
    ax2.tick_params(axis='y', labelcolor='salmon')
    
    ssim_min = plot_df['SSIM'].min(skipna=True)
    if not np.isnan(ssim_min): ax2.set_ylim(bottom=max(0, ssim_min - 0.05), top=1.0)
    else: ax2.set_ylim(bottom=0, top=1.0)
    
    fig.legend(labels=['PSNR (dB)', 'SSIM'], loc='upper center', bbox_to_anchor=(0.5, 0.95), ncol=2)
    plt.title(f'Average Restoration Performance (Downsampling x{SCALE_FACTOR})', pad=40)
    plt.tight_layout()
    plt.savefig(chart_path)
    print(f"Comparison chart saved to {chart_path}")
    plt.close(fig)

def generate_close_up_plot(original_img, distorted_img_view, traditional_img, dl_img, crop_region, save_path_prefix):
    """Crops the four input images and saves a comparison plot."""
    x, y, w, h = crop_region
    h_orig, w_orig = original_img.shape[:2]
    if y + h > h_orig or x + w > w_orig:
        print(f"   - Warning: Crop region {crop_region} is outside image bounds ({w_orig}x{h_orig}). Skipping close-up.")
        return

    # All inputs should be RGB
    images_to_crop = [original_img, distorted_img_view, traditional_img, dl_img]
    cropped_images = [img[y:y+h, x:x+w] for img in images_to_crop]
    titles = ["Original (Close-up)", "Distorted (Nearest)", "Traditional (Bicubic)", "Deep Learning (ESRGAN)"]

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    fig.suptitle(f"Close-up Comparison @ ({x},{y}) [{os.path.basename(save_path_prefix)}]", fontsize=16)

    for ax, img, title in zip(axes, cropped_images, titles):
        ax.imshow(img)
        ax.set_title(title)
        ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.92])
    save_path = f"{save_path_prefix}_closeup_scale{SCALE_FACTOR}.png"
    plt.savefig(save_path)
    plt.close(fig)
    print(f"   - Close-up plot saved to {save_path}")

def visualize_comparison(images, distortion_fn, traditional_fn, deep_learning_fn, save_path=None):
    num_images = len(images)
    if num_images == 0: return

    fig, axes = plt.subplots(num_images, 4, figsize=(16, 4 * num_images))
    if num_images < 3: fig.set_size_inches(16, 6 * num_images)
    fig.suptitle("Image Super-Resolution Comparison", fontsize=16)
    col_titles = ["Original", "Distorted (Nearest)", "Traditional (Bicubic)", "Deep Learning (ESRGAN)"]
    if num_images == 1: axes = np.array([axes])

    for j, title in enumerate(col_titles):
        axes[0, j].set_title(title, fontsize=12)

    for i, img_path in enumerate(images):
        img = cv2.imread(img_path)
        if img is None:
            print(f"Warning: Could not read image at {img_path} for viz. Skipping row.")
            for j in range(4):
                 axes[i, j].imshow(np.zeros((100, 100, 3), dtype=np.uint8))
                 axes[i, j].axis("off")
            continue
            
        h, w = img.shape[:2]
        h_new = (h // SCALE_FACTOR) * SCALE_FACTOR
        w_new = (w // SCALE_FACTOR) * SCALE_FACTOR
        img = img[:h_new, :w_new]
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        print(f"\nProcessing {os.path.basename(img_path)} for main visualization...")
        
        distorted_lr = distortion_fn(img_rgb)
        distorted_view = cv2.resize(distorted_lr, (img_rgb.shape[1], img_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
        restored_trad = traditional_fn(distorted_lr)
        restored_dl = deep_learning_fn(distorted_lr)
        
        display_images = [img_rgb, distorted_view, restored_trad, restored_dl]

        for j, disp_img in enumerate(display_images):
            axes[i, j].imshow(disp_img)
            axes[i, j].axis("off")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    if save_path:
        plt.savefig(save_path)
        print(f"--> Main comparison plot SAVED to: {save_path}")
        plt.close(fig) 
    else:
        plt.show()

# ===================================================================
# === MAIN EXECUTION ===
# ===================================================================

print("--- Analysis Script Started ---")
try:
    setup_directories()

    if not os.path.exists(MODEL_PATH):
        print(f"FATAL ERROR: Model file not found at {MODEL_PATH}")
        print(f"Please make sure '{MODEL_FILE_NAME}' is in the 'models/' directory.")
    else:
        print(f"Found model file: {MODEL_PATH}")
        print("Loading Super-Resolution model...")
        model = SuperResolutionModel(model_path=MODEL_PATH, scale=SCALE_FACTOR)
        print("Model loaded successfully.")
        
        all_image_files = [f for f in os.listdir(ORIGINAL_DIR) if f.endswith(('.png', '.jpg', '.jpeg'))]
        if not all_image_files:
            print(f"Error: No images found in '{ORIGINAL_DIR}'.")
        else:
            generate_downsampled_images(all_image_files)
            metrics_df = run_full_analysis(all_image_files, model)

            if not metrics_df.empty:
                plot_and_save_results(metrics_df)
            else:
                print("Metrics DataFrame is empty. Cannot plot results.")

            print("\n--- Generating Visual Plots ---")
            viz_paths = [os.path.join(ORIGINAL_DIR, f) for f in IMAGES_FOR_VISUALIZATION if f in all_image_files]
            
            if viz_paths:
                main_viz_save_path = os.path.join(RESULTS_DIR, f"main_comparison_scale{SCALE_FACTOR}.png")
                print(f"\nGenerating and saving full-size comparison plot to {main_viz_save_path}...")
                
                visualize_comparison(
                    images=viz_paths,
                    distortion_fn=lambda img: apply_DownSampling(img, scale_factor=SCALE_FACTOR),
                    traditional_fn=lambda lr_img: apply_bicubic_interpolation(lr_img, scale_factor=SCALE_FACTOR),
                    deep_learning_fn=lambda lr_img: model.upscale(lr_img),
                    save_path=main_viz_save_path
                )

                print("\nGenerating close-up comparison plots...")
                for img_path in viz_paths:
                    filename = os.path.basename(img_path)
                    if filename not in CROP_REGIONS:
                        print(f"   - Warning: No crop region defined for {filename}. Skipping close-up.")
                        continue

                    print(f"   - Creating close-up for {filename}")
                    original_img = cv2.imread(img_path)
                    if original_img is None: continue
                        
                    h, w = original_img.shape[:2]
                    h_new = (h // SCALE_FACTOR) * SCALE_FACTOR
                    w_new = (w // SCALE_FACTOR) * SCALE_FACTOR
                    original_img = original_img[:h_new, :w_new]
                    original_rgb = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
                    
                    distorted_lr = apply_DownSampling(original_rgb, scale_factor=SCALE_FACTOR)
                    distorted_view = cv2.resize(distorted_lr, (original_rgb.shape[1], original_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
                    trad_restored = apply_bicubic_interpolation(distorted_lr, scale_factor=SCALE_FACTOR)
                    dl_restored = model.upscale(distorted_lr)

                    close_up_save_prefix = os.path.join(RESULTS_DIR, os.path.splitext(filename)[0])
                    crop_region = CROP_REGIONS[filename]

                    generate_close_up_plot(
                        original_img=original_rgb, 
                        distorted_img_view=distorted_view, 
                        traditional_img=trad_restored, 
                        dl_img=dl_restored,
                        crop_region=crop_region, 
                        save_path_prefix=close_up_save_prefix
                    )
            else:
                print("Skipping visual comparisons as no valid images were selected or found.")

except Exception as e:
    print(f"\n--- AN ERROR OCCURRED ---")
    print(f"An error occurred during model loading or analysis: {e}")
    traceback.print_exc()

print("\n--- Analysis Script Finished ---")