# Kvasir-HealthAI End-to-End Segmentation Walkthrough

This notebook summarizes the preprocessing, training artefacts, inference, and explainability workflow for the U-Net based polyp segmentation pipeline. Comments stay concise while keeping each major processing stage visible.

In [None]:
import json
import os
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from imageio import v2 as imageio
from skimage import color, exposure, morphology
from skimage.restoration import inpaint_biharmonic, denoise_bilateral
from scipy import ndimage
import torch

from src.models.unet import UNet
from src.training.losses import BCEDiceLoss

plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (10, 4)
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['axes.labelsize'] = 11

REPO_ROOT = Path.cwd().resolve()
if not (REPO_ROOT / 'notebooks').exists():
    REPO_ROOT = REPO_ROOT.parent
ASSET_ROOT = REPO_ROOT / 'notebooks' / 'assets'
RESULT_ROOT = ASSET_ROOT / 'results'
SAMPLE_ROOT = REPO_ROOT / 'sample'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


In [None]:
def remove_specular_highlights(rgb, value_thr=0.85, saturation_thr=0.25):
    float_img = rgb.astype(np.float32) / 255.0
    hsv = color.rgb2hsv(float_img)
    bright = hsv[..., 2] > value_thr
    low_sat = hsv[..., 1] < saturation_thr
    mask = bright & low_sat
    repaired = inpaint_biharmonic(float_img, mask, channel_axis=-1)
    repaired = np.clip(repaired, 0.0, 1.0)
    return (repaired * 255).astype(np.uint8), (mask.astype(np.uint8) * 255)

def homomorphic_filter_channel(channel, cutoff=30.0, gamma_l=0.9, gamma_h=1.1, blend=0.85):
    g0 = channel.astype(np.float32) / 255.0 + 1e-6
    log_g = np.log(g0)
    dft = np.fft.fft2(log_g)
    dft_shift = np.fft.fftshift(dft)
    h, w = channel.shape
    yy, xx = np.ogrid[:h, :w]
    cy, cx = h // 2, w // 2
    dist = (xx - cx) ** 2 + (yy - cy) ** 2
    H = 1.0 - np.exp(-(dist / (2.0 * (cutoff ** 2))))
    H = (gamma_h - gamma_l) * H + gamma_l
    rec = np.fft.ifft2(np.fft.ifftshift(dft_shift * H)).real
    out = np.exp(rec)
    out = np.clip(out, 0, None)
    out *= (g0.mean() / (out.mean() + 1e-6))
    lo = np.percentile(out, 3)
    hi = np.percentile(out, 97)
    if hi - lo > 1e-6:
        out = np.clip((out - lo) / (hi - lo), 0, 1)
    out = blend * out + (1.0 - blend) * g0
    return (np.clip(out, 0, 1) * 255).astype(np.uint8)

def guided_smooth(channel, sigma_color=0.05, sigma_spatial=4):
    ch = channel.astype(np.float32) / 255.0
    filt = denoise_bilateral(ch, sigma_color=sigma_color, sigma_spatial=sigma_spatial, channel_axis=None)
    return (np.clip(filt, 0, 1) * 255).astype(np.uint8)

def apply_clahe_channel(channel, clip_limit=0.03):
    cl = exposure.equalize_adapthist(channel.astype(np.float32) / 255.0, clip_limit=clip_limit)
    return (np.clip(cl, 0, 1) * 255).astype(np.uint8)

def retone_channel(channel, target_mean=145.0, target_std=80.0):
    v = channel.astype(np.float32)
    lo = np.percentile(v, 2)
    hi = np.percentile(v, 98)
    if hi - lo > 1e-6:
        v = np.clip((v - lo) / (hi - lo), 0, 1)
    else:
        v = np.clip(v / 255.0, 0, 1)
    mu = float(v.mean())
    sd = float(v.std()) + 1e-6
    a = (target_std / 255.0) / sd
    b = (target_mean / 255.0) - a * mu
    v = np.clip(a * v + b, 0, 1)
    return (v * 255).astype(np.uint8)

def apply_on_value(rgb, op):
    hsv = color.rgb2hsv(rgb.astype(np.float32) / 255.0)
    h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
    v_u8 = (v * 255).astype(np.uint8)
    v_new = op(v_u8)
    hsv_new = np.stack([h, s, v_new.astype(np.float32) / 255.0], axis=-1)
    rgb_new = (color.hsv2rgb(hsv_new) * 255).astype(np.uint8)
    return rgb_new, v_u8, v_new

def overlay_mask(rgb, mask, color=(255, 0, 0), alpha=0.55):
    out = rgb.copy().astype(np.float32)
    mask_bool = mask.astype(bool)
    color_arr = np.array(color, dtype=np.float32)
    out[mask_bool] = alpha * color_arr + (1.0 - alpha) * out[mask_bool]
    return out.astype(np.uint8)

def horizontal_flip(rgb, mask):
    return np.flip(rgb, axis=1).copy(), np.flip(mask, axis=1).copy()


## Raw colonoscopy frame

In [None]:
raw_path = SAMPLE_ROOT / 'preprocess_before.png'
assert raw_path.exists(), 'Sample image is missing.'
raw_rgba = imageio.imread(raw_path)
raw_rgb = raw_rgba[..., :3]
if raw_rgb.dtype != np.uint8:
    raw_rgb = (raw_rgb * 255).astype(np.uint8)

hsv = color.rgb2hsv(raw_rgb.astype(np.float32) / 255.0)
fallback_mask = (hsv[..., 0] < 0.08) & (hsv[..., 1] > 0.35) & (hsv[..., 2] > 0.30)
fallback_mask = morphology.remove_small_objects(fallback_mask, 64)
fallback_mask = morphology.remove_small_holes(fallback_mask, 64)
fallback_mask_u8 = (fallback_mask.astype(np.uint8) * 255)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(raw_rgb)
axes[0].set_title('Raw RGB frame')
axes[0].axis('off')
axes[1].imshow(overlay_mask(raw_rgb, fallback_mask))
axes[1].set_title('Heuristic polyp region (for demo)')
axes[1].axis('off')
plt.tight_layout()
plt.show()


## Stage 1 – Specular highlight suppression

In [None]:
spec_rgb, spec_mask = remove_specular_highlights(raw_rgb)
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
axes[0].imshow(raw_rgb)
axes[0].set_title('Before inpainting')
axes[0].axis('off')
axes[1].imshow(spec_mask, cmap='gray')
axes[1].set_title('Detected highlights')
axes[1].axis('off')
axes[2].imshow(spec_rgb)
axes[2].set_title('After inpainting')
axes[2].axis('off')
plt.tight_layout()
plt.show()


## Stage 2 – Homomorphic illumination correction

In [None]:
homo_rgb, v_before, v_after = apply_on_value(spec_rgb, homomorphic_filter_channel)
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
axes[0].imshow(spec_rgb)
axes[0].set_title('Input to homomorphic filter')
axes[0].axis('off')
axes[1].imshow(v_before, cmap='gray')
axes[1].set_title('Value channel before')
axes[1].axis('off')
axes[2].imshow(homo_rgb)
axes[2].set_title('Value corrected output')
axes[2].axis('off')
plt.tight_layout()
plt.show()


## Stage 3 – Guided smoothing

In [None]:
guided_rgb, _, guided_v = apply_on_value(homo_rgb, guided_smooth)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(homo_rgb)
axes[0].set_title('Before guided filter')
axes[0].axis('off')
axes[1].imshow(guided_rgb)
axes[1].set_title('Edge-preserving smoothing')
axes[1].axis('off')
plt.tight_layout()
plt.show()


## Stage 4 – Local contrast boosting (CLAHE)

In [None]:
clahe_rgb, _, clahe_v = apply_on_value(guided_rgb, apply_clahe_channel)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(guided_rgb)
axes[0].set_title('Before CLAHE')
axes[0].axis('off')
axes[1].imshow(clahe_rgb)
axes[1].set_title('After CLAHE')
axes[1].axis('off')
plt.tight_layout()
plt.show()


## Stage 5 – Retoning for consistent brightness

In [None]:
retone_rgb, _, retone_v = apply_on_value(clahe_rgb, retone_channel)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(clahe_rgb)
axes[0].set_title('Before retone')
axes[0].axis('off')
axes[1].imshow(retone_rgb)
axes[1].set_title('After retone')
axes[1].axis('off')
plt.tight_layout()
plt.show()


## Preprocessing summary

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].imshow(raw_rgb)
axes[0].set_title('Original frame')
axes[0].axis('off')
axes[1].imshow(retone_rgb)
axes[1].set_title('Final preprocessed frame')
axes[1].axis('off')
plt.tight_layout()
plt.show()


## Dataset-style augmentation check

In [None]:
aug_rgb, aug_mask = horizontal_flip(retone_rgb, fallback_mask_u8)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].imshow(overlay_mask(retone_rgb, fallback_mask_u8))
axes[0].set_title('Preprocessed + mask')
axes[0].axis('off')
axes[1].imshow(overlay_mask(aug_rgb, aug_mask))
axes[1].set_title('Horizontal flip (train mode)')
axes[1].axis('off')
plt.tight_layout()
plt.show()


## Training history (pre-computed run)

Stored training artefacts are reused here; the notebook does not rerun optimisation.


In [None]:
history_path = RESULT_ROOT / 'unet' / 'history.json'
summary_path = RESULT_ROOT / 'unet' / 'summary.json'
loss_curve_path = RESULT_ROOT / 'unet' / 'loss_curves.png'
metric_curve_path = RESULT_ROOT / 'unet' / 'metric_curves.png'

with open(history_path) as f:
    history = json.load(f)
with open(summary_path) as f:
    summary = json.load(f)

history_df = pd.DataFrame(history)
summary


In [None]:
display(history_df.head())
fig, ax = plt.subplots(1, 2, figsize=(14, 4))
ax[0].plot(history_df['epoch'], history_df['train_loss'], label='train')
ax[0].plot(history_df['epoch'], history_df['val_loss'], label='val')
ax[0].set_title('BCE+Dice loss')
ax[0].set_xlabel('Epoch')
ax[0].legend()
ax[1].plot(history_df['epoch'], history_df['dice'], label='Dice')
ax[1].plot(history_df['epoch'], history_df['iou'], label='IoU')
ax[1].set_title('Segmentation metrics')
ax[1].set_xlabel('Epoch')
ax[1].legend()
plt.tight_layout()
plt.show()


In [None]:
loss_img = imageio.imread(loss_curve_path)
metric_img = imageio.imread(metric_curve_path)
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
axes[0].imshow(loss_img)
axes[0].set_title('Saved loss curves')
axes[0].axis('off')
axes[1].imshow(metric_img)
axes[1].set_title('Saved metric curves')
axes[1].axis('off')
plt.tight_layout()
plt.show()
print('These PNGs come from the previous training session; the notebook only reads them back.')


## Loss function snapshot

In [None]:
loss_fn = BCEDiceLoss(bce_weight=0.5)
print(loss_fn.__doc__)
logits = torch.randn(1, 1, 4, 4)
targets = torch.randint(0, 2, (1, 1, 4, 4)).float()
print(f'Example loss value: {loss_fn(logits, targets).item():.4f}')


## Model instantiation and checkpoint handling

In [None]:
model = UNet(in_ch=3, out_ch=1, base=32).to(device)
param_count = sum(p.numel() for p in model.parameters())
print(f'U-Net parameters: {param_count:,}')

checkpoint_path = RESULT_ROOT / 'unet' / 'best_unet.pt'
model_loaded = False
if checkpoint_path.exists():
    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state)
    model_loaded = True
    print(f'Loaded checkpoint: {checkpoint_path}')
else:
    print('Checkpoint not found in this repository snapshot. The demo keeps the randomly initialised weights.')
model.eval()


## Inference demo on the sample frame

If the checkpoint file is absent, the visual output falls back to the heuristic mask so the demonstration remains interpretable.


In [None]:
proc_size = 256
zoom_y = proc_size / retone_rgb.shape[0]
zoom_x = proc_size / retone_rgb.shape[1]
resized = ndimage.zoom(retone_rgb, (zoom_y, zoom_x, 1), order=1)
input_tensor = torch.from_numpy(resized.astype(np.float32) / 255.0).permute(2, 0, 1)[None].to(device)
with torch.no_grad():
    logits = model(input_tensor)
    probs = torch.sigmoid(logits)[0, 0].cpu().numpy()
pred_mask = (probs > 0.5).astype(np.uint8)
pred_mask = ndimage.zoom(pred_mask, (retone_rgb.shape[0] / proc_size, retone_rgb.shape[1] / proc_size), order=0)
if not model_loaded:
    print('Checkpoint missing: overlay uses the heuristic mask so that the visualisation stays meaningful.')
    pred_mask = fallback_mask.astype(np.uint8)
overlay_pred = overlay_mask(raw_rgb, (pred_mask * 255).astype(np.uint8))
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].imshow(raw_rgb)
axes[0].set_title('Original frame')
axes[0].axis('off')
axes[1].imshow(retone_rgb)
axes[1].set_title('Preprocessed input')
axes[1].axis('off')
axes[2].imshow(overlay_pred)
axes[2].set_title('Prediction overlay')
axes[2].axis('off')
plt.tight_layout()
plt.show()


## Explainability artefacts (pre-computed Grad-CAM)

In [None]:
xai_dir = RESULT_ROOT / 'xai_visualizations'
cam_sets = {}
for path in xai_dir.glob('unet_cam_*_input.png'):
    key = path.stem.replace('unet_cam_', '').replace('_input', '')
    cam_sets[key] = {'input': path}
for path in xai_dir.glob('unet_cam_*_cam.png'):
    key = path.stem.replace('unet_cam_', '').replace('_cam', '')
    cam_sets.setdefault(key, {})['cam'] = path
for path in xai_dir.glob('unet_cam_*_overlay.png'):
    key = path.stem.replace('unet_cam_', '').replace('_overlay', '')
    cam_sets.setdefault(key, {})['overlay'] = path
keys = sorted(cam_sets.keys())[:3]
fig, axes = plt.subplots(len(keys), 3, figsize=(12, 4 * len(keys)))
if len(keys) == 1:
    axes = axes[None, :]
for row, key in enumerate(keys):
    trio = cam_sets[key]
    axes[row, 0].imshow(imageio.imread(trio['input']))
    axes[row, 0].set_title(f'Input #{key}')
    axes[row, 0].axis('off')
    axes[row, 1].imshow(imageio.imread(trio['cam']))
    axes[row, 1].set_title('Grad-CAM heatmap')
    axes[row, 1].axis('off')
    axes[row, 2].imshow(imageio.imread(trio['overlay']))
    axes[row, 2].set_title('Heatmap overlay')
    axes[row, 2].axis('off')
plt.tight_layout()
plt.show()
print('These Grad-CAM renders were generated offline with the trained checkpoint and stored under notebooks/assets/results/xai_visualizations.')
