In [42]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt

# Agent (đã gộp threshold + router)
from agent import AdaptiveRestorationAgent

# Module A (visualization)
from module_A.visualize import (
    visualize_pipeline,
    plot_histogram_before_after
)

# Zero-DCE model
from module_B.model import DCENet


ImportError: cannot import name 'AdaptiveRestorationAgent' from 'agent' (unknown location)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint_path = (
    "module_B/real_synthetic_charbonnier_perceptual_color_exposure_best.pth"
)

model = DCENet(num_iterations=8).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

print("Zero-DCE model loaded.")


In [None]:
IMG_PATH = "test.jpg"

img_bgr = cv2.imread(IMG_PATH)
assert img_bgr is not None, "Failed to load image"

img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

plt.imshow(img_rgb)
plt.title("Input Image")
plt.axis("off")
plt.show()


In [None]:
agent = AdaptiveRestorationAgent(
    low_light_thresh=25.0,
    low_contrast_thresh=18.0,
    color_cast_thresh=10.0,
    high_noise_thresh=52.3,   # <-- lấy từ LOL eval (đã giải thích trong report)
)


In [None]:
result = agent.run(img_rgb)

output = result["output"]
analysis = result["analysis"]
decision = result["decision"]

print("=== Agent Decision ===")
for k, v in decision.items():
    print(f"{k}: {v}")


In [None]:
if decision["stage"] == "deep_learning":
    img_norm = img_rgb.astype(np.float32) / 255.0
    img_tensor = (
        torch.from_numpy(img_norm)
        .permute(2, 0, 1)
        .unsqueeze(0)
        .to(device)
    )

    with torch.no_grad():
        enhanced, _ = model(img_tensor)
        enhanced = torch.clamp(enhanced, 0, 1)

    output = (
        enhanced.squeeze(0)
        .permute(1, 2, 0)
        .cpu()
        .numpy()
    )


In [None]:
visualize_pipeline(
    img=img_rgb,
    analysis=analysis,
    decision=decision,
    output=output
)

In [None]:
plot_histogram_before_after(
    cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR),
    cv2.cvtColor((output * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
)

In [None]:
from module_A.traditional import apply_clahe, gray_world, retinex

methods = {
    "Original": img_rgb,
    "CLAHE": apply_clahe(img_rgb)[0],
    "Gray World": gray_world(img_rgb)[0],
    "Retinex": retinex(img_rgb)[0],
    "Agent Output": output,
}

plt.figure(figsize=(14, 5))
for i, (name, im) in enumerate(methods.items()):
    plt.subplot(1, len(methods), i + 1)
    plt.imshow(im)
    plt.title(name)
    plt.axis("off")

plt.tight_layout()
plt.show()
