# Enhancekit demo

This notebook demonstrates the basic workflow for using `enhancekit` to inspect the
available models and run an enhancement pass over one or more distorted images. Place
any test inputs inside `examples/images/` before running the cells.


In [None]:
from pathlib import Path
from typing import Optional

import numpy as np
from PIL import Image, ImageFilter
from IPython.display import display

from enhancekit import list_models, load_model

# Ensure the examples folder exists and point to the image directory
EXAMPLES_DIR = Path("examples")
IMAGES_DIR = EXAMPLES_DIR / "images"
IMAGES_DIR.mkdir(parents=True, exist_ok=True)


## Explore the registry

Inspect the registered models to decide which variant you want to try. The built-in
examples include lightweight identity models that are useful for verifying the
pipeline end-to-end.


In [None]:
# List the registered model identifiers
list_models()


## Prepare an example image

If you already placed a distorted image inside `examples/images`, the notebook will
pick the first one and create a resized copy to keep the Uformer demo lightweight.
Otherwise, it will synthesize a small noisy gradient so you can run the demo
without additional assets.



In [None]:
def load_distorted_image() -> Path:
    existing: Optional[Path] = next(IMAGES_DIR.glob("*.png"), None)
    if existing:
        return existing

    # Build a simple synthetic test image when no sample is provided
    base = np.zeros((256, 256, 3), dtype=np.float32)
    xs = np.linspace(0, 1, base.shape[1], dtype=np.float32)
    ys = np.linspace(0, 1, base.shape[0], dtype=np.float32)
    base[..., 0] = xs  # horizontal gradient
    base[..., 1] = ys[:, None]  # vertical gradient
    base[..., 2] = 0.2

    noise = np.random.normal(scale=0.05, size=base.shape).astype(np.float32)
    noisy = np.clip(base + noise, 0.0, 1.0)

    image = Image.fromarray((noisy * 255).astype("uint8")).filter(ImageFilter.GaussianBlur(radius=1.5))
    out_path = IMAGES_DIR / "synthetic_distorted.png"
    image.save(out_path)
    return out_path


def resize_image_for_demo(path: Path, edge: int = 256) -> Path:
    """Create a resized copy for lightweight CPU inference."""

    image = Image.open(path)
    if image.size == (edge, edge):
        return path

    resized = image.resize((edge, edge), resample=Image.LANCZOS)
    resized_path = path.with_name(f"{path.stem}_resized.png")
    resized.save(resized_path)
    return resized_path


example_path = load_distorted_image()
print(f"Using example image: {example_path}")
display(Image.open(example_path))

resized_path = resize_image_for_demo(example_path, edge=256)
if resized_path != example_path:
    print(f"Resized demo image saved to: {resized_path}")
else:
    print("Image already at the demo-friendly size.")
example_path = resized_path
display(Image.open(example_path))



## Run enhancement

Load the registered Uformer model and run `enhance_image` to generate the enhanced
output. A resized 256x256 copy keeps the forward pass quick on CPU. You can adjust
`device` to use a GPU if one is available and pick any registered model name from
the list above.



In [None]:
model = load_model(
    "uformer",
    pretrained=False,
    device="cpu",
    freeze=True,
    img_size=256,
)

# Enhance a single image
result = model.enhance_image(example_path)

enhanced_path = IMAGES_DIR / "enhanced_preview.png"
result.save(enhanced_path)

print(f"Enhanced image saved to: {enhanced_path}")
display(result)



## Quick Uformer sanity check

Initialize a lightweight Uformer instance using the registry identity so the
configuration stays in sync with the rest of the codebase. Overriding the
`img_size`, `embed_dim`, and other hyperparameters keeps the forward pass small
for CPU execution.



In [None]:
import torch

lightweight_uformer = load_model(
    "uformer",
    pretrained=False,
    device="cpu",
    freeze=True,
    img_size=64,
    embed_dim=32,
    depths=[2, 2, 2, 2, 2, 2, 2, 2, 2],
    num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
    win_size=8,
    token_projection="linear",
    token_mlp="leff",
    modulator=False,
    shift_flag=True,
)

backbone = getattr(lightweight_uformer, "backbone", lightweight_uformer)
torch.manual_seed(0)
fake_input = torch.rand(1, 3, 64, 64)
backbone.eval()
with torch.no_grad():
    fake_output = backbone(fake_input)

print(f"Fake input shape: {tuple(fake_input.shape)} -> output shape: {tuple(fake_output.shape)}")



## Batch and folder utilities

Enhance an entire folder by reusing the same model. Outputs are written to a sibling
`examples/outputs` directory so you can compare them with the originals.


In [None]:
OUTPUT_DIR = EXAMPLES_DIR / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Run over every image in the folder
output_paths = model.enhance_folder(IMAGES_DIR, output_folder=OUTPUT_DIR)

print(f"Wrote {len(output_paths)} enhanced files to {OUTPUT_DIR}")
for path in output_paths:
    display(Image.open(path))
