In [1]:
import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv

In [2]:
def load_model(head_type, ckpt_path, bins=BINS, device=device):
    model = HomographyNet(head_type=head_type, bins=bins).to(device)
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model"])
    model.eval()
    print(f"Loaded {head_type} model from: {ckpt_path}")
    return model

def visualize_alignment(x2, pred_offsets, patch_size=64, title=""):
    """
    x2: (2,64,64) numpy, channels [orig, warped]
    pred_offsets: (8,) numpy (dx1,dy1,...,dx4,dy4)
    Warps channel-0 to channel-1 using predicted homography, shows diff.
    """
    a = (x2[0] * 255).astype(np.uint8)  # original patch
    b = (x2[1] * 255).astype(np.uint8)  # warped patch (target)
    H, W = patch_size, patch_size

    # local patch coordinates
    src = np.array([[0,0],[W,0],[W,H],[0,H]], dtype=np.float32)
    dst = src + pred_offsets.reshape(4,2).astype(np.float32)

    Hmat = cv.getPerspectiveTransform(src, dst)
    Hinv = np.linalg.inv(Hmat)

    a_aligned = cv.warpPerspective(a, Hinv, (W, H), flags=cv.INTER_LINEAR, borderMode=cv.BORDER_REPLICATE)
    diff = cv.absdiff(a_aligned, b)

    fig, axs = plt.subplots(1, 4, figsize=(12,3))
    axs[0].imshow(a, cmap="gray"); axs[0].set_title("orig (ch0)"); axs[0].axis("off")
    axs[1].imshow(b, cmap="gray"); axs[1].set_title("warped (ch1)"); axs[1].axis("off")
    axs[2].imshow(a_aligned, cmap="gray"); axs[2].set_title("orig aligned"); axs[2].axis("off")
    axs[3].imshow(diff, cmap="gray"); axs[3].set_title("abs diff"); axs[3].axis("off")
    fig.suptitle(title)
    plt.show()


NameError: name 'BINS' is not defined