In [3]:
%pip install -q gradio gradio_imageslider scikit-image


Note: you may need to restart the kernel to use updated packages.


In [None]:
#Imports
import os, time, math, statistics
import numpy as np
from PIL import Image
import pandas as pd

import torch
import torch.nn.functional as F
import triton
import triton.language as tl

from skimage.metrics import peak_signal_noise_ratio as sk_psnr, structural_similarity as sk_ssim
from skimage import exposure

print("CUDA available:", torch.cuda.is_available())
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

In [None]:
# =======================
# Triton Kernels (manual tuning only; no @triton.autotune to avoid conflicts)
# =======================
@triton.jit
def conv2d_tiled_kernel(
    img_ptr, out_ptr, ker_ptr,
    H, W, K,
    STRIDE_H, STRIDE_W, OUT_STRIDE_H, OUT_STRIDE_W,
    TILE_H: tl.constexpr, TILE_W: tl.constexpr
):
    pid_h = tl.program_id(0)
    pid_w = tl.program_id(1)
    r = (K - 1) // 2
    oh0 = pid_h * TILE_H
    ow0 = pid_w * TILE_W
    oh = oh0 + tl.arange(0, TILE_H)
    ow = ow0 + tl.arange(0, TILE_W)
    acc = tl.zeros((TILE_H, TILE_W), dtype=tl.float32)
    for ky in range(0, K):
        for kx in range(0, K):
            ih = oh[:, None] + (ky - r)
            iw = ow[None, :] + (kx - r)
            mask = (ih >= 0) & (ih < H) & (iw >= 0) & (iw < W)
            pix = tl.load(img_ptr + ih * STRIDE_H + iw * STRIDE_W, mask=mask, other=0.0)
            cof = tl.load(ker_ptr + ky * K + kx)
            acc += pix * cof
    out_block = tl.make_block_ptr(
        base=out_ptr, shape=(H, W), strides=(OUT_STRIDE_H, OUT_STRIDE_W),
        offsets=(oh0, ow0), block_shape=(TILE_H, TILE_W), order=(1, 0),
    )
    tl.store(out_block, acc, boundary_check=(0, 1))

@triton.jit
def conv1d_row_kernel(
    img_ptr, tmp_ptr, ker_row_ptr,
    H, W, K,
    SIH, SIW, SOH, SOW,
    TILE_H: tl.constexpr, TILE_W: tl.constexpr
):
    pid_h = tl.program_id(0)
    pid_w = tl.program_id(1)
    r = (K - 1)//2
    oh0 = pid_h * TILE_H
    ow0 = pid_w * TILE_W
    oh = oh0 + tl.arange(0, TILE_H)
    ow = ow0 + tl.arange(0, TILE_W)
    acc = tl.zeros((TILE_H, TILE_W), dtype=tl.float32)
    for kx in range(0, K):
        ih = oh[:, None]
        iw = ow[None, :] + (kx - r)
        mask = (ih >= 0) & (ih < H) & (iw >= 0) & (iw < W)
        pix = tl.load(img_ptr + ih * SIH + iw * SIW, mask=mask, other=0.0)
        cof = tl.load(ker_row_ptr + kx)
        acc += pix * cof
    tmp_block = tl.make_block_ptr(
        base=tmp_ptr, shape=(H, W), strides=(SOH, SOW),
        offsets=(oh0, ow0), block_shape=(TILE_H, TILE_W), order=(1, 0),
    )
    tl.store(tmp_block, acc, boundary_check=(0, 1))

@triton.jit
def conv1d_col_kernel(
    tmp_ptr, out_ptr, ker_col_ptr,
    H, W, K,
    SIH, SIW, SOH, SOW,
    TILE_H: tl.constexpr, TILE_W: tl.constexpr
):
    pid_h = tl.program_id(0)
    pid_w = tl.program_id(1)
    r = (K - 1)//2
    oh0 = pid_h * TILE_H
    ow0 = pid_w * TILE_W
    oh = oh0 + tl.arange(0, TILE_H)
    ow = ow0 + tl.arange(0, TILE_W)
    acc = tl.zeros((TILE_H, TILE_W), dtype=tl.float32)
    for ky in range(0, K):
        ih = oh[:, None] + (ky - r)
        iw = ow[None, :]
        mask = (ih >= 0) & (ih < H) & (iw >= 0) & (iw < W)
        pix = tl.load(tmp_ptr + ih * SIH + iw * SIW, mask=mask, other=0.0)
        cof = tl.load(ker_col_ptr + ky)
        acc += pix * cof
    out_block = tl.make_block_ptr(
        base=out_ptr, shape=(H, W), strides=(SOH, SOW),
        offsets=(oh0, ow0), block_shape=(TILE_H, TILE_W), order=(1, 0),
    )
    tl.store(out_block, acc, boundary_check=(0, 1))

In [None]:
# =======================
# Reference & Timing Utils
# =======================
def conv2d_ref(img_2d, ker_2d):
    img = img_2d[None,None]
    ker = ker_2d[None,None]
    pad = ker.shape[-1]//2
    return F.conv2d(img, ker, padding=pad).squeeze(0).squeeze(0)

CANDIDATE_TILES = [(16,16), (32,16), (16,32), (32,32), (64,32), (32,64), (64,64)]
CANDIDATE_WARPS = [4, 8]

def time_kernel(call, reps=10, warmup=2):
    for _ in range(warmup):
        call()
        torch.cuda.synchronize()
    times = []
    for _ in range(reps):
        torch.cuda.synchronize()
        t0 = time.time()
        call()
        torch.cuda.synchronize()
        times.append(time.time() - t0)
    med = statistics.median(times)
    p20 = statistics.quantiles(times, n=5)[0]
    p80 = statistics.quantiles(times, n=5)[-1]
    return med, (p20, p80)

# Prealloc scratch buffers for stable timing
_scratch_out2d = torch.empty((4096, 4096), device=device, dtype=torch.float32)
_scratch_tmp   = torch.empty((4096, 4096), device=device, dtype=torch.float32)
_scratch_u = torch.ones(65, device=device, dtype=torch.float32) / 65
_scratch_v = torch.ones(65, device=device, dtype=torch.float32) / 65

In [None]:
# =======================
# Manual Tuning (sweep TILE/warps)
# =======================
def tune_conv2d(img, ker, H, W, K, tiles=CANDIDATE_TILES, warps=CANDIDATE_WARPS, reps=8, warmup=2):
    best = None
    for th, tw in tiles:
        grid = (triton.cdiv(H, th), triton.cdiv(W, tw))
        for nw in warps:
            def launch():
                conv2d_tiled_kernel[grid](
                    img, _scratch_out2d, ker,
                    H, W, K,
                    img.stride(0), img.stride(1),
                    _scratch_out2d.stride(0), _scratch_out2d.stride(1),
                    TILE_H=th, TILE_W=tw, num_warps=nw
                )
            med, _ = time_kernel(launch, reps=reps, warmup=warmup)
            if (best is None) or (med < best[0]):
                best = (med, th, tw, nw)
    return {"latency_s": best[0], "TILE_H": best[1], "TILE_W": best[2], "num_warps": best[3]}

def tune_sep_rowcol(img, H, W, K, tiles=CANDIDATE_TILES, warps=CANDIDATE_WARPS, reps=8, warmup=2):
    best = None
    tmp = _scratch_tmp
    u = _scratch_u[:K]; v = _scratch_v[:K]
    for th, tw in tiles:
        grid = (triton.cdiv(H, th), triton.cdiv(W, tw))
        for nw in warps:
            def launch_pair():
                conv1d_row_kernel[grid](
                    img, tmp, v, H, W, v.numel(),
                    img.stride(0), img.stride(1),
                    tmp.stride(0), tmp.stride(1),
                    TILE_H=th, TILE_W=tw, num_warps=nw
                )
                conv1d_col_kernel[grid](
                    tmp, _scratch_out2d, u, H, W, u.numel(),
                    tmp.stride(0), tmp.stride(1),
                    _scratch_out2d.stride(0), _scratch_out2d.stride(1),
                    TILE_H=th, TILE_W=tw, num_warps=nw
                )
            med, _ = time_kernel(launch_pair, reps=reps, warmup=warmup)
            if (best is None) or (med < best[0]):
                best = (med, th, tw, nw)
    return {"latency_s": best[0], "TILE_H": best[1], "TILE_W": best[2], "num_warps": best[3]}

In [None]:
# =======================
# Steady-State Runners
# =======================
def run_conv2d_steady(img, ker, reps=10, warmup=3):
    H, W = img.shape
    K = ker.shape[0]
    out = torch.empty_like(img)
    best = tune_conv2d(img, ker, H, W, K, reps=max(4,reps//2), warmup=max(2,warmup//2))
    th, tw, nw = best["TILE_H"], best["TILE_W"], best["num_warps"]
    grid = (triton.cdiv(H, th), triton.cdiv(W, tw))
    def launch():
        conv2d_tiled_kernel[grid](
            img, out, ker, H, W, K,
            img.stride(0), img.stride(1),
            out.stride(0), out.stride(1),
            TILE_H=th, TILE_W=tw, num_warps=nw
        )
    med, (p20,p80) = time_kernel(launch, reps=reps, warmup=warmup)
    mpix_s = (H*W)/(med*1e6)
    return out, {"MPix/s": mpix_s, "latency_s": med, "p20_s": p20, "p80_s": p80,
                 "TILE_H": th, "TILE_W": tw, "num_warps": nw}

def run_separable_steady(img, u, v, reps=10, warmup=3):
    H, W = img.shape
    K = v.numel()
    tmp = torch.empty_like(img)
    out = torch.empty_like(img)
    best = tune_sep_rowcol(img, H, W, K, reps=max(4,reps//2), warmup=max(2,warmup//2))
    th, tw, nw = best["TILE_H"], best["TILE_W"], best["num_warps"]
    grid = (triton.cdiv(H, th), triton.cdiv(W, tw))
    def launch_pair():
        conv1d_row_kernel[grid](
            img, tmp, v, H, W, v.numel(),
            img.stride(0), img.stride(1),
            tmp.stride(0), tmp.stride(1),
            TILE_H=th, TILE_W=tw, num_warps=nw
        )
        conv1d_col_kernel[grid](
            tmp, out, u, H, W, u.numel(),
            tmp.stride(0), tmp.stride(1),
            out.stride(0), out.stride(1),
            TILE_H=th, TILE_W=tw, num_warps=nw
        )
    med, (p20,p80) = time_kernel(launch_pair, reps=reps, warmup=warmup)
    mpix_s = (H*W)/(med*1e6)
    return out, {"MPix/s": mpix_s, "latency_s": med, "p20_s": p20, "p80_s": p80,
                 "TILE_H": th, "TILE_W": tw, "num_warps": nw}

In [None]:
# =======================
# Filters & Evaluation
# =======================
FILTERS = {
    "gaussian3x3": np.array([[1,2,1],[2,4,2],[1,2,1]], dtype=np.float32)/16.0,
    "gaussian5x5": np.array([[1,4,6,4,1],[4,16,24,16,4],[6,24,36,24,6],[4,16,24,16,4],[1,4,6,4,1]], dtype=np.float32)/256.0,
    "box3x3": np.ones((3,3), dtype=np.float32)/9.0,
    "sobel_h": np.array([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=np.float32)/8.0,
    "sobel_v": np.array([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=np.float32)/8.0,
    "laplacian": np.array([[0,-1,0],[-1,4,-1],[0,-1,0]], dtype=np.float32),
}

def compute_gflops(H, W, K, latency_s):
    flops = H * W * 2 * K * K
    return flops / (latency_s * 1e9)

def evaluate_once(img_np, filter_name="gaussian5x5", algo="2D", reps=10, warmup=3):
    ker_np = FILTERS[filter_name]
    K = ker_np.shape[0]
    img = torch.from_numpy(img_np).to(device, dtype=torch.float32)
    ker = torch.from_numpy(ker_np).to(device, dtype=torch.float32)
    ref = conv2d_ref(img.cpu(), ker.cpu()).to(device)
    if algo.lower().startswith("sep"):
        u = torch.ones(K, device=device, dtype=torch.float32)/K
        v = torch.ones(K, device=device, dtype=torch.float32)/K
        out, perf = run_separable_steady(img, u, v, reps=reps, warmup=warmup)
    else:
        out, perf = run_conv2d_steady(img, ker, reps=reps, warmup=warmup)
    out_np = out.detach().cpu().numpy().clip(0,1)
    ref_np = ref.detach().cpu().numpy().clip(0,1)
    psnr = sk_psnr(ref_np, out_np, data_range=1.0)
    ssim = sk_ssim(ref_np, out_np, data_range=1.0)
    max_err = float(np.max(np.abs(out_np - ref_np)))
    gflops = compute_gflops(img_np.shape[0], img_np.shape[1], K, perf["latency_s"])
    return out_np, {
        "MPix/s": perf["MPix/s"], "latency_s": perf["latency_s"], "GFLOPS": gflops,
        "TILE_H": perf["TILE_H"], "TILE_W": perf["TILE_W"], "num_warps": perf["num_warps"],
        "PSNR_dB_vs_ref": psnr, "SSIM_vs_ref": ssim, "max_abs_err_vs_ref": max_err
    }

In [None]:
# =======================
# New: Gaussian Denoise / Unsharp / CLAHE
# =======================
def gaussian_kernel(size, sigma):
    ax = torch.arange(size, device=device, dtype=torch.float32) - (size - 1)/2
    xx, yy = torch.meshgrid(ax, ax, indexing='ij')
    ker = torch.exp(-(xx*xx + yy*yy) / (2.0 * sigma * sigma))
    ker = ker / ker.sum()
    return ker

def evaluate_with_kernel(img_np, ker_np, algo="2D", reps=10, warmup=3):
    img = torch.from_numpy(img_np).to(device, dtype=torch.float32)
    ker = torch.from_numpy(ker_np.astype(np.float32)).to(device, dtype=torch.float32)
    ref = conv2d_ref(img.cpu(), ker.cpu()).to(device)
    if algo.lower().startswith("sep"):
        K = ker.shape[0]
        u = torch.ones(K, device=device, dtype=torch.float32)/K
        v = torch.ones(K, device=device, dtype=torch.float32)/K
        out, perf = run_separable_steady(img, u, v, reps=reps, warmup=warmup)
    else:
        out, perf = run_conv2d_steady(img, ker, reps=reps, warmup=warmup)
    out_np = out.detach().cpu().numpy().clip(0,1)
    ref_np = ref.detach().cpu().numpy().clip(0,1)
    psnr = sk_psnr(ref_np, out_np, data_range=1.0)
    ssim = sk_ssim(ref_np, out_np, data_range=1.0)
    max_err = float(np.max(np.abs(out_np - ref_np)))
    gflops = compute_gflops(img_np.shape[0], img_np.shape[1], ker_np.shape[0], perf["latency_s"])
    return out_np, {
        "MPix/s": perf["MPix/s"], "latency_s": perf["latency_s"], "GFLOPS": gflops,
        "TILE_H": perf["TILE_H"], "TILE_W": perf["TILE_W"], "num_warps": perf["num_warps"],
        "PSNR_dB_vs_ref": psnr, "SSIM_vs_ref": ssim, "max_abs_err_vs_ref": max_err
    }

def unsharp_mask_gpu(img_np, K=5, sigma=1.0, amount=1.0, algo="2D", reps=10, warmup=3):
    ker = gaussian_kernel(K, sigma)
    img = torch.from_numpy(img_np).to(device, dtype=torch.float32)
    if algo.lower().startswith("sep"):
        u = torch.ones(K, device=device, dtype=torch.float32)/K
        v = torch.ones(K, device=device, dtype=torch.float32)/K
        blurred, perf_blur = run_separable_steady(img, u, v, reps=reps, warmup=warmup)
    else:
        blurred, perf_blur = run_conv2d_steady(img, ker, reps=reps, warmup=warmup)
    enhanced = (img + amount * (img - blurred)).clamp(0,1)
    out_np = enhanced.detach().cpu().numpy()
    psnr_orig = sk_psnr(img_np, out_np, data_range=1.0)
    ssim_orig = sk_ssim(img_np, out_np, data_range=1.0)
    return out_np, {
        "MPix/s": (img.numel())/(perf_blur["latency_s"]*1e6),
        "latency_s": perf_blur["latency_s"],
        "TILE_H": perf_blur["TILE_H"], "TILE_W": perf_blur["TILE_W"], "num_warps": perf_blur["num_warps"],
        "PSNR_dB_vs_input": psnr_orig, "SSIM_vs_input": ssim_orig
    }

def apply_clahe(img_np, clip_limit=0.01):
    clahe = exposure.equalize_adapthist(img_np, clip_limit=clip_limit)
    return clahe.astype(np.float32)

In [None]:
# =======================
# Gradio Frontend (Before/After + Metrics)
# =======================
import gradio as gr
from gradio_imageslider import ImageSlider  # in-image slider

def run_pipeline(pil_img, mode, K, sigma, amount, clip_limit, algo, noise_sig, fixed_filter):
    # Reuse your apply_and_compare internals but also return both images separately
    # Compute outputs
    if pil_img is None:
        N = 1024
        img_np = (np.indices((N,N)).sum(axis=0) % 2).astype(np.float32)
    else:
        arr = np.array(pil_img).astype(np.float32)
        if arr.ndim == 3:
            arr = 0.2989*arr[...,0] + 0.5870*arr[...,1] + 0.1140*arr[...,2]
        img_np = (arr/255.0).clip(0,1)
    clean_np = img_np.copy()
    if noise_sig > 0:
        img_np = np.clip(img_np + np.random.normal(0, noise_sig, img_np.shape).astype(np.float32), 0, 1)

    if mode.startswith("Denoise"):
        ker = gaussian_kernel(int(K), float(sigma)).detach().cpu().numpy()
        out_np, m = evaluate_with_kernel(img_np, ker, algo=algo, reps=8, warmup=2)
        psnr_clean = sk_psnr(clean_np, out_np, data_range=1.0)
        ssim_clean = sk_ssim(clean_np, out_np, data_range=1.0)
        metrics_text = (
            f"Denoise | K={K} σ={sigma:.2f} | MPix/s={m['MPix/s']:.2f} | "
            f"PSNR(clean,out)={psnr_clean:.2f} dB | SSIM(clean,out)={ssim_clean:.4f} | "
            f"TILE={m['TILE_H']}x{m['TILE_W']} Warps={m['num_warps']}"
        )
    elif mode.startswith("Sharpen"):
        out_np, m = unsharp_mask_gpu(img_np, K=int(K), sigma=float(sigma), amount=float(amount),
                                     algo=algo, reps=8, warmup=2)
        metrics_text = (
            f"Unsharp | K={K} σ={sigma:.2f} amount={amount:.2f} | "
            f"MPix/s={m['MPix/s']:.2f} | TILE={m['TILE_H']}x{m['TILE_W']} Warps={m['num_warps']}"
        )
    elif mode.startswith("Contrast"):
        out_np = apply_clahe(img_np, clip_limit=float(clip_limit))
        metrics_text = f"CLAHE | clip_limit={clip_limit:.3f} | (CPU post-process for visualization)"
    else:
        ker_np = FILTERS[fixed_filter].astype(np.float32)
        out_np, m = evaluate_with_kernel(img_np, ker_np, algo=algo, reps=8, warmup=2)
        metrics_text = (
            f"Fixed Filter {fixed_filter} | MPix/s={m['MPix/s']:.2f} | "
            f"TILE={m['TILE_H']}x{m['TILE_W']} Warps={m['num_warps']}"
        )

    before_u8 = Image.fromarray((img_np*255).astype(np.uint8))
    after_u8  = Image.fromarray((out_np*255).astype(np.uint8))
    diff_img = Image.fromarray((np.abs(out_np - img_np)*255).astype(np.uint8))
    return before_u8, after_u8, diff_img, metrics_text, (before_u8, after_u8)

with gr.Blocks(title="GPU-Optimized X-ray Filtering") as demo:
    gr.Markdown("### GPU Optmization")
    with gr.Row():
        input_img = gr.Image(type="pil", label="Input image")
    with gr.Row():
        mode = gr.Radio(
            ["Denoise (Gaussian)", "Sharpen (Unsharp Mask)", "Contrast (CLAHE)", "Filter (Fixed Kernel)"],
            value="Denoise (Gaussian)", label="Mode"
        )
        algo = gr.Radio(["2D", "separable"], value="2D", label="GPU algorithm")
        fixed_filter = gr.Dropdown(list(FILTERS.keys()), value="gaussian5x5", label="Fixed filter")
    with gr.Row():
        K_slider = gr.Slider(3, 15, value=5, step=2, label="Kernel size K (odd)")
        sigma_slider = gr.Slider(0.5, 3.0, value=1.0, step=0.1, label="Sigma (Gaussian/Unsharp)")
        amount_slider = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="Amount (Unsharp)")
        clip_slider = gr.Slider(0.005, 0.05, value=0.01, step=0.005, label="CLAHE clip_limit")
        noise_sigma = gr.Slider(0.0, 0.1, value=0.0, step=0.01, label="Add Gaussian noise σ (0..0.1)")

    with gr.Row():
        with gr.Column():
            before_out = gr.Image(label="Before")
        with gr.Column():
            after_out = gr.Image(label="After")
    with gr.Row():
        slider = ImageSlider(label="Drag to Compare", type="pil")
    with gr.Row():
        diff_out = gr.Image(label="|Δ| Heatmap")
        metrics = gr.Textbox(label="Metrics", lines=6)

    run_btn = gr.Button("Run")
    def on_run(img, mo, K, si, am, cl, al, ns, ff):
        b, a, d, txt, pair = run_pipeline(img, mo, K, si, am, cl, al, ns, ff)
        return b, a, pair, d, txt

    run_btn.click(
        on_run,
        inputs=[input_img, mode, K_slider, sigma_slider, amount_slider, clip_slider, algo, noise_sigma, fixed_filter],
        outputs=[before_out, after_out, slider, diff_out, metrics]
    )

demo.launch(share=True)

Note: you may need to restart the kernel to use updated packages.
CUDA available: True
Device: cuda
* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://7335aac8c772584d36.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




  return 10 * np.log10((data_range**2) / err)
