In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import math
from tqdm import tqdm 

In [None]:
import torch
import piq
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
PREPROCESSED_DIR = "../ProcessedImages" 
OUTPUT_DIR = "../AMSRResult"
SIGMA_LIST = [15, 80, 250] # [Kecil, Menengah, Besar]

os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# --- Fungsi Helper ---
def plot_comparison(img_original, img_processed, title_original="Original", title_processed="Processed"):
    # ... (Tidak ada perubahan di sini, sama seperti kode Anda) ...
    fig, axes = plt.subplots(1, 2, figsize=(15, 7))
    axes[0].imshow(cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB))
    axes[0].set_title(title_original)
    axes[0].axis("off")
    axes[1].imshow(cv2.cvtColor(img_processed, cv2.COLOR_BGR2RGB))
    axes[1].set_title(title_processed)
    axes[1].axis("off")
    plt.tight_layout()
    plt.show()

def load_image_paths(directory):
    # ... (Tidak ada perubahan di sini, sama seperti kode Anda) ...
    image_paths = []
    valid_extensions = {".jpg", ".jpeg", ".png", ".bmp"}
    for root, _, files in os.walk(directory):
        for file in files:
            if os.path.splitext(file)[1].lower() in valid_extensions:
                image_paths.append(os.path.join(root, file))
    print(f"Total {len(image_paths)} gambar ditemukan.")
    return image_paths

In [None]:
# --- Implementasi Algoritma Retinex (Inti) ---

def single_scale_retinex(image, sigma):
    # ... (Tidak ada perubahan di sini, sama seperti kode Anda) ...
    image_log = np.log10(image.astype(float) + 1.0)
    illumination_log = np.log10(cv2.GaussianBlur(image.astype(float), (0, 0), sigma) + 1.0)
    reflectance_log = image_log - illumination_log
    return reflectance_log

def get_adaptive_weights(image):
    # ... (Tidak ada perubahan di sini, sama seperti kode Anda) ...
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    mean_intensity = np.mean(gray)
    if mean_intensity < 60:
        W1, W2, W3 = 0.4, 0.2, 0.4
    elif mean_intensity < 120:
        W1, W2, W3 = 0.25, 0.5, 0.25
    else:
        W1, W2, W3 = 1/3.0, 1/3.0, 1/3.0
    return [W1, W2, W3]

def normalize_and_convert(msr_log_image):
    # ... (Tidak ada perubahan di sini, sama seperti kode Anda) ...
    msr_8bit = np.zeros_like(msr_log_image, dtype=np.uint8)
    for i in range(msr_log_image.shape[2]):
        channel = msr_log_image[:, :, i]
        min_val, max_val = np.min(channel), np.max(channel)
        if max_val > min_val:
            channel = (channel - min_val) / (max_val - min_val) * 255.0
        msr_8bit[:, :, i] = channel.astype(np.uint8)
    return msr_8bit

def apply_amsr(image, sigmas=SIGMA_LIST):
    # ... (Tidak ada perubahan di sini, sama seperti kode Anda) ...
    weights = get_adaptive_weights(image) 
    W1, W2, W3 = weights[0], weights[1], weights[2]
    ssr_kecil = single_scale_retinex(image, sigmas[0])
    ssr_menengah = single_scale_retinex(image, sigmas[1])
    ssr_besar = single_scale_retinex(image, sigmas[2])
    amsr_log = (ssr_kecil * W1) + (ssr_menengah * W2) + (ssr_besar * W3)
    enhanced_image = normalize_and_convert(amsr_log)
    return enhanced_image

In [None]:
# --- Fungsi Evaluasi (tahan banting: coba PIQ, lalu fallback ke imquality) ---
def evaluate_quality(image):
    """
    Mengembalikan (brisque_score, niqe_score).
    Strategi: coba PIQ (beberapa lokasi API), jika gagal fallback ke imquality.* jika tersedia.
    """
    score_b = None
    score_n = None
    # Siapkan tensor untuk PIQ jika memungkinkan
    img_t = None
    try:
        img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        img_t = torch.from_numpy(img_rgb).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0
    except Exception as e:
        # gagal membuat tensor, lanjut ke fallback
        img_t = None
    # 1) BRISQUE: coba PIQ
    if img_t is not None:
        try:
            if hasattr(piq, 'brisque') and callable(piq.brisque):
                with torch.no_grad():
                    score_b = float(piq.brisque(img_t, data_range=1.0).cpu().item())
            elif hasattr(piq, 'functional') and hasattr(piq.functional, 'brisque'):
                with torch.no_grad():
                    score_b = float(piq.functional.brisque(img_t, data_range=1.0).cpu().item())
        except Exception:
            score_b = None
    # fallback BRISQUE: imquality (jika tersedia)
    if score_b is None:
        try:
            import imquality.brisque as brisque_lib
            rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            score_b = float(brisque_lib.score(rgb))
        except Exception:
            score_b = None
    # 2) NIQE: coba PIQ di beberapa lokasi; jika gagal, fallback ke imquality
    if img_t is not None:
        try:
            # coba top-level
            if hasattr(piq, 'niqe') and callable(piq.niqe):
                with torch.no_grad():
                    score_n = float(piq.niqe(img_t, data_range=1.0).cpu().item())
            # coba submodule functional
            elif hasattr(piq, 'functional') and hasattr(piq.functional, 'niqe'):
                with torch.no_grad():
                    score_n = float(piq.functional.niqe(img_t, data_range=1.0).cpu().item())
            else:
                # coba dengan grayscale tensor jika tersedia
                gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
                gray_t = torch.from_numpy(gray).unsqueeze(0).unsqueeze(0).float().to(device) / 255.0
                if hasattr(piq, 'niqe') and callable(piq.niqe):
                    with torch.no_grad():
                        score_n = float(piq.niqe(gray_t, data_range=1.0).cpu().item())
        except Exception:
            score_n = None
    # fallback NIQE via imquality
    if score_n is None:
        try:
            import imquality.niqe as niqe_lib
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            score_n = float(niqe_lib.score(gray))
        except Exception:
            score_n = None
    return score_b, score_n

# --- 1. Proses 1 Contoh Gambar ---
print("--- Memproses 1 Contoh Gambar dengan AMSR ---")
preprocessed_paths = load_image_paths(PREPROCESSED_DIR)

if preprocessed_paths:
    sample_path = preprocessed_paths[0]
    input_image = cv2.imread(sample_path)

    if input_image is not None:
        # ... (Tidak ada perubahan di logika proses & simpan) ...
        print(f"Menerapkan AMSR pada: {sample_path}")
        enhanced_image = apply_amsr(input_image)
        plot_comparison(input_image, enhanced_image, "Input (Denoised)", "Hasil AMSR")

        relative_path = os.path.relpath(sample_path, PREPROCESSED_DIR)
        output_path = os.path.join(OUTPUT_DIR, relative_path)
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        cv2.imwrite(output_path, enhanced_image)
        print(f"Hasil disimpan di: {output_path}")

        # --- Evaluasi Kuantitatif (tahan banting) untuk 1 contoh ---
        print("\n--- Evaluasi Kuantitatif (BRISQUE & NIQE) untuk 1 Contoh ---")
        score_input_b, score_input_n = evaluate_quality(input_image)
        score_enhanced_b, score_enhanced_n = evaluate_quality(enhanced_image)

        if score_input_b is not None and score_enhanced_b is not None:
            print(f"Skor BRISQUE (Input): {score_input_b:.2f} (Lebih rendah lebih baik)")
            print(f"Skor BRISQUE (AMSR):  {score_enhanced_b:.2f} (Lebih rendah lebih baik)")
            print(f"Skor NIQE (Input): {score_input_n if score_input_n is not None else 'N/A'} (Lebih rendah lebih baik)")
            print(f"Skor NIQE (AMSR):  {score_enhanced_n if score_enhanced_n is not None else 'N/A'} (Lebih rendah lebih baik)")

In [None]:
# --- 1. Proses 1 Contoh Gambar ---
print("--- Memproses 1 Contoh Gambar dengan AMSR ---")
preprocessed_paths = load_image_paths(PREPROCESSED_DIR)

if preprocessed_paths:
    sample_path = preprocessed_paths[0]
    input_image = cv2.imread(sample_path)
    
    if input_image is not None:
        # ... (Tidak ada perubahan di logika proses & simpan) ...
        print(f"Menerapkan AMSR pada: {sample_path}")
        enhanced_image = apply_amsr(input_image)
        plot_comparison(input_image, enhanced_image, "Input (Denoised)", "Hasil AMSR")
        
        relative_path = os.path.relpath(sample_path, PREPROCESSED_DIR)
        output_path = os.path.join(OUTPUT_DIR, relative_path)
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        cv2.imwrite(output_path, enhanced_image)
        print(f"Hasil disimpan di: {output_path}")

        # --- Evaluasi Kuantitatif (Diperbarui) ---
        print("\n--- Evaluasi Kuantitatif (BRISQUE & NIQE) untuk 1 Contoh ---")
        score_input_b, score_input_n = evaluate_quality(input_image)
        score_enhanced_b, score_enhanced_n = evaluate_quality(enhanced_image)
        
        if score_input_b is not None and score_enhanced_b is not None:
            print(f"Skor BRISQUE (Input): {score_input_b:.2f} (Lebih rendah lebih baik)")
            print(f"Skor BRISQUE (AMSR):  {score_enhanced_b:.2f} (Lebih rendah lebih baik)")
            print(f"Skor NIQE (Input): {score_input_n:.2f} (Lebih rendah lebih baik)")
            print(f"Skor NIQE (AMSR):  {score_enhanced_n:.2f} (Lebih rendah lebih baik)")

In [None]:
# --- 2. Batch Processing (Proses & Evaluasi Semua Gambar) ---
print("\n--- Menjalankan Batch Processing & Evaluasi (Semua Gambar) ---")

if not preprocessed_paths:
    print("Tidak ada gambar untuk diproses. Evaluasi batch dilewati.")
else:
    # Siapkan list untuk menampung semua skor dan rows untuk CSV
    scores_input_b, scores_input_n = [], []
    scores_enhanced_b, scores_enhanced_n = [], []
    rows = []
    
    for path in tqdm(preprocessed_paths, desc="Menerapkan AMSR ke semua gambar"):
        try:
            img = cv2.imread(path)
            if img is None:
                continue

            enhanced_img = apply_amsr(img)

            relative_path = os.path.relpath(path, PREPROCESSED_DIR)
            output_path = os.path.join(OUTPUT_DIR, relative_path)
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            cv2.imwrite(output_path, enhanced_img)

            # Evaluasi & Kumpulkan skor
            score_in_b, score_in_n = evaluate_quality(img)
            score_out_b, score_out_n = evaluate_quality(enhanced_img)

            if score_in_b is not None and score_out_b is not None:
                scores_input_b.append(score_in_b)
                scores_input_n.append(score_in_n)
                scores_enhanced_b.append(score_out_b)
                scores_enhanced_n.append(score_out_n)
                # Tambah row untuk CSV
                rows.append([relative_path, score_in_b, score_in_n, score_out_b, score_out_n])

        except Exception as e:
            print(f"Error memproses {path}: {e}")
    
    # --- Rangkuman Evaluasi ---
    print("\n--- Rangkuman Hasil Evaluasi Kuantitatif (BRISQUE & NIQE) ---")
    if scores_input_b:
        avg_input_b = np.mean(scores_input_b)
        avg_input_n = np.mean(scores_input_n)
        avg_enhanced_b = np.mean(scores_enhanced_b)
        avg_enhanced_n = np.mean(scores_enhanced_n)

        print(f"Total gambar dievaluasi: {len(scores_input_b)}")
        print(f"\n--- Rata-rata Skor BRISQUE --- (Lebih rendah lebih baik)")
        print(f"Input (Asli):   {avg_input_b:.2f}")
        print(f"Hasil (AMSR):   {avg_enhanced_b:.2f}")

        print(f"\n--- Rata-rata Skor NIQE --- (Lebih rendah lebih baik)")
        print(f"Input (Asli):   {avg_input_n:.2f}")
        print(f"Hasil (AMSR):   {avg_enhanced_n:.2f}")

        if avg_enhanced_b < avg_input_b and avg_enhanced_n < avg_input_n:
            print("\nKesimpulan: Kualitas perseptual dan statistik rata-rata MENINGKAT.")
        else:
            print("\nKesimpulan: Perlu dicek, salah satu atau kedua metrik tidak membaik.")
    else:
        print("Tidak ada gambar yang berhasil dievaluasi.")

    # --- Ekspor CSV dengan hasil per-gambar ---
    try:
        import csv
        csv_path = os.path.join(OUTPUT_DIR, 'evaluation_summary_amsr.csv')
        with open(csv_path, 'w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(['relative_path', 'input_brisque', 'input_niqe', 'enhanced_brisque', 'enhanced_niqe'])
            writer.writerows(rows)
        print(f"CSV evaluasi disimpan di: {csv_path}")
    except Exception as e:
        print(f"Gagal menyimpan CSV: {e}")