# Checkpoint and Image Inference UI

This notebook lets you choose a model checkpoint (`.pt`) and an image, run inference, and visualize detections/masks.

You can provide files either by:
- text path input (for `best.pt` or any local file), or
- upload widget.


In [None]:
import io
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
from PIL import Image

from apex_x.config import ApexXConfig
from apex_x.model import (
    TeacherModel,
    TeacherModelV3,
    PVModule,
    DualPathFPN,
    DetHead,
    TimmBackboneAdapter,
    post_process_detections_per_class,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


def _extract_uploaded_file(upload_widget):
    value = upload_widget.value
    if not value:
        return None, None

    # ipywidgets <8: dict[name] -> {'content': ...}
    if isinstance(value, dict):
        name = next(iter(value.keys()))
        payload = value[name]
        content = payload['content'] if isinstance(payload, dict) else payload
    else:
        # ipywidgets >=8: tuple/list of dicts
        item = value[0]
        name = item.get('name', 'uploaded.bin')
        content = item.get('content')

    if isinstance(content, memoryview):
        content = content.tobytes()
    return name, bytes(content)


def _safe_torch_load(source):
    try:
        return torch.load(source, map_location='cpu', weights_only=True)
    except TypeError:
        # Backward-compatible fallback for older PyTorch without weights_only.
        return torch.load(source, map_location='cpu')


def _is_state_dict(candidate):
    return isinstance(candidate, dict) and candidate and all(
        isinstance(k, str) and isinstance(v, torch.Tensor)
        for k, v in candidate.items()
    )


def _load_checkpoint_payload(checkpoint_path_text, checkpoint_upload_widget):
    upload_name, upload_bytes = _extract_uploaded_file(checkpoint_upload_widget)
    if upload_bytes is not None:
        payload = _safe_torch_load(io.BytesIO(upload_bytes))
        return payload, upload_name

    checkpoint_path = Path(checkpoint_path_text).expanduser()
    if not checkpoint_path.exists():
        raise FileNotFoundError(f'Checkpoint not found: {checkpoint_path}')
    payload = _safe_torch_load(checkpoint_path)
    return payload, str(checkpoint_path)


def _extract_state_dict(payload):
    if _is_state_dict(payload):
        return payload, 'raw_state_dict'

    if isinstance(payload, dict):
        candidates = (
            ('model_state_dict', 'structured_checkpoint'),
            ('state_dict', 'state_dict_field'),
            ('model', 'train_checkpoint_model'),
            ('teacher', 'teacher_field'),
            ('ema_model', 'ema_model_field'),
            ('ema', 'ema_field'),
        )
        for key, label in candidates:
            candidate = payload.get(key)
            if _is_state_dict(candidate):
                return candidate, label

    raise ValueError(
        'Unsupported checkpoint format. Expected raw state_dict or dict with one of: '
        'model_state_dict, state_dict, model, teacher, ema_model, ema.'
    )


def _infer_model_family(state_dict, model_hint='auto'):
    if model_hint in {'teacher', 'teacher_v3'}:
        return model_hint

    v3_markers = ('backbone.', 'neck.', 'mask_head.', 'quality_head.', 'rpn_objectness')
    if any(any(k.startswith(marker) for marker in v3_markers) for k in state_dict.keys()):
        return 'teacher_v3'
    return 'teacher'


def _infer_num_classes(state_dict, family):
    if family == 'teacher':
        w = state_dict.get('det_head.cls_pred.weight')
        if isinstance(w, torch.Tensor) and w.ndim >= 1:
            return int(w.shape[0])
    else:
        w = state_dict.get('det_head.stages.0.cls_head.4.weight')
        if isinstance(w, torch.Tensor) and w.ndim >= 1:
            return int(w.shape[0])
    return 3


def _infer_teacher_backbone_type(state_dict):
    if any(k.startswith('pv_module.backbone.blocks.') for k in state_dict):
        return 'timm'
    return 'pv'


def _build_teacher_model_for_state_dict(state_dict, num_classes):
    cfg = ApexXConfig()
    backbone_type = _infer_teacher_backbone_type(state_dict)

    if backbone_type == 'timm':
        pv_module = TimmBackboneAdapter(
            model_name='efficientnet_b0',
            pretrained=False,
            out_indices=(2, 3, 4),
        )
        p3_ch = pv_module.p3_channels
        p4_ch = pv_module.p4_channels
        p5_ch = pv_module.p5_channels
        ff_channels = p3_ch
    else:
        pv_module = PVModule(
            in_channels=3,
            p3_channels=16,
            p4_channels=24,
            p5_channels=32,
            coarse_level='P4',
        )
        p3_ch, p4_ch, p5_ch = 16, 24, 32
        ff_channels = 16

    fpn = DualPathFPN(
        pv_p3_channels=p3_ch,
        pv_p4_channels=p4_ch,
        pv_p5_channels=p5_ch,
        ff_channels=ff_channels,
        out_channels=16,
    )
    det_head = DetHead(in_channels=16, num_classes=num_classes, hidden_channels=16, depth=1)

    model = TeacherModel(
        num_classes=num_classes,
        config=cfg,
        pv_module=pv_module,
        fpn=fpn,
        det_head=det_head,
        feature_layers=('P3', 'P4'),
        use_ema=True,
        ema_decay=0.99,
        use_ema_for_forward=False,
    )
    return model


def _load_state_dict_non_strict(model, state_dict, strict=False):
    if strict:
        incompatible = model.load_state_dict(state_dict, strict=True)
        return incompatible, []

    model_state = model.state_dict()
    filtered = {}
    skipped = []
    for k, v in state_dict.items():
        expected = model_state.get(k)
        if expected is None:
            continue
        if tuple(expected.shape) != tuple(v.shape):
            skipped.append(k)
            continue
        filtered[k] = v

    incompatible = model.load_state_dict(filtered, strict=False)
    return incompatible, skipped


def _load_image(image_path_text, image_upload_widget):
    upload_name, upload_bytes = _extract_uploaded_file(image_upload_widget)
    if upload_bytes is not None:
        image = Image.open(io.BytesIO(upload_bytes)).convert('RGB')
        return image, upload_name

    image_path = Path(image_path_text).expanduser()
    if not image_path.exists():
        raise FileNotFoundError(f'Image not found: {image_path}')
    image = Image.open(image_path).convert('RGB')
    return image, str(image_path)


def _prepare_image_tensor(pil_image, max_side=1280, align_to=32):
    w, h = pil_image.size
    scale = min(1.0, float(max_side) / float(max(h, w)))
    align_to = max(1, int(align_to))
    new_h = max(align_to, int(round((h * scale) / float(align_to)) * align_to))
    new_w = max(align_to, int(round((w * scale) / float(align_to)) * align_to))
    resized = pil_image.resize((new_w, new_h))
    image_np = np.array(resized)
    tensor = torch.from_numpy(image_np).permute(2, 0, 1).float().unsqueeze(0) / 255.0
    return image_np, tensor.to(device), (new_h, new_w)


def _draw_predictions(image_np, boxes, scores, classes, masks=None, max_dets=50):
    boxes = boxes.detach().cpu() if isinstance(boxes, torch.Tensor) else torch.as_tensor(boxes)
    scores = scores.detach().cpu() if isinstance(scores, torch.Tensor) else torch.as_tensor(scores)
    classes = classes.detach().cpu() if isinstance(classes, torch.Tensor) else torch.as_tensor(classes)

    fig, ax = plt.subplots(1, 1, figsize=(12, 12))
    ax.imshow(image_np)

    n = min(int(boxes.shape[0]), int(max_dets))
    for i in range(n):
        x1, y1, x2, y2 = [float(v) for v in boxes[i].tolist()]
        score = float(scores[i].item())
        cls_id = int(classes[i].item())

        rect = plt.Rectangle((x1, y1), max(1.0, x2 - x1), max(1.0, y2 - y1),
                             fill=False, color='lime', linewidth=1.5)
        ax.add_patch(rect)
        ax.text(x1, max(0.0, y1 - 2.0), f'{cls_id}:{score:.3f}', color='yellow', fontsize=9,
                bbox=dict(facecolor='black', alpha=0.5, pad=1))

        if masks is not None and i < masks.shape[0]:
            m = masks[i]
            if isinstance(m, torch.Tensor):
                m = m.detach().cpu().numpy()
            m = (m > 0.5).astype(np.float32)
            if m.shape[:2] != image_np.shape[:2]:
                continue
            overlay = np.zeros((m.shape[0], m.shape[1], 4), dtype=np.float32)
            overlay[..., 1] = 1.0
            overlay[..., 3] = 0.25 * m
            ax.imshow(overlay)

    ax.set_axis_off()
    ax.set_title(f'Detections: {n}')
    plt.show()


In [None]:
import ipywidgets as widgets
from IPython.display import display

checkpoint_path = widgets.Text(
    value='artifacts/train_output/checkpoints/best.pt',
    description='CKPT path:',
    layout=widgets.Layout(width='900px')
)
checkpoint_upload = widgets.FileUpload(accept='.pt', multiple=False, description='Upload CKPT')

image_path = widgets.Text(
    value='',
    placeholder='optional local image path',
    description='Image path:',
    layout=widgets.Layout(width='900px')
)
image_upload = widgets.FileUpload(accept='image/*', multiple=False, description='Upload image')

model_hint = widgets.Dropdown(options=['auto', 'teacher', 'teacher_v3'], value='auto', description='Model:')
strict_load = widgets.Checkbox(value=False, description='Strict load')
conf_threshold = widgets.FloatSlider(value=0.25, min=0.0, max=1.0, step=0.01, description='Conf:')
nms_iou = widgets.FloatSlider(value=0.5, min=0.1, max=0.9, step=0.01, description='NMS IoU:')
max_dets = widgets.IntSlider(value=100, min=1, max=500, step=1, description='Max dets:')
run_button = widgets.Button(description='Run Inference', button_style='success')
out = widgets.Output()


def _run_inference(_):
    with out:
        out.clear_output(wait=True)
        try:
            ckpt_payload, ckpt_name = _load_checkpoint_payload(checkpoint_path.value.strip(), checkpoint_upload)
            state_dict, ckpt_format = _extract_state_dict(ckpt_payload)

            family = _infer_model_family(state_dict, model_hint.value)
            num_classes = _infer_num_classes(state_dict, family)

            if family == 'teacher_v3':
                model = TeacherModelV3(num_classes=num_classes)
            else:
                model = _build_teacher_model_for_state_dict(state_dict, num_classes=num_classes)

            incompatible, skipped = _load_state_dict_non_strict(model, state_dict, strict=bool(strict_load.value))
            model = model.to(device).eval()

            img_pil, image_name = _load_image(image_path.value.strip(), image_upload)
            align_to = 14 if family == 'teacher_v3' else 32
            image_np, image_tensor, image_size = _prepare_image_tensor(img_pil, align_to=align_to)

            print(f'Checkpoint: {ckpt_name} ({ckpt_format})')
            print(f'Model family: {family} | classes: {num_classes}')
            print(f'Image: {image_name} | resized: {image_size[1]}x{image_size[0]} | align: {align_to}')
            print(f"Missing keys: {len(getattr(incompatible, 'missing_keys', []))} | Unexpected keys: {len(getattr(incompatible, 'unexpected_keys', []))} | Shape-skipped: {len(skipped)}")

            with torch.no_grad():
                if family == 'teacher_v3':
                    outputs = model(image_tensor)
                    boxes = outputs['boxes']
                    score_matrix = outputs['scores']
                    if score_matrix.ndim == 2 and score_matrix.shape[1] > 0:
                        if float(score_matrix.min().item()) < 0.0 or float(score_matrix.max().item()) > 1.0:
                            score_matrix = torch.sigmoid(score_matrix)
                        scores, classes = score_matrix.max(dim=1)
                    else:
                        scores = torch.zeros((boxes.shape[0],), device=boxes.device)
                        classes = torch.zeros((boxes.shape[0],), dtype=torch.int64, device=boxes.device)

                    keep = scores >= float(conf_threshold.value)
                    boxes = boxes[keep]
                    scores = scores[keep]
                    classes = classes[keep]

                    if boxes.numel() > 0:
                        keep_idx = torchvision.ops.nms(boxes, scores, float(nms_iou.value))
                        keep_idx = keep_idx[: int(max_dets.value)]
                        boxes = boxes[keep_idx]
                        scores = scores[keep_idx]
                        classes = classes[keep_idx]
                    else:
                        keep_idx = torch.zeros((0,), dtype=torch.int64, device=boxes.device)

                    masks = outputs.get('masks')
                    masks_up = None
                    if isinstance(masks, torch.Tensor) and masks.numel() > 0 and keep_idx.numel() > 0:
                        masks_sel = masks[keep][keep_idx]
                        if masks_sel.ndim == 4 and masks_sel.shape[1] == 1:
                            masks_sel = masks_sel[:, 0]
                        if masks_sel.ndim == 3:
                            masks_up = F.interpolate(
                                masks_sel.unsqueeze(1),
                                size=(image_np.shape[0], image_np.shape[1]),
                                mode='bilinear',
                                align_corners=False,
                            ).squeeze(1)

                else:
                    outputs = model(image_tensor, use_ema=False)
                    dets = post_process_detections_per_class(
                        outputs.logits_by_level,
                        outputs.boxes_by_level,
                        outputs.quality_by_level,
                        conf_threshold=float(conf_threshold.value),
                        nms_threshold=float(nms_iou.value),
                        max_detections=int(max_dets.value),
                    )[0]
                    boxes = dets['boxes']
                    scores = dets['scores']
                    classes = dets['classes']

                    masks_up = None
                    if isinstance(outputs.masks, torch.Tensor) and outputs.masks.numel() > 0:
                        masks = outputs.masks[0]
                        if masks.ndim == 4 and masks.shape[1] == 1:
                            masks = masks[:, 0]
                        if masks.ndim == 3:
                            masks_up = F.interpolate(
                                masks.unsqueeze(1),
                                size=(image_np.shape[0], image_np.shape[1]),
                                mode='bilinear',
                                align_corners=False,
                            ).squeeze(1)
                            if masks_up.shape[0] > boxes.shape[0]:
                                masks_up = masks_up[: boxes.shape[0]]

            _draw_predictions(image_np, boxes, scores, classes, masks=masks_up, max_dets=int(max_dets.value))

        except Exception as exc:
            print(f'Error: {exc}')


run_button.on_click(_run_inference)

display(widgets.VBox([
    widgets.HTML('<b>Checkpoint</b>'),
    checkpoint_path,
    checkpoint_upload,
    widgets.HTML('<b>Image</b>'),
    image_path,
    image_upload,
    widgets.HBox([model_hint, strict_load]),
    conf_threshold,
    nms_iou,
    max_dets,
    run_button,
    out,
]))
