### Setup (Colab): install dependencies

Run this cell first in **Google Colab** to install PyTorch, MONAI, and other requirements. Skip if running locally with a configured environment.

**Note:** This notebook is self-contained and does not require the `carotid` folder (e.g. for Colab when only the notebook is uploaded).

# StrokeLink – Carotid segmentation model (notebook)

**StrokeLink** (AI-driven carotid ultrasound for stroke triage in Rwanda). This notebook implements the ML track: Momot (2022) data, CLAHE + DWT preprocessing, Swin-UNETR 2D, IMT measurement. Outputs feed the FastAPI backend and Flutter app for CHW screening in Bumbogo/Gasabo. Clinical high-risk threshold: **IMT ≥ 0.9 mm**.

In [2]:
%pip install -q torch monai opencv-python-headless PyWavelets numpy scikit-learn

Note: you may need to restart the kernel to use updated packages.


## 1. Imports, constants, and in-notebook utils (IMT + data QA)

In [4]:
# All imports and constants (run once)
from pathlib import Path
import numpy as np
import torch
import cv2
import pywt
import matplotlib.pyplot as plt
from typing import Tuple, Optional, List, Dict
from sklearn.model_selection import train_test_split
from monai.networks.nets import SwinUNETR
from monai.losses import DiceCELoss, DiceLoss
from monai.metrics import DiceMetric
from torch.utils.data import TensorDataset, DataLoader

RANDOM_STATE = 42
DATA_DIR = Path("data")
MODEL_DIR = Path("models")
FIGURES_DIR = Path("figures")
IMT_HIGH_RISK_MM = 0.9  # Clinical threshold for stroke risk triage (capstone)
MODEL_DIR.mkdir(exist_ok=True)
FIGURES_DIR.mkdir(exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

ModuleNotFoundError: No module named 'matplotlib'

In [None]:
# --- In-notebook IMT and data QA (no carotid package) ---

# ----- IMT (Intima-Media Thickness) -----
def get_interfaces_from_mask(mask, lumen_label=1, wall_label=2):
    """Lumen-Intima and Media-Adventitia interfaces per column. mask: (H,W) 0=bg, 1=lumen, 2=wall."""
    h, w = mask.shape
    lumen_intima = np.full(w, np.nan)
    media_adventitia = np.full(w, np.nan)
    for x in range(w):
        col = mask[:, x]
        lumen_idx = np.where(col == lumen_label)[0]
        wall_idx = np.where(col == wall_label)[0]
        if len(lumen_idx) and len(wall_idx):
            lumen_center = np.mean(lumen_idx)
            wall_inner = wall_idx[np.argmin(np.abs(wall_idx - lumen_center))]
            lumen_intima[x] = float(wall_inner)
            wall_outer = wall_idx[np.argmax(np.abs(wall_idx - lumen_center))]
            media_adventitia[x] = float(wall_outer)
        elif len(wall_idx) >= 2:
            lumen_intima[x] = float(np.min(wall_idx))
            media_adventitia[x] = float(np.max(wall_idx))
    return lumen_intima, media_adventitia

def imt_pixels_per_column(lumen_intima, media_adventitia):
    """Vertical distance (pixels) between inner and outer wall per column."""
    valid = np.isfinite(lumen_intima) & np.isfinite(media_adventitia)
    thickness = np.abs(media_adventitia - lumen_intima)
    thickness[~valid] = np.nan
    return thickness

def imt_mm_from_mask(mask, spacing_mm_per_pixel, lumen_label=1, wall_label=2):
    """Mean IMT in mm from segmentation mask."""
    li, ma = get_interfaces_from_mask(mask, lumen_label=lumen_label, wall_label=wall_label)
    thickness_px = imt_pixels_per_column(li, ma)
    valid = np.isfinite(thickness_px)
    if not np.any(valid):
        return np.nan
    return float(np.nanmean(thickness_px) * spacing_mm_per_pixel)

def imt_mae_mm(pred_masks, gt_masks, spacing_mm_per_pixel, lumen_label=1, wall_label=2):
    """Mean Absolute Error of IMT (mm) across batch. pred_masks, gt_masks: (N,H,W)."""
    n = pred_masks.shape[0]
    errors = []
    for i in range(n):
        pred_imt = imt_mm_from_mask(pred_masks[i], spacing_mm_per_pixel, lumen_label=lumen_label, wall_label=wall_label)
        gt_imt = imt_mm_from_mask(gt_masks[i], spacing_mm_per_pixel, lumen_label=lumen_label, wall_label=wall_label)
        if np.isfinite(pred_imt) and np.isfinite(gt_imt):
            errors.append(abs(pred_imt - gt_imt))
    return float(np.mean(errors)) if errors else np.nan

# ----- Data QA -----
def validate_image_readable(path):
    """Check if image loads. Returns (ok, error_msg)."""
    try:
        img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
        if img is None:
            img = cv2.imread(str(path))
            if img is None:
                return False, "Failed to load image"
            img = img[:, :, 0] if img.ndim == 3 else img
        if img.size == 0 or img.ndim != 2:
            return False, f"Invalid shape: {getattr(img, 'shape', 'unknown')}"
        if np.all(img == img.flat[0]):
            return False, "Image is constant (possibly corrupted)"
        return True, None
    except Exception as e:
        return False, str(e)

def validate_mask_readable(path):
    """Check if mask loads. Returns (ok, error_msg)."""
    try:
        mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            mask = cv2.imread(str(path))
            if mask is None:
                return False, "Failed to load mask"
            mask = mask[:, :, 0] if mask.ndim == 3 else mask
        if mask.size == 0 or mask.ndim != 2:
            return False, f"Invalid shape: {getattr(mask, 'shape', 'unknown')}"
        return True, None
    except Exception as e:
        return False, str(e)

def check_mask_consistency(mask, img_shape=None, min_coverage_pct=0.001, max_coverage_pct=0.95):
    """Non-empty, reasonable coverage, optional shape match. Returns (ok, error_msg)."""
    h, w = mask.shape
    if img_shape is not None and (h, w) != img_shape:
        return False, f"Mask shape {mask.shape} != image shape {img_shape}"
    foreground = np.sum(mask > 127) if mask.dtype in (np.float32, np.float64) else np.sum(mask > 0)
    total = h * w
    coverage = foreground / total
    if coverage < min_coverage_pct:
        return False, f"Mask nearly empty (coverage={coverage:.4f})"
    if coverage > max_coverage_pct:
        return False, f"Mask nearly full (coverage={coverage:.4f})"
    return True, None

def validate_pair(img_path, mask_path, min_coverage_pct=0.001, max_coverage_pct=0.95, require_shape_match=True):
    """Validate image/mask pair. Returns (ok, error_msg)."""
    ok, err = validate_image_readable(img_path)
    if not ok:
        return False, f"Image: {err}"
    img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
    if img is None:
        img = cv2.imread(str(img_path))[:, :, 0]
    img_shape = img.shape[:2]
    ok, err = validate_mask_readable(mask_path)
    if not ok:
        return False, f"Mask: {err}"
    mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
    if mask is None:
        mask = cv2.imread(str(mask_path))[:, :, 0]
    ok, err = check_mask_consistency(mask, img_shape=img_shape if require_shape_match else None,
                                    min_coverage_pct=min_coverage_pct, max_coverage_pct=max_coverage_pct)
    if not ok:
        return False, err
    return True, None

def filter_and_flag_pairs(pairs, min_coverage_pct=0.001, max_coverage_pct=0.95):
    """Validate all pairs; return (valid_pairs, flagged_list)."""
    valid, flagged = [], []
    for img_path, mask_path in pairs:
        ok, err = validate_pair(img_path, mask_path, min_coverage_pct=min_coverage_pct, max_coverage_pct=max_coverage_pct)
        if ok:
            valid.append((img_path, mask_path))
        else:
            flagged.append({"img": img_path, "mask": mask_path, "reason": err})
    return valid, flagged

print("IMT and data QA helpers defined (in-notebook).")

## 2. Data Preprocessing & Enhancement (definition)

**Use this section.** It defines the single preprocessing pipeline. Section 3c only *applies* this when loading images.

Since EDA showed mean intensities as low as ~13 (very dark), raw images are hard for the model to read. We apply:

- **CLAHE (Contrast Limited Adaptive Histogram Equalization):** Redistributes pixel intensities so carotid wall boundaries are visible.
- **DWT (Discrete Wavelet Transform) denoising:** Reduces ultrasound speckle while preserving sharp Intima–Media edges.
- **Normalization:** Rescale pixel values to $[0, 1]$ so gradient descent converges faster.

The code below defines the `cleaner` object; Section 3c uses `cleaner` when building the train/val/test arrays.

In [None]:
# Preprocessing: CLAHE + DWT; output in [0, 1]
class MedicalDataCleaner:
    """CLAHE for contrast + DWT denoising for speckle (same as carotid/preprocessing.py)."""
    def __init__(self, clahe_clip_limit=2.0, clahe_grid_size=(8, 8), dwt_wavelet="db4", dwt_level=2, dwt_mode="soft", dwt_threshold_scale=1.0):
        self.clahe_clip_limit = clahe_clip_limit
        self.clahe_grid_size = clahe_grid_size
        self.dwt_wavelet = dwt_wavelet
        self.dwt_level = dwt_level
        self.dwt_mode = dwt_mode
        self.dwt_threshold_scale = dwt_threshold_scale

    def _clahe(self, img):
        img = np.asarray(img, dtype=np.float64)
        if img.max() > 1.0:
            img = img / (img.max() + 1e-8)
        img_uint8 = (np.clip(img, 0, 1) * 255).astype(np.uint8)
        clahe = cv2.createCLAHE(clipLimit=self.clahe_clip_limit, tileGridSize=self.clahe_grid_size)
        out = clahe.apply(img_uint8) if img_uint8.ndim == 2 else cv2.cvtColor(img_uint8, cv2.COLOR_RGB2LAB)
        if img_uint8.ndim == 3:
            out[..., 0] = clahe.apply(out[..., 0])
            out = cv2.cvtColor(out, cv2.COLOR_LAB2RGB)
        return out.astype(np.float64) / 255.0

    def _dwt_denoise_2d(self, img):
        coeffs = pywt.wavedec2(img, self.dwt_wavelet, level=self.dwt_level)
        cA = coeffs[0]
        detail_list = list(coeffs[1:])
        sigma = np.median(np.abs(cA)) / 0.6745 if cA.size else 1.0
        thresh = self.dwt_threshold_scale * sigma * np.sqrt(2 * np.log(cA.size + 1e-8))
        detail_list = [tuple(pywt.threshold(d, thresh, mode=self.dwt_mode) if d is not None else None for d in level) for level in detail_list]
        return pywt.waverec2([cA] + detail_list, self.dwt_wavelet)[: img.shape[0], : img.shape[1]]

    def _dwt_denoise(self, img):
        img = np.asarray(img, dtype=np.float64)
        if img.ndim == 3:
            return np.stack([self._dwt_denoise_2d(img[..., c]) for c in range(img.shape[-1])], axis=-1)
        return self._dwt_denoise_2d(img)

    def __call__(self, img, apply_clahe=True, apply_dwt=True):
        out = np.asarray(img, dtype=np.float64)
        if out.max() > 1.0:
            out = out / (out.max() + 1e-8)
        if apply_clahe:
            out = self._clahe(out)
        if apply_dwt:
            out = self._dwt_denoise(out)
        return np.clip(out, 0, 1).astype(np.float32)

cleaner = MedicalDataCleaner(clahe_clip_limit=2.0, clahe_grid_size=(8, 8), dwt_wavelet="db4", dwt_level=2)

# Demo on a synthetic 2D "ultrasound-like" image
np.random.seed(RANDOM_STATE)
h, w = 128, 224
fake_img = np.random.rand(h, w).astype(np.float32) * 0.5 + 0.25
fake_img[40:60, :] += 0.3  # simulate wall
cleaned = cleaner(fake_img, apply_clahe=True, apply_dwt=True)
print(f"Original shape: {fake_img.shape}, cleaned shape: {cleaned.shape}, range: [{cleaned.min():.3f}, {cleaned.max():.3f}]")

#### Impact of CLAHE and DWT on image noise

**CLAHE (Contrast Limited Adaptive Histogram Equalization):** While primarily designed for contrast enhancement, CLAHE can have a mixed impact on noise. By increasing the contrast, especially in previously dark or uniform regions, it can sometimes make existing noise more visible. However, its *contrast-limited* aspect helps prevent excessive amplification of noise in very homogeneous areas. Its main goal is not noise reduction, but rather making features more distinguishable.

**DWT (Discrete Wavelet Transform) denoising:** This technique is specifically applied for noise reduction. Ultrasound images are notoriously plagued by *speckle noise*, a granular noise that degrades image quality. DWT denoising works by transforming the image into wavelet coefficients, where noise (typically high-frequency details) can be identified and suppressed by thresholding. By applying a soft or hard threshold to these detail coefficients, the algorithm reduces speckle while aiming to preserve important structural edges (e.g. the intima–media layer).

**Overall impact on noise:** The combination of CLAHE and DWT typically yields images that are both clearer in contrast and smoother, with less noise. The DWT component actively suppresses characteristic speckle, leading to a cleaner image for analysis or model training. This noise reduction is important because excessive noise can obscure features and hurt segmentation accuracy.

## 3. Data: Momot (2022) – Common Carotid Artery Ultrasound

**(Colab)** Mount Drive and unzip the dataset — run the cell below in Colab to load the data.

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!unzip -o "/content/drive/MyDrive/Common Carotid Artery Ultrasound Images.zip" -d "/content/data"

### 3a. Find pairs, data cleaning, and split

In [None]:
DATA_ROOT = Path("/content/data") if Path("/content/data").exists() else Path("data")

def find_image_mask_pairs(root, exts=(".png", ".jpg", ".jpeg")):
    root = Path(root)
    images = [p for p in root.rglob("*") if p.suffix.lower() in exts and "mask" not in p.name.lower()]
    pairs = []
    for img_path in images:
        for mask_dir in ("Masks", "masks", "Labels", "labels", "Mask", "mask"):
            mask_path = root / mask_dir / img_path.name
            if mask_path.exists():
                pairs.append((str(img_path), str(mask_path)))
                break
        else:
            stem, suf = img_path.stem, img_path.suffix
            mask_path = img_path.parent / f"{stem}_mask{suf}"
            if mask_path.exists():
                pairs.append((str(img_path), str(mask_path)))
    return pairs

pairs = find_image_mask_pairs(DATA_ROOT)
if not pairs:
    raise FileNotFoundError(f"No image/mask pairs under {DATA_ROOT}.")
print(f"Found {len(pairs)} pairs")

valid_pairs, flagged = filter_and_flag_pairs(pairs, min_coverage_pct=0.001, max_coverage_pct=0.95)
pairs = valid_pairs
if flagged:
    print(f"Flagged {len(flagged)} (removed):")
    for f in flagged[:5]:
        print(f"  - {Path(f['img']).name}: {f['reason']}")
print(f"Valid: {len(pairs)}")

# 70/15/15 split (Train / Val / Test)
train_pairs, rest_pairs = train_test_split(pairs, test_size=0.30, random_state=RANDOM_STATE)
val_pairs, test_pairs = train_test_split(rest_pairs, test_size=0.50, random_state=RANDOM_STATE)
print(f"Train: {len(train_pairs)}, Val: {len(val_pairs)}, Test: {len(test_pairs)}")

### 3b. Phase One: Raw Discovery EDA

**When:** Immediately after downloading the Momot (2022) dataset, before any model or preprocessing code.

**Goal:** Understand what you are working with and catch "garbage" data.

1. **Label leakage:** Ensure there is no text on the ultrasound images (e.g. patient names, hospital tags) that the model could use to "cheat."
2. **Outlier detection:** Flag images that are too bright, too dark, or have severe motion blur.
3. **Class imbalance:** Check normal vs high-risk (IMT ≥ 0.9 mm). If severely imbalanced (e.g. 2000 normal vs 200 high-risk), use oversampling or reweighting.

In [None]:
# --- Phase One Raw Discovery: use all valid pairs (train+val+test) before preprocessing ---
all_pairs = train_pairs + val_pairs + test_pairs
n_total = len(all_pairs)
spacing_eda = 0.04

# Build per-image stats: brightness, blur (Laplacian variance), shape, coverage, IMT
brightness, blur_score, shapes_list, coverages, imts_list = [], [], [], [], []
for img_path, mask_path in all_pairs:
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        img = cv2.imread(img_path)[:, :, 0]
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if mask is None:
        mask = cv2.imread(mask_path)[:, :, 0]
    shapes_list.append(img.shape)
    gray = np.asarray(img, dtype=np.float32) / (np.max(img) + 1e-8)
    brightness.append(np.mean(gray))
    blur_score.append(cv2.Laplacian(img, cv2.CV_64F).var())  # low = blurry
    mask_bin = (mask > 127).astype(np.float32)
    coverages.append(np.mean(mask_bin))
    try:
        imt = imt_mm_from_mask(mask_bin.astype(np.int32), spacing_eda, lumen_label=2, wall_label=1)
        imts_list.append(imt if np.isfinite(imt) else np.nan)
    except Exception:
        imts_list.append(np.nan)

brightness = np.array(brightness)
blur_score = np.array(blur_score)
imts_arr = np.array(imts_list)
imts_valid = imts_arr[np.isfinite(imts_arr)]

print("=== 1. Label leakage (manual check) ===")
print("Inspect samples below for text/labels (patient name, hospital, dates). Remove or mask such images.")
n_show = min(9, n_total)
idx_show = np.random.RandomState(RANDOM_STATE).choice(n_total, n_show, replace=False)
fig1, ax1 = plt.subplots(3, 3, figsize=(9, 9))
for k, i in enumerate(idx_show):
    img = cv2.imread(all_pairs[i][0], cv2.IMREAD_GRAYSCALE)
    if img is None:
        img = cv2.imread(all_pairs[i][0])[:, :, 0]
    ax1.flat[k].imshow(img, cmap='gray')
    ax1.flat[k].set_title(Path(all_pairs[i][0]).name[:18])
    ax1.flat[k].axis('off')
plt.suptitle('Label leakage check: look for text on images'); plt.tight_layout()
FIGURES_DIR.mkdir(exist_ok=True)
plt.savefig(FIGURES_DIR / 'eda_label_leakage_check.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n=== 2. Outlier detection (brightness & motion blur) ===")
b_lo, b_hi = np.percentile(brightness, 2), np.percentile(brightness, 98)
blur_lo = np.percentile(blur_score[blur_score > 0], 5) if np.any(blur_score > 0) else 0
too_dark = np.where(brightness < b_lo)[0]
too_bright = np.where(brightness > b_hi)[0]
too_blurry = np.where(blur_score < blur_lo)[0] if blur_lo > 0 else np.array([])
print(f"Brightness: median={np.median(brightness):.3f}, range [{b_lo:.3f}, {b_hi:.3f}]")
print(f"Blur (Laplacian var): median={np.median(blur_score):.1f}; low = blurry")
print(f"Flagged: too dark {len(too_dark)}, too bright {len(too_bright)}, too blurry {len(too_blurry)}")
if len(too_dark) + len(too_bright) + len(too_blurry) > 0:
    outlier_idx = np.unique(np.concatenate([too_dark[:3], too_bright[:3], too_blurry[:3]]))[:6]
    fig2, ax2 = plt.subplots(2, 3, figsize=(9, 6))
    for k, i in enumerate(outlier_idx):
        if i >= n_total:
            break
        img = cv2.imread(all_pairs[i][0], cv2.IMREAD_GRAYSCALE)
        if img is None:
            img = cv2.imread(all_pairs[i][0])[:, :, 0]
        ax2.flat[k].imshow(img, cmap='gray')
        ax2.flat[k].set_title(f"b={brightness[i]:.2f} blur={blur_score[i]:.0f}")
        ax2.flat[k].axis('off')
    plt.suptitle('Sample outliers (brightness / blur)'); plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'eda_outliers.png', dpi=150, bbox_inches='tight')
    plt.show()

print("\n=== 3. Class imbalance (normal vs high-risk by IMT) ===")
normal = np.sum(imts_valid < IMT_HIGH_RISK_MM)
high_risk = np.sum(imts_valid >= IMT_HIGH_RISK_MM)
print(f"IMT (mm): min={np.nanmin(imts_arr):.3f}, max={np.nanmax(imts_arr):.3f}, mean={np.nanmean(imts_valid):.3f}")
print(f"Normal (IMT < {IMT_HIGH_RISK_MM} mm): {normal}  |  High-risk (IMT >= {IMT_HIGH_RISK_MM} mm): {high_risk}")
ratio = max(normal, high_risk) / (min(normal, high_risk) + 1e-8)
if ratio > 3:
    print(f">> Severe imbalance (ratio {ratio:.1f}:1). Consider oversampling or class weights.")
else:
    print(">> Balance is acceptable.")

print("\n=== Summary: shapes & mask coverage ===")
print(f"Shapes: {set(shapes_list)}")
print(f"Mask coverage: min={min(coverages):.4f}, max={max(coverages):.4f}, mean={np.mean(coverages):.4f}")
fig3, ax3 = plt.subplots(2, 3, figsize=(9, 6))
sample_idx = np.random.RandomState(RANDOM_STATE).choice(n_total, min(6, n_total), replace=False)
for k, i in enumerate(sample_idx):
    img = cv2.imread(all_pairs[i][0], cv2.IMREAD_GRAYSCALE)
    mask = cv2.imread(all_pairs[i][1], cv2.IMREAD_GRAYSCALE)
    if img is None:
        img = cv2.imread(all_pairs[i][0])[:, :, 0]
    if mask is None:
        mask = cv2.imread(all_pairs[i][1])[:, :, 0]
    ax3.flat[k].imshow(img, cmap='gray')
    ax3.flat[k].contour((mask > 127).astype(float), levels=[0.5], colors=['red'])
    ax3.flat[k].set_title(Path(all_pairs[i][0]).name[:20])
    ax3.flat[k].axis('off')
plt.suptitle('Sample images + mask contours'); plt.tight_layout()
plt.savefig(FIGURES_DIR / 'eda_samples.png', dpi=150, bbox_inches='tight')
plt.show()

#### Raw Discovery EDA — Findings & Next Steps

Our Raw Discovery EDA focused on understanding the dataset's characteristics before model development. We found:

- **No apparent label leakage:** Visual inspection of sample images showed no patient IDs or other text that could bias the model, suggesting a clean dataset.

- **Generally dark images & consistent shapes:** Mean pixel intensities ranged from ~13 to ~71, indicating that the images are predominantly dark. All sampled images consistently maintained a shape of **(749, 709)**.

- **Severe class imbalance:** A critical finding is the extreme class imbalance — in the sampled set, 100% of images were categorized as **high-risk** (IMT ≥ 0.9 mm), with 0% **normal**. This is clearly visible in the IMT distribution from the EDA outputs.

**Next steps** derived from this EDA:

1. **Address severe class imbalance** (e.g. oversampling, class weights, or collecting more diverse data) so the model does not collapse to predicting only high-risk.
2. **Use image preprocessing** such as contrast enhancement (e.g. **CLAHE**, already applied in Section 2) to improve visibility in these generally dark images.

### 3c. Data loading (applies Section 2 preprocessing)

**Use this section for loading data.** It uses the **same** preprocessing as Section 2 (the `cleaner`): CLAHE + DWT + resize + normalize to $[0,1]$. There is only one preprocessing pipeline; Section 2 defines it, this cell applies it.

Pipeline: load image/mask → apply `cleaner` (CLAHE + DWT from Section 2) → resize to `IMG_SIZE` → output in $[0,1]$ (image and mask transforms mirrored).

In [None]:
# Apply Section 2 preprocessing (cleaner) + resize + [0,1]. Run Section 2 first to define 'cleaner'.
IMG_SIZE = (224, 224)  # use (512, 512) for full Swin if GPU memory allows

def load_and_prepare(img_path, mask_path, cleaner, size=IMG_SIZE):
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        img = cv2.imread(img_path)[:, :, 0]
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if mask is None:
        mask = cv2.imread(mask_path)[:, :, 0]
    img = img.astype(np.float32) / (np.max(img) + 1e-8)
    img = cleaner(img, apply_clahe=True, apply_dwt=True)
    img = cv2.resize(img, (size[1], size[0]), interpolation=cv2.INTER_LINEAR)
    mask = cv2.resize(mask, (size[1], size[0]), interpolation=cv2.INTER_NEAREST)
    mask = (mask > 127).astype(np.float32)
    return img[None], mask[None]  # (1,H,W), (1,H,W)

def load_split(pairs, cleaner, size=IMG_SIZE):
    X, y = [], []
    for img_path, mask_path in pairs:
        xi, yi = load_and_prepare(img_path, mask_path, cleaner, size)
        X.append(xi)
        y.append(yi)
    return np.stack(X, axis=0).astype(np.float32), np.stack(y, axis=0).astype(np.float32)

X_train, y_train = load_split(train_pairs, cleaner)
X_val, y_val = load_split(val_pairs, cleaner)
X_test, y_test = load_split(test_pairs, cleaner)
spacing_mm = 0.04  # mm/pixel (typical US; use metadata if available)
print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}, spacing_mm_per_pixel: {spacing_mm}")

## 4. Fine-tuning

Now that the data and masks are **technically perfect** (cleaned, preprocessed, and split), we proceed to **fine-tuning**: initialize the model (or load a pretrained checkpoint) and train it on the carotid dataset. The steps below define the architecture, optionally load pretrained weights, then run the training and validation loop.

### 4a. Model initialization (Swin-UNETR)

Architecture from MONAI. Key hyperparameters: **feature_size** (e.g. 48), **patch_size** (usually 2 for Swin), **num_heads** (e.g. 12). Optionally load pretrained weights for fine-tuning (e.g. from a medical imaging checkpoint).

In [None]:
img_size = IMG_SIZE
model = SwinUNETR(
    img_size=img_size,
    in_channels=1,
    out_channels=2,
    spatial_dims=2,
    use_checkpoint=True,
    feature_size=48,
    num_heads=(3, 6, 12, 24),
    patch_size=2,
    window_size=7,
).to(device)
pretrained_path = MODEL_DIR / "model_swinvit.pt"
if pretrained_path.exists():
    state = torch.load(pretrained_path, map_location=device)
    model.load_state_dict(state.get("model", state), strict=False)
    print("Loaded pretrained weights for fine-tuning.")
print(f"Swin-UNETR 2D model, params: {sum(p.numel() for p in model.parameters()):,}")

### 4b. Training & validation loop (fine-tuning)

**Addressing the 100% high-risk imbalance:**
- **Weighted loss:** DiceCELoss (Dice + Cross-Entropy) focuses the model on the narrow carotid wall rather than the vast background.
- **Augmentation:** Elastic deformations (e.g. Rand2DElasticd) simulate probe pressure and increase effective data variety.

**Loop:** 70/15/15 split; **validation metric** = Mean Dice Score; **early stopping** when validation loss stops improving to avoid overfitting.

**If the channel diagnostic showed correct order but Dice ~0.04:** Pred and GT both have ~2.5% foreground — the model predicts the right *proportion* but in the *wrong places* (spatial mismatch). The loss below uses **Dice + weighted Cross-Entropy** (foreground weight = 15) so the model is pushed to put probability on the wall pixels, not just anywhere.

In [None]:
# Foreground-weighted loss: GT has ~2% foreground; weight CE so model focuses on wall pixels (not just proportion).
FOREGROUND_WEIGHT = 15.0  # increase (e.g. 20) if Dice still low
dice_loss_fn = DiceLoss(include_background=False, softmax=True, to_onehot_y=True)
ce_loss_fn = torch.nn.CrossEntropyLoss(weight=torch.tensor([1.0, FOREGROUND_WEIGHT], dtype=torch.float32, device=device))
def criterion(out, target):
    return dice_loss_fn(out, target) + ce_loss_fn(out, target)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
max_epochs = 50
total_steps = max(1, len(train_pairs) // 4) * max_epochs
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6)
train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train.squeeze(1).astype(np.int64)))
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)

dice_metric = DiceMetric(include_background=False, reduction="mean")
train_loss_history, val_dice_history = [], []
best_val_dice, best_epoch, patience_counter = 0.0, 0, 0
best_state = None
early_stop_patience = 10

for ep in range(max_epochs):
    model.train()
    total_loss = 0.0
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        optimizer.zero_grad()
        out = model(batch_x)
        loss = criterion(out, batch_y)
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    mean_loss = total_loss / len(train_loader)
    train_loss_history.append(mean_loss)

    model.eval()
    dice_metric.reset()
    with torch.no_grad():
        pred = model(torch.from_numpy(X_val).to(device))
        pred_soft = torch.softmax(pred, dim=1)
        # MONAI DiceMetric(include_background=False): Channel 0 = Background, Channel 1 = Foreground (IMT).
        # Labels: 0=bg, 1=fg. one_hot -> (N,H,W,2); permute -> (N,2,H,W). Must match pred_soft (N,2,H,W).
        y_val_labels = torch.from_numpy(y_val.squeeze(1).astype(np.int64)).to(device)
        y_onehot = torch.nn.functional.one_hot(y_val_labels, num_classes=2).permute(0, 3, 1, 2).float()
        # If your Dice is stuck ~0.04, try swapping channels (model may output fg in channel 0):
        # pred_soft = pred_soft[:, [1, 0], :, :]
        if ep == 0:
            print(f"  [debug] pred_soft.shape={tuple(pred_soft.shape)}, y_onehot.shape={tuple(y_onehot.shape)}")
            print(f"  [debug] pred foreground (ch1) mean={pred_soft[:, 1].mean().item():.4f}, GT foreground mean={y_onehot[:, 1].mean().item():.4f}")
        dice_metric(y_pred=pred_soft, y=y_onehot)
    val_dice = dice_metric.aggregate().item()
    val_dice_history.append(val_dice)

    if val_dice > best_val_dice:
        best_val_dice = val_dice
        best_epoch = ep + 1
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        patience_counter = 0
    else:
        patience_counter += 1
    print(f"Epoch {ep+1}/{max_epochs}  loss: {mean_loss:.4f}  val Dice: {val_dice:.4f}")
    if patience_counter >= early_stop_patience:
        print(f"Early stopping at epoch {ep+1} (no improvement for {early_stop_patience} epochs). Best val Dice: {best_val_dice:.4f} at epoch {best_epoch}")
        break

if best_state is not None:
    model.load_state_dict(best_state)
    model.to(device)
print("Training done.")

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_loss_history) + 1), train_loss_history, 'b-', label='Train loss')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.plot(range(1, len(val_dice_history) + 1), val_dice_history, 'g-', label='Val Dice')
plt.axhline(best_val_dice, color='gray', linestyle='--', label=f'Best {best_val_dice:.3f}')
plt.xlabel('Epoch'); plt.ylabel('Dice'); plt.legend(); plt.grid(True, alpha=0.3)
plt.suptitle('StrokeLink – training & validation')
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

#### Quick channel diagnostic (run after at least one training epoch)

If you see **val Dice stuck at 0.0454**, run the cell below. It prints prediction vs GT channel means and **Dice with and without channel swap**. If "Dice (swapped)" is much higher than "Dice (current)", the model outputs foreground in channel 0 — uncomment the swap in the training loop and re-train.

In [None]:
# Run once after training to check channel order (no need to re-train yet)
from monai.metrics import DiceMetric
model.eval()
with torch.no_grad():
    pred = model(torch.from_numpy(X_val).to(device))
    pred_soft = torch.softmax(pred, dim=1)
    y_val_labels = torch.from_numpy(y_val.squeeze(1).astype(np.int64)).to(device)
    y_onehot = torch.nn.functional.one_hot(y_val_labels, num_classes=2).permute(0, 3, 1, 2).float()
print(f"pred_soft.shape = {tuple(pred_soft.shape)},  y_onehot.shape = {tuple(y_onehot.shape)}")
print(f"Pred channel 0 (bg) mean = {pred_soft[:, 0].mean().item():.4f},  channel 1 (fg) mean = {pred_soft[:, 1].mean().item():.4f}")
print(f"GT   channel 0 (bg) mean = {y_onehot[:, 0].mean().item():.4f},  channel 1 (fg) mean = {y_onehot[:, 1].mean().item():.4f}")
dice_metric = DiceMetric(include_background=False, reduction="mean")
dice_metric(y_pred=pred_soft, y=y_onehot)
dice_current = dice_metric.aggregate().item()
dice_metric.reset()
dice_metric(y_pred=pred_soft[:, [1, 0], :, :], y=y_onehot)  # swap channels
dice_swapped = dice_metric.aggregate().item()
print(f"Dice (current) = {dice_current:.4f}   |   Dice (swapped ch0<->ch1) = {dice_swapped:.4f}")
if dice_swapped > dice_current * 1.5:
    print(">> Swap gives much higher Dice → Uncomment 'pred_soft = pred_soft[:, [1, 0], :, :]' in the training loop and re-run training.")
else:
    print(">> Swap did not help → Channel order is likely correct; low Dice is from imbalance/small structure or learning.")

# Interpretation: Pred and GT both ~2.5% foreground → model predicts right proportion but in WRONG places (spatial mismatch).
# Fix: Use a foreground-weighted loss (see training cell) so the model is pushed to put probability on the wall, not just anywhere.

#### Analysis of training and validation results

**Observations:** Training loss typically decreases (e.g. from ~0.06 in Epoch 1 to ~0.02 by Epoch 11), while **validation Dice can remain very low** (e.g. ~0.05) across epochs. Early stopping then halts training after no improvement for several epochs. This indicates the model is fitting the training set but **not generalizing** to the validation set.

**Potential causes:**

- **"Ghost background" / channel mismatch (very common):** A constant Dice ~0.0454 often means a **metric configuration error**, not total failure. MONAI's `DiceMetric(include_background=False)` assumes **Channel 0 = Background, Channel 1 = Foreground**. If the model's outputs are flipped, or one-hot labels don't match (same shape, channel 1 = foreground), the metric can compare the wrong channels and yield a near-constant low value. **Fix:** Check the first-epoch debug prints (`pred_soft.shape`, `y_onehot.shape`, pred foreground mean). If pred foreground mean is ~0, try swapping channels: `pred_soft = pred_soft[:, [1, 0], :, :]` before calling the metric.
- **Severe class imbalance:** If validation masks have very few foreground pixels, Dice is unstable and a small misalignment gives ~0.05. Inspect the validation prediction plots (cell below) to see if the model outputs a "blob" that is slightly off the wall.
- **Model / loss:** Consider loss weighting, learning rate, or optimizer if channels are correct and the model still doesn't generalize.

**Next steps:** (1) Run the cell below to inspect validation data and **visualize predictions** (input | GT mask | pred foreground prob). (2) If pred foreground mean is ~0 in the training debug print, uncomment the channel swap in the training loop and re-run. (3) Tune loss or hyperparameters as needed.

In [None]:
# Inspect validation data: foreground fraction and sample images (run after training)
val_masks = y_val.squeeze(1)  # (N, H, W)
foreground_frac = (val_masks > 0).astype(np.float32).reshape(val_masks.shape[0], -1).mean(axis=1)
print(f"Validation: {val_masks.shape[0]} samples. Foreground (wall) fraction per sample: min={foreground_frac.min():.4f}, max={foreground_frac.max():.4f}, mean={foreground_frac.mean():.4f}")
if foreground_frac.mean() < 0.01:
    print(">> Very few foreground pixels in validation — low Dice is expected; consider class weights or inspecting mask encoding.")
n_show = min(4, len(X_val))
fig, ax = plt.subplots(2, n_show, figsize=(2 * n_show, 4))
for i in range(n_show):
    ax[0, i].imshow(X_val[i].squeeze(), cmap='gray')
    ax[0, i].set_title(f"Val {i+1} fg={foreground_frac[i]:.3f}")
    ax[0, i].axis('off')
    ax[1, i].imshow(y_val[i].squeeze(), cmap='viridis')
    ax[1, i].set_title('Mask')
    ax[1, i].axis('off')
plt.suptitle('Validation samples: image (top) and mask (bottom)')
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'validation_inspection.png', dpi=150, bbox_inches='tight')
plt.show()

# Visualize model predictions vs GT (detect channel mismatch or "blob slightly off")
model.eval()
with torch.no_grad():
    pred = model(torch.from_numpy(X_val[:n_show]).to(device))
    pred_soft = torch.softmax(pred, dim=1)
    pred_fg = pred_soft[:, 1].cpu().numpy()  # foreground channel
    pred_class = pred_soft.argmax(dim=1).cpu().numpy()
fig2, ax2 = plt.subplots(3, n_show, figsize=(2 * n_show, 6))
for i in range(n_show):
    ax2[0, i].imshow(X_val[i].squeeze(), cmap='gray')
    ax2[0, i].set_title('Input')
    ax2[0, i].axis('off')
    ax2[1, i].imshow(y_val[i].squeeze(), cmap='viridis')
    ax2[1, i].set_title('GT mask')
    ax2[1, i].axis('off')
    ax2[2, i].imshow(pred_fg[i], cmap='hot', vmin=0, vmax=1)
    ax2[2, i].set_title(f'Pred fg prob (ch1) mean={pred_fg[i].mean():.3f}')
    ax2[2, i].axis('off')
plt.suptitle('Validation: input | GT mask | predicted foreground probability (channel 1)')
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'validation_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Evaluation on test set: Dice + IMT MAE (mm)

In [None]:
model.eval()
dice_metric = DiceMetric(include_background=False, reduction="mean")
with torch.no_grad():
    pred = model(torch.from_numpy(X_test).to(device))
    pred_soft = torch.softmax(pred, dim=1)
    # Channel 0=bg, 1=fg. Match MONAI DiceMetric(include_background=False). If Dice stuck ~0.04, try: pred_soft = pred_soft[:, [1, 0], :, :]
    y_test_labels = torch.from_numpy(y_test.squeeze(1).astype(np.int64)).to(device)
    y_onehot = torch.nn.functional.one_hot(y_test_labels, num_classes=2).permute(0, 3, 1, 2).float()
    dice_metric(y_pred=pred_soft, y=y_onehot)
    pred_class = pred_soft.argmax(dim=1).cpu().numpy()
test_dice = dice_metric.aggregate().item()
imt_mae = imt_mae_mm(pred_class, y_test.squeeze(1).astype(np.int32), spacing_mm, lumen_label=2, wall_label=1)
imt_str = f"{imt_mae:.4f}" if np.isfinite(imt_mae) else "N/A"
print(f"Test Dice: {test_dice:.4f}  |  IMT MAE (mm): {imt_str}")
print(f"StrokeLink triage: IMT ≥ {IMT_HIGH_RISK_MM} mm = high risk (refer to Gasabo District)")

## 6. Save model for the app

In [None]:
model_path = MODEL_DIR / "carotid_swin_unetr_2d.pt"
torch.save({
    "model": model.state_dict(), "img_size": img_size, "in_channels": 1, "out_channels": 2,
    "imt_high_risk_mm": IMT_HIGH_RISK_MM, "spacing_mm_per_pixel": spacing_mm
}, model_path)
print(f"Model saved to {model_path} (for FastAPI + Flutter app)")