# Checkpoint and Image Inference UI

This notebook lets you choose a model checkpoint (`.pt`) and an image, run inference, and visualize detections/masks.

You can provide files either by:
- text path input (for `best.pt` or any local file), or
- upload widget.


In [None]:
import io
import math
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
from PIL import Image

from apex_x.config import ApexXConfig
from apex_x.model import (
    TeacherModel,
    TeacherModelV3,
    PVModule,
    DualPathFPN,
    DetHead,
    TimmBackboneAdapter,
    post_process_detections_per_class,
)
from apex_x.model.image_enhancer import LearnableImageEnhancer
from apex_x.model.pv_dinov2 import PVModuleDINOv2
from apex_x.train.checkpoint import extract_model_state_dict, is_tensor_state_dict, safe_torch_load
from apex_x.model.worldclass_deps import missing_worldclass_dependencies, worldclass_install_hint


def _robust_reshape_vit_output(self, x, image_hw=None):
    if x.ndim != 3:
        raise ValueError('ViT hidden-state must be [B, N_tokens, D]')

    bsz, token_count, dim = x.shape
    patch_size = max(1, int(getattr(self, 'patch_size', 14)))

    num_register_tokens = 0
    try:
        num_register_tokens = int(getattr(self.dinov2.config, 'num_register_tokens', 0) or 0)
    except Exception:
        num_register_tokens = 0

    candidates = [1]
    if num_register_tokens > 0:
        candidates.append(1 + num_register_tokens)
    candidates.append(0)

    special = None
    if image_hw is not None:
        exp_h = max(1, int(image_hw[0]) // patch_size)
        exp_w = max(1, int(image_hw[1]) // patch_size)
        exp_tokens = exp_h * exp_w
        for s in candidates:
            if token_count - s == exp_tokens:
                special = s
                break

    if special is None:
        for s in candidates:
            if token_count - s > 0:
                special = s
                break

    if special is None:
        raise ValueError(f'Invalid token count {token_count}: cannot strip special tokens')

    x = x[:, special:, :]
    num_tokens = int(x.shape[1])

    ratio = None
    if image_hw is not None and int(image_hw[1]) > 0:
        ratio = float(image_hw[0]) / float(image_hw[1])

    best = None
    root = int(math.isqrt(num_tokens))
    for h in range(root, 0, -1):
        if num_tokens % h != 0:
            continue
        w = num_tokens // h
        for cand_h, cand_w in ((h, w), (w, h)):
            if ratio is None:
                score = abs(cand_h - cand_w)
            else:
                score = abs((float(cand_h) / float(max(1, cand_w))) - ratio)
            if best is None or score < best[0]:
                best = (score, cand_h, cand_w)

    if best is None:
        raise ValueError(f'Cannot factor token grid for num_tokens={num_tokens}')

    h_patches, w_patches = best[1], best[2]
    if h_patches * w_patches != num_tokens:
        raise ValueError(
            f'Cannot reshape {num_tokens} tokens into a 2-D grid '
            f'({h_patches}x{w_patches} != {num_tokens})'
        )

    return x.reshape(bsz, h_patches, w_patches, dim).permute(0, 3, 1, 2)


def _patch_pvmodule_dinov2_rect_tokens():
    # Runtime compatibility patch for environments with older pv_dinov2 reshape logic.
    PVModuleDINOv2._reshape_vit_output = _robust_reshape_vit_output


_patch_pvmodule_dinov2_rect_tokens()


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

missing_deps = missing_worldclass_dependencies()
if missing_deps:
    print(f"Worldclass deps missing: {missing_deps}")
    print(worldclass_install_hint())


def _extract_uploaded_file(upload_widget):
    value = upload_widget.value
    if not value:
        return None, None

    # ipywidgets <8: dict[name] -> {'content': ...}
    if isinstance(value, dict):
        name = next(iter(value.keys()))
        payload = value[name]
        content = payload['content'] if isinstance(payload, dict) else payload
    else:
        # ipywidgets >=8: tuple/list of dicts
        item = value[0]
        name = item.get('name', 'uploaded.bin')
        content = item.get('content')

    if isinstance(content, memoryview):
        content = content.tobytes()
    return name, bytes(content)


def _safe_torch_load(source):
    return safe_torch_load(source, map_location='cpu')


def _is_state_dict(candidate):
    return is_tensor_state_dict(candidate)


def _load_checkpoint_payload(checkpoint_path_text, checkpoint_upload_widget):
    upload_name, upload_bytes = _extract_uploaded_file(checkpoint_upload_widget)
    if upload_bytes is not None:
        payload = _safe_torch_load(io.BytesIO(upload_bytes))
        return payload, upload_name

    checkpoint_path = Path(checkpoint_path_text).expanduser()
    if not checkpoint_path.exists():
        raise FileNotFoundError(f'Checkpoint not found: {checkpoint_path}')
    payload = _safe_torch_load(checkpoint_path)
    return payload, str(checkpoint_path)


def _extract_state_dict(payload):
    state_dict, fmt = extract_model_state_dict(payload)
    fmt_map = {
        'model_state_dict': 'structured_checkpoint',
        'state_dict': 'state_dict_field',
        'model': 'train_checkpoint_model',
        'teacher': 'teacher_field',
        'ema_model': 'ema_model_field',
        'ema': 'ema_field',
        'raw_state_dict': 'raw_state_dict',
    }
    return state_dict, fmt_map.get(fmt, fmt)


def _infer_model_family(state_dict, model_hint='auto'):
    if model_hint in {'teacher', 'teacher_v3'}:
        return model_hint

    v3_markers = ('backbone.', 'neck.', 'mask_head.', 'quality_head.', 'rpn_objectness')
    if any(any(k.startswith(marker) for marker in v3_markers) for k in state_dict.keys()):
        return 'teacher_v3'
    return 'teacher'


def _infer_num_classes(state_dict, family):
    if family == 'teacher':
        w = state_dict.get('det_head.cls_pred.weight')
        if isinstance(w, torch.Tensor) and w.ndim >= 1:
            return int(w.shape[0])
    else:
        w = state_dict.get('det_head.stages.0.cls_head.4.weight')
        if isinstance(w, torch.Tensor) and w.ndim >= 1:
            return int(w.shape[0])
    return 3


def _infer_teacher_backbone_type(state_dict):
    if any(k.startswith('pv_module.backbone.blocks.') for k in state_dict):
        return 'timm'
    return 'pv'


def _build_teacher_model_for_state_dict(state_dict, num_classes):
    cfg = ApexXConfig()
    backbone_type = _infer_teacher_backbone_type(state_dict)

    if backbone_type == 'timm':
        pv_module = TimmBackboneAdapter(
            model_name='efficientnet_b0',
            pretrained=False,
            out_indices=(2, 3, 4),
        )
        p3_ch = pv_module.p3_channels
        p4_ch = pv_module.p4_channels
        p5_ch = pv_module.p5_channels
        ff_channels = p3_ch
    else:
        pv_module = PVModule(
            in_channels=3,
            p3_channels=16,
            p4_channels=24,
            p5_channels=32,
            coarse_level='P4',
        )
        p3_ch, p4_ch, p5_ch = 16, 24, 32
        ff_channels = 16

    fpn = DualPathFPN(
        pv_p3_channels=p3_ch,
        pv_p4_channels=p4_ch,
        pv_p5_channels=p5_ch,
        ff_channels=ff_channels,
        out_channels=16,
    )
    det_head = DetHead(in_channels=16, num_classes=num_classes, hidden_channels=16, depth=1)

    model = TeacherModel(
        num_classes=num_classes,
        config=cfg,
        pv_module=pv_module,
        fpn=fpn,
        det_head=det_head,
        feature_layers=('P3', 'P4'),
        use_ema=True,
        ema_decay=0.99,
        use_ema_for_forward=False,
    )
    return model


def _load_state_dict_non_strict(model, state_dict, strict=False):
    if strict:
        incompatible = model.load_state_dict(state_dict, strict=True)
        return incompatible, []

    model_state = model.state_dict()
    filtered = {}
    skipped = []
    for k, v in state_dict.items():
        expected = model_state.get(k)
        if expected is None:
            continue
        if tuple(expected.shape) != tuple(v.shape):
            skipped.append(k)
            continue
        filtered[k] = v

    incompatible = model.load_state_dict(filtered, strict=False)
    return incompatible, skipped


def _load_image(image_path_text, image_upload_widget):
    upload_name, upload_bytes = _extract_uploaded_file(image_upload_widget)
    if upload_bytes is not None:
        image = Image.open(io.BytesIO(upload_bytes)).convert('RGB')
        return image, upload_name

    image_path = Path(image_path_text).expanduser()
    if not image_path.exists():
        raise FileNotFoundError(f'Image not found: {image_path}')
    image = Image.open(image_path).convert('RGB')
    return image, str(image_path)


def _prepare_image_tensor(
    pil_image,
    target_size=1024,
    keep_aspect=False,
    align_to=32,
):
    w, h = pil_image.size
    target_size = int(max(64, target_size))

    if keep_aspect:
        scale = float(target_size) / float(max(h, w))
        align_to = max(1, int(align_to))
        new_h = max(align_to, int(round((h * scale) / float(align_to)) * align_to))
        new_w = max(align_to, int(round((w * scale) / float(align_to)) * align_to))
    else:
        new_h = target_size
        new_w = target_size

    resized = pil_image.resize((new_w, new_h))
    image_np = np.array(resized)
    tensor = torch.from_numpy(image_np).permute(2, 0, 1).float().unsqueeze(0) / 255.0
    return image_np, tensor.to(device), (new_h, new_w)


def _clip_boxes_to_image(boxes, image_h, image_w):
    boxes = boxes.clone()
    boxes[:, 0::2] = boxes[:, 0::2].clamp(0, max(0, image_w - 1))
    boxes[:, 1::2] = boxes[:, 1::2].clamp(0, max(0, image_h - 1))
    return boxes


def _postprocess_teacher_v3_outputs(outputs, conf_threshold, nms_iou, max_dets, image_hw):
    image_h, image_w = int(image_hw[0]), int(image_hw[1])

    boxes = outputs.get('boxes')
    score_matrix = outputs.get('scores')
    if not isinstance(boxes, torch.Tensor) or boxes.ndim != 2 or boxes.shape[1] != 4:
        raise ValueError('TeacherModelV3 output has invalid boxes tensor')
    if not isinstance(score_matrix, torch.Tensor) or score_matrix.ndim != 2:
        raise ValueError('TeacherModelV3 output has invalid scores tensor')

    if score_matrix.shape[0] != boxes.shape[0]:
        n = min(score_matrix.shape[0], boxes.shape[0])
        boxes = boxes[:n]
        score_matrix = score_matrix[:n]

    # If scores look like logits, map to probabilities.
    if float(score_matrix.min().item()) < 0.0 or float(score_matrix.max().item()) > 1.0:
        score_matrix = torch.sigmoid(score_matrix)

    scores, classes = score_matrix.max(dim=1)
    boxes = _clip_boxes_to_image(boxes, image_h=image_h, image_w=image_w)

    keep = scores >= float(conf_threshold)
    boxes = boxes[keep]
    scores = scores[keep]
    classes = classes[keep]

    if boxes.numel() > 0:
        keep_idx = torchvision.ops.batched_nms(
            boxes,
            scores,
            classes,
            iou_threshold=float(nms_iou),
        )
        keep_idx = keep_idx[: int(max_dets)]
        boxes = boxes[keep_idx]
        scores = scores[keep_idx]
        classes = classes[keep_idx]
    else:
        keep_idx = torch.zeros((0,), dtype=torch.int64, device=boxes.device)

    masks = outputs.get('masks')
    masks_up = None
    if isinstance(masks, torch.Tensor) and masks.numel() > 0 and keep_idx.numel() > 0:
        masks_sel = masks[keep][keep_idx]
        if masks_sel.ndim == 4 and masks_sel.shape[1] == 1:
            masks_sel = masks_sel[:, 0]
        if masks_sel.ndim == 3:
            masks_up = F.interpolate(
                masks_sel.unsqueeze(1),
                size=(image_h, image_w),
                mode='bilinear',
                align_corners=False,
            ).squeeze(1)

    return boxes, scores, classes, masks_up


def _draw_predictions(image_np, boxes, scores, classes, masks=None, max_dets=50, mask_threshold=0.5):
    boxes = boxes.detach().cpu() if isinstance(boxes, torch.Tensor) else torch.as_tensor(boxes)
    scores = scores.detach().cpu() if isinstance(scores, torch.Tensor) else torch.as_tensor(scores)
    classes = classes.detach().cpu() if isinstance(classes, torch.Tensor) else torch.as_tensor(classes)

    fig, ax = plt.subplots(1, 1, figsize=(12, 12))
    ax.imshow(image_np)

    rng = np.random.RandomState(42)
    n = min(int(boxes.shape[0]), int(max_dets))
    for i in range(n):
        x1, y1, x2, y2 = [float(v) for v in boxes[i].tolist()]
        score = float(scores[i].item())
        cls_id = int(classes[i].item())

        color = tuple((rng.rand(3) * 0.8 + 0.2).tolist())
        rect = plt.Rectangle(
            (x1, y1),
            max(1.0, x2 - x1),
            max(1.0, y2 - y1),
            fill=False,
            color=color,
            linewidth=1.5,
        )
        ax.add_patch(rect)
        ax.text(
            x1,
            max(0.0, y1 - 2.0),
            f'{cls_id}:{score:.3f}',
            color='white',
            fontsize=9,
            bbox=dict(facecolor='black', alpha=0.6, pad=1),
        )

        if masks is not None and i < masks.shape[0]:
            m = masks[i]
            if isinstance(m, torch.Tensor):
                m = m.detach().cpu().numpy()
            m = (m > float(mask_threshold)).astype(np.float32)
            if m.shape[:2] != image_np.shape[:2]:
                continue
            overlay = np.zeros((m.shape[0], m.shape[1], 4), dtype=np.float32)
            overlay[..., 0] = color[0]
            overlay[..., 1] = color[1]
            overlay[..., 2] = color[2]
            overlay[..., 3] = 0.25 * m
            ax.imshow(overlay)

    ax.set_axis_off()
    ax.set_title(f'Detections: {n}')
    plt.show()





In [None]:
import ipywidgets as widgets
from IPython.display import display

checkpoint_path = widgets.Text(
    value='outputs/a100_v3_1024px/best_1024.pt',
    description='CKPT path:',
    layout=widgets.Layout(width='900px')
)
checkpoint_upload = widgets.FileUpload(accept='.pt', multiple=False, description='Upload CKPT')

image_path = widgets.Text(
    value='',
    placeholder='optional local image path',
    description='Image path:',
    layout=widgets.Layout(width='900px')
)
image_upload = widgets.FileUpload(accept='image/*', multiple=False, description='Upload image')

model_hint = widgets.Dropdown(options=['auto', 'teacher', 'teacher_v3'], value='auto', description='Model:')
strict_load = widgets.Checkbox(value=False, description='Strict load')
use_ckpt_enhancer = widgets.Checkbox(value=True, description='Use enhancer')
keep_aspect = widgets.Checkbox(value=False, description='Keep aspect')
align_mode = widgets.Dropdown(
    options=[('auto', 'auto'), ('14', '14'), ('16', '16'), ('28', '28'), ('32', '32')],
    value='auto',
    description='Align:',
)
inference_size = widgets.IntSlider(value=1024, min=256, max=2048, step=32, description='Infer size:')
conf_threshold = widgets.FloatSlider(value=0.25, min=0.0, max=1.0, step=0.01, description='Conf:')
nms_iou = widgets.FloatSlider(value=0.5, min=0.1, max=0.9, step=0.01, description='NMS IoU:')
mask_threshold = widgets.FloatSlider(value=0.5, min=0.05, max=0.95, step=0.01, description='Mask thr:')
max_dets = widgets.IntSlider(value=100, min=1, max=500, step=1, description='Max dets:')
run_button = widgets.Button(description='Run Inference', button_style='success')
out = widgets.Output()


def _run_inference(_):
    with out:
        out.clear_output(wait=True)
        try:
            ckpt_payload, ckpt_name = _load_checkpoint_payload(checkpoint_path.value.strip(), checkpoint_upload)
            state_dict, ckpt_format = _extract_state_dict(ckpt_payload)

            family = _infer_model_family(state_dict, model_hint.value)
            num_classes = _infer_num_classes(state_dict, family)

            if family == 'teacher_v3':
                model = TeacherModelV3(num_classes=num_classes)
            else:
                model = _build_teacher_model_for_state_dict(state_dict, num_classes=num_classes)

            incompatible, skipped = _load_state_dict_non_strict(model, state_dict, strict=bool(strict_load.value))
            model = model.to(device).eval()

            enhancer = None
            enhancer_status = 'not used'
            if bool(use_ckpt_enhancer.value) and isinstance(ckpt_payload, dict):
                enhancer_state = ckpt_payload.get('enhancer')
                if _is_state_dict(enhancer_state):
                    enhancer = LearnableImageEnhancer()
                    enh_incompat, enh_skipped = _load_state_dict_non_strict(enhancer, enhancer_state, strict=False)
                    enhancer = enhancer.to(device).eval()
                    enhancer_status = (
                        f'loaded (missing={len(getattr(enh_incompat, "missing_keys", []))}, '
                        f'unexpected={len(getattr(enh_incompat, "unexpected_keys", []))}, '
                        f'shape_skipped={len(enh_skipped)})'
                    )
                else:
                    enhancer_status = 'checkpoint has no enhancer state'

            img_pil, image_name = _load_image(image_path.value.strip(), image_upload)
            align_to = 14 if family == 'teacher_v3' else 32
            if align_mode.value != 'auto':
                align_to = int(align_mode.value)
            image_np, image_tensor, image_size = _prepare_image_tensor(
                img_pil,
                target_size=int(inference_size.value),
                keep_aspect=bool(keep_aspect.value),
                align_to=align_to,
            )

            print(f'Checkpoint: {ckpt_name} ({ckpt_format})')
            print(f'Model family: {family} | classes: {num_classes}')
            print(f'Image: {image_name} | resized: {image_size[1]}x{image_size[0]} | align: {align_to} | keep_aspect: {bool(keep_aspect.value)}')
            print(f"Missing keys: {len(getattr(incompatible, 'missing_keys', []))} | Unexpected keys: {len(getattr(incompatible, 'unexpected_keys', []))} | Shape-skipped: {len(skipped)}")
            print(f'Enhancer: {enhancer_status}')

            def _forward_once(model_input, image_hw):
                if family == 'teacher_v3':
                    outputs = model(model_input)
                    return _postprocess_teacher_v3_outputs(
                        outputs,
                        conf_threshold=float(conf_threshold.value),
                        nms_iou=float(nms_iou.value),
                        max_dets=int(max_dets.value),
                        image_hw=image_hw,
                    )

                outputs = model(model_input, use_ema=False)
                dets = post_process_detections_per_class(
                    outputs.logits_by_level,
                    outputs.boxes_by_level,
                    outputs.quality_by_level,
                    conf_threshold=float(conf_threshold.value),
                    nms_threshold=float(nms_iou.value),
                    max_detections=int(max_dets.value),
                )[0]
                boxes = dets['boxes']
                scores = dets['scores']
                classes = dets['classes']

                masks_up = None
                if isinstance(outputs.masks, torch.Tensor) and outputs.masks.numel() > 0:
                    masks = outputs.masks[0]
                    if masks.ndim == 4 and masks.shape[1] == 1:
                        masks = masks[:, 0]
                    if masks.ndim == 3:
                        masks_up = F.interpolate(
                            masks.unsqueeze(1),
                            size=(image_hw[0], image_hw[1]),
                            mode='bilinear',
                            align_corners=False,
                        ).squeeze(1)
                        if masks_up.shape[0] > boxes.shape[0]:
                            masks_up = masks_up[: boxes.shape[0]]

                return boxes, scores, classes, masks_up

            with torch.no_grad():
                model_input = image_tensor
                if enhancer is not None:
                    model_input = enhancer(model_input)

                try:
                    boxes, scores, classes, masks_up = _forward_once(
                        model_input,
                        image_hw=(image_np.shape[0], image_np.shape[1]),
                    )
                except ValueError as exc:
                    message = str(exc)
                    can_retry = family == 'teacher_v3' and 'Cannot reshape' in message
                    if not can_retry:
                        raise
                    print(f'Warning: {message}')
                    print('Retrying with square resize to stabilize token grid...')
                    image_np, image_tensor, image_size = _prepare_image_tensor(
                        img_pil,
                        target_size=int(inference_size.value),
                        keep_aspect=False,
                        align_to=14,
                    )
                    model_input = image_tensor
                    if enhancer is not None:
                        model_input = enhancer(model_input)
                    boxes, scores, classes, masks_up = _forward_once(
                        model_input,
                        image_hw=(image_np.shape[0], image_np.shape[1]),
                    )
                    print(
                        f'Retry image resized to: {image_size[1]}x{image_size[0]} '
                        '| align: 14 | keep_aspect: false'
                    )

            _draw_predictions(
                image_np,
                boxes,
                scores,
                classes,
                masks=masks_up,
                max_dets=int(max_dets.value),
                mask_threshold=float(mask_threshold.value),
            )

        except Exception as exc:
            print(f'Error: {exc}')


run_button.on_click(_run_inference)

display(widgets.VBox([
    widgets.HTML('<b>Checkpoint</b>'),
    checkpoint_path,
    checkpoint_upload,
    widgets.HTML('<b>Image</b>'),
    image_path,
    image_upload,
    widgets.HBox([model_hint, strict_load, use_ckpt_enhancer, keep_aspect, align_mode]),
    inference_size,
    conf_threshold,
    nms_iou,
    mask_threshold,
    max_dets,
    run_button,
    out,
]))
