In [None]:
import cv2
import numpy as np
import os
from PIL import Image, ImageEnhance
from io import BytesIO
import base64
from IPython.display import display, HTML

# --------------------------------------
# Utility: Display images in Notebook (HTML)
# --------------------------------------
def pil_to_base64_html(img, width=400):
    buffer = BytesIO()
    img.save(buffer, format="PNG")
    encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
    return f'<img src="data:image/png;base64,{encoded}" width="{width}px" style="margin:5px;"/>'

# ==============================
# Enhancement Pipeline Components
# ==============================

# ---- Vignette Detection ----
def detect_vignette(img_bgr, threshold=0.65):
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255
    h, w = gray.shape
    center_brightness = np.mean(gray[h//3:2*h//3, w//3:2*w//3])
    corners = [
        np.mean(gray[0:h//5, 0:w//5]),
        np.mean(gray[0:h//5, 4*w//5:w]),
        np.mean(gray[4*h//5:h, 0:w//5]),
        np.mean(gray[4*h//5:h, 4*w//5:w]),
    ]
    corner_brightness = np.mean(corners)
    ratio = corner_brightness / (center_brightness+1e-8)
    return ratio < threshold

def remove_vignette(img_bgr):
    rows, cols = img_bgr.shape[:2]
    kernel_x = cv2.getGaussianKernel(cols, cols/2)
    kernel_y = cv2.getGaussianKernel(rows, rows/2)
    kernel = kernel_y * kernel_x.T
    mask = kernel / kernel.max()
    result = np.empty_like(img_bgr, dtype=np.float32)
    for i in range(3):
        result[..., i] = img_bgr[..., i] / mask
    result = np.clip(result, 0, 255).astype(np.uint8)
    return result


# ---- Dehaze ----
def dark_channel(img, size=15):
    min_img = cv2.min(cv2.min(img[...,0], img[...,1]), img[...,2])
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (size, size))
    dark = cv2.erode(min_img, kernel)
    return dark

def estimate_atmospheric_light(img, dark, top_percent=0.001):
    h, w = dark.shape
    num_pixels = max(1, int(h * w * top_percent))
    flat_img = img.reshape(h*w, 3)
    flat_dark = dark.ravel()
    indices = np.argpartition(flat_dark, -num_pixels)[-num_pixels:]
    atmospheric_light = np.mean(flat_img[indices], axis=0)
    return atmospheric_light

def estimate_transmission(img, atmospheric_light, omega=0.95, size=15):
    norm_img = img / atmospheric_light
    dark = dark_channel(norm_img, size)
    transmission = 1 - omega * dark
    return transmission

def guided_filter_gray(I, p, r, eps):
    mean_I = cv2.boxFilter(I, cv2.CV_32F,(r,r))
    mean_p = cv2.boxFilter(p, cv2.CV_32F,(r,r))
    corr_I = cv2.boxFilter(I*I, cv2.CV_32F,(r,r))
    corr_Ip = cv2.boxFilter(I*p, cv2.CV_32F,(r,r))
    var_I = corr_I - mean_I*mean_I
    cov_Ip = corr_Ip - mean_I*mean_p
    a = cov_Ip/(var_I + eps)
    b = mean_p - a*mean_I
    mean_a = cv2.boxFilter(a, cv2.CV_32F,(r,r))
    mean_b = cv2.boxFilter(b, cv2.CV_32F,(r,r))
    q = mean_a*I + mean_b
    return q

def dehaze(img_bgr, r=60, eps=1e-3, omega=0.95):
    img = img_bgr.astype(np.float32) / 255
    dark = dark_channel(img_bgr)
    A = estimate_atmospheric_light(img_bgr, dark)
    transmission = estimate_transmission(img, A, omega=omega)
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255
    transmission_refined = guided_filter_gray(gray, transmission.astype(np.float32), r, eps)
    transmission_refined = np.clip(transmission_refined, 0.1, 1)
    J = np.empty_like(img)
    for c in range(3):
        J[...,c] = (img[...,c] - A[c]/255) / transmission_refined + A[c]/255
    J = np.clip(J, 0, 1)
    return (J * 255).astype(np.uint8)

def finishing_touches_pipeline(img_bgr):
    output = remove_vignette(img_bgr)
    output = dehaze(output)
    return output

# ---- Dull/Gentle Enhancement ----
def is_image_dull(image_path, brightness_thresh=110, contrast_thresh=50):
    img = cv2.imread(image_path)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    mean_brightness = np.mean(gray)
    contrast = np.std(gray)
    return not ((mean_brightness > brightness_thresh) and (contrast > contrast_thresh))

def gentle_enhance(pil_img):
    pil_img = ImageEnhance.Contrast(pil_img).enhance(1.11)
    pil_img = ImageEnhance.Sharpness(pil_img).enhance(1.65)
    pil_img = ImageEnhance.Brightness(pil_img).enhance(1.01)
    return pil_img

def dull_enhance(pil_img):
    pil_img = ImageEnhance.Sharpness(pil_img).enhance(2.0)
    pil_img = ImageEnhance.Contrast(pil_img).enhance(1.01)
    pil_img = ImageEnhance.Brightness(pil_img).enhance(1.01)
    return pil_img

# ---- White Balance / Color Adjust ----
def find_dominant_colors(img_bgr, k=4):
    data = img_bgr.reshape((-1, 3)).astype(np.float32)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 40, 0.2)
    _, labels, palette = cv2.kmeans(data, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
    _, counts = np.unique(labels, return_counts=True)
    dominant_colors = palette.astype(int)
    percentages = counts / counts.sum()
    sorted_idx = np.argsort(percentages)[::-1]
    return dominant_colors[sorted_idx], percentages[sorted_idx]

def is_image_type_white_dominant(img_bgr, white_rgb_thresh=240, min_frac=0.2):
    dominant_colors, percentages = find_dominant_colors(img_bgr, k=4)
    for color, frac in zip(dominant_colors, percentages):
        if np.all(color >= white_rgb_thresh) and frac >= min_frac:
            return True
    return False

def white_balance_perfect_reflector(img_bgr, percentile=95):
    img = img_bgr.astype(np.float32)
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    thresh = np.percentile(gray, percentile)
    mask = gray >= thresh
    if np.count_nonzero(mask) == 0:
        return img_bgr
    avg_b = np.mean(img[..., 0][mask])
    avg_g = np.mean(img[..., 1][mask])
    avg_r = np.mean(img[..., 2][mask])
    scale_b = 255.0 / avg_b
    scale_g = 255.0 / avg_g
    scale_r = 255.0 / avg_r
    img[..., 0] *= scale_b
    img[..., 1] *= scale_g
    img[..., 2] *= scale_r
    return np.clip(img, 0, 255).astype(np.uint8)

def adjust_blue_cast(img_bgr, intensity=0.1):
    img = img_bgr.astype(np.float32)
    img[..., 0] *= (1 - intensity)
    img[..., 1] *= (1 + intensity / 2)
    return np.clip(img, 0, 255).astype(np.uint8)

# ---- Local Tone Mapping ----
def guided_filter(I, p, r, eps):
    mean_I = cv2.boxFilter(I, cv2.CV_32F, (r,r))
    mean_p = cv2.boxFilter(p, cv2.CV_32F, (r,r))
    corr_I = cv2.boxFilter(I*I, cv2.CV_32F, (r,r))
    corr_Ip = cv2.boxFilter(I*p, cv2.CV_32F, (r,r))
    var_I = corr_I - mean_I*mean_I
    cov_Ip = corr_Ip - mean_I*mean_p
    a = cov_Ip / (var_I + eps)
    b = mean_p - a*mean_I
    mean_a = cv2.boxFilter(a, cv2.CV_32F, (r,r))
    mean_b = cv2.boxFilter(b, cv2.CV_32F, (r,r))
    q = mean_a * I + mean_b
    return q

def local_tone_mapping(img_bgr, radius=15, eps=1e-3, dodge_burn_strength=2.5):
    img_lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
    L, A, B = cv2.split(img_lab)
    L_float = L.astype(np.float32) / 255.0
    guided_img = guided_filter(L_float, L_float, radius, eps)
    enhanced_L = L_float + dodge_burn_strength * (L_float - guided_img)
    enhanced_L = np.clip(enhanced_L, 0, 1)
    L_enhanced = (enhanced_L * 255).astype(np.uint8)
    enhanced_lab = cv2.merge([L_enhanced, A, B])
    return cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2BGR)

def adjust_saturation_vibrance(img_bgr, saturation_scale=1.3, vibrance_scale=0.5):
    img_hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32)
    H, S, V = cv2.split(img_hsv)
    S = S * saturation_scale
    S = np.clip(S, 0, 255)
    vibrance_boost = vibrance_scale * (255 - S) * (S / 255)
    S = np.clip(S + vibrance_boost, 0, 255)
    img_hsv_adj = cv2.merge([H, S, V]).astype(np.uint8)
    return cv2.cvtColor(img_hsv_adj, cv2.COLOR_HSV2BGR)

# ==============================
# --- Full Pipeline (returns enhanced PIL image) ---
# ==============================
def run_full_pipeline_return_only(pil_original, img_bgr, image_path):
    if detect_vignette(img_bgr):
        img_bgr = finishing_touches_pipeline(img_bgr)

    locally_toned_img = local_tone_mapping(img_bgr)
    color_styled_img = adjust_saturation_vibrance(locally_toned_img, saturation_scale=1.3, vibrance_scale=0.5)
    pil_tone_styled = Image.fromarray(cv2.cvtColor(color_styled_img, cv2.COLOR_BGR2RGB))

    if is_image_dull(image_path):
        enhanced_pil = dull_enhance(pil_tone_styled.copy())
    else:
        enhanced_pil = gentle_enhance(pil_tone_styled.copy())

    enhanced_img_bgr = cv2.cvtColor(np.array(enhanced_pil), cv2.COLOR_RGB2BGR)
    if is_image_type_white_dominant(enhanced_img_bgr):
        enhanced_img_bgr = white_balance_perfect_reflector(enhanced_img_bgr, percentile=95)
        enhanced_img_bgr = adjust_blue_cast(enhanced_img_bgr, intensity=0.01)
        enhanced_pil = Image.fromarray(cv2.cvtColor(enhanced_img_bgr, cv2.COLOR_BGR2RGB))

    return enhanced_pil

# ==============================
# Batch Processing with Notebook Display + Webpage Export
# ==============================

def enhance_images_in_folder(folder_path, export_html_path=None , output_folder="enhanced_output"):
    html_blocks = []

    # Make sure output folder exists
    if output_folder:
        os.makedirs(output_folder, exist_ok=True)

    for fname in os.listdir(folder_path):
        fpath = os.path.join(folder_path, fname)
        if not fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            continue

        img_bgr = cv2.imread(fpath)
        if img_bgr is None:
            print(f"Failed to load {fname}, skipping.")
            continue

        pil_original = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
        enhanced_pil = run_full_pipeline_return_only(pil_original, img_bgr, fpath)

        #  SAVE ENHANCED IMAGE HERE 
        if output_folder:
            save_path = os.path.join(output_folder, f"enhanced_{fname}")
            enhanced_pil.save(save_path)
            print(f"Saved enhanced image: {save_path}")
        
        orig_html = pil_to_base64_html(pil_original)
        enh_html = pil_to_base64_html(enhanced_pil)

        # Centered layout
        block = f"""
        <div style="display:flex; flex-direction:row; justify-content:center; align-items:center; margin-bottom:20px;">
            <div style="text-align:center; margin-right:20px;">{orig_html}<br><b>Original</b></div>
            <div style="text-align:center;">{enh_html}<br><b>Enhanced</b></div>
        </div>
        """
        
        html_blocks.append(block)

    final_html = "<html><head><title>Auto Enhancement Results</title></head><body>" + "".join(html_blocks) + "</body></html>"

    # Display inside Jupyter Notebook
    display(HTML(final_html))

    # Export to standalone HTML if path given
    if export_html_path:
        with open(export_html_path, "w", encoding="utf-8") as f:
            f.write(final_html)
        print(f"Results exported to {export_html_path}")

# ==============================
# Example Usage
# ==============================
if __name__ == "__main__":
    folder = r"IMAGE_FOLDER_PATH"
    enhance_images_in_folder(folder, export_html_path="OUTPUT_FOLDER_PATH_WITH_HTML_FILE_LABEL.html" , output_folder="OUTPUT_FOLDER_PATH")


Saved enhanced image: D:/AutoEnhnace/Final_Dataset/Batch1/Interior_Dataset/Output1\enhanced_Input_1.jpg
Saved enhanced image: D:/AutoEnhnace/Final_Dataset/Batch1/Interior_Dataset/Output1\enhanced_Input_10.jpg
Saved enhanced image: D:/AutoEnhnace/Final_Dataset/Batch1/Interior_Dataset/Output1\enhanced_Input_100.jpg
Saved enhanced image: D:/AutoEnhnace/Final_Dataset/Batch1/Interior_Dataset/Output1\enhanced_Input_101.jpg
Saved enhanced image: D:/AutoEnhnace/Final_Dataset/Batch1/Interior_Dataset/Output1\enhanced_Input_102.jpg
Saved enhanced image: D:/AutoEnhnace/Final_Dataset/Batch1/Interior_Dataset/Output1\enhanced_Input_103.jpg
Saved enhanced image: D:/AutoEnhnace/Final_Dataset/Batch1/Interior_Dataset/Output1\enhanced_Input_104.jpg
Saved enhanced image: D:/AutoEnhnace/Final_Dataset/Batch1/Interior_Dataset/Output1\enhanced_Input_105.jpg
Saved enhanced image: D:/AutoEnhnace/Final_Dataset/Batch1/Interior_Dataset/Output1\enhanced_Input_106.jpg
Saved enhanced image: D:/AutoEnhnace/Final_Datase