In [None]:
import cv2
from tqdm import tqdm
import numpy as np
import random
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

In [29]:
# ============================================================
# Prepair Dataset
# ============================================================

def Preprocess_and_save_images(INPUT_DIR, OUTPUT_DIR, TARGET_SIZE):
    # Ustvari izhodni direktorij, ƒçe ≈°e ne obstaja
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Pridobi vse .jpg slike
    image_paths = [os.path.join(INPUT_DIR, f) for f in os.listdir(INPUT_DIR)
                   if f.lower().endswith(('.jpg', '.jpeg'))]

    print(f"Najdenih {len(image_paths)} slik za obdelavo...")

    for path in tqdm(image_paths, desc="Obdelujem slike"):
        # Preberi sliko
        img = cv2.imread(path)
        if img is None:
            print(f"‚ùå Napaka pri branju slike: {path}")
            continue

        # Pretvori v sivinsko
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Spremeni velikost (320x240)
        resized = cv2.resize(gray, TARGET_SIZE)

        # Ustvari enako ime datoteke v izhodni mapi
        filename = os.path.basename(path)
        output_path = os.path.join(OUTPUT_DIR, filename)

        # Shrani obdelano sliko
        cv2.imwrite(output_path, resized)

    print("‚úÖ Vse slike so uspe≈°no predobdelane in shranjene v:")
    print(f"   {OUTPUT_DIR}")


INPUT_DIR = "datasets/val2017"
PREPROCESSED_DIR = "datasets/val2017_preprocessed"

TARGET_SIZE = (320, 240)
# Preprocess_and_save_images(INPUT_DIR, OUTPUT_DIR, TARGET_SIZE)


In [30]:
# ============================================================
# HELPER FUNKCIJE ZA GENERIRANJE PAROV
# ============================================================

def sample_window(img_shape, window_size=64, margin=16):
    h, w = img_shape[:2]
    x = random.randint(margin, w - margin - window_size)
    y = random.randint(margin, h - margin - window_size)
    return x, y


def get_corners(x, y, window_size=64):
    return np.array([
        [x, y],
        [x + window_size, y],
        [x + window_size, y + window_size],
        [x, y + window_size]
    ], dtype=np.float32)


def perturb_corners(corners, disp_range=(-16, 16)):
    min_disp, max_disp = disp_range
    disp = np.random.randint(min_disp, max_disp + 1, size=corners.shape).astype(np.float32)
    return corners + disp


def generate_pair(img, window_size=64, margin=16, disp_range=(-16, 16)):
    h, w = img.shape[:2]

    x, y = sample_window((h, w), window_size, margin)

    src_corners = get_corners(x, y, window_size)
    dst_corners = perturb_corners(src_corners, disp_range)

    # Homografija H (src -> dst) in njen inverz
    H = cv2.getPerspectiveTransform(src_corners, dst_corners)
    H_inv = np.linalg.inv(H)

    # Warp celotne slike z H^-1
    warped = cv2.warpPerspective(img, H_inv, (w, h), flags=cv2.INTER_LINEAR)

    # Izre≈æi patche
    orig_patch = img[y:y + window_size, x:x + window_size]
    warped_patch = warped[y:y + window_size, x:x + window_size]

    # Stack v 2 kanala in normaliziraj
    pair = np.stack([orig_patch, warped_patch], axis=-1).astype(np.float32) / 255.0

    # Ground truth: pomiki kotiƒçkov
    offsets = (dst_corners - src_corners).astype(np.float32)

    return pair, offsets, src_corners, warped


In [31]:
# ============================================================
# TEST GENERIRANJA PARA
# ============================================================

def visualize_generate_pair(image_dir):
    # Nalo≈æ nakljuƒçno sliko
    image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
                   if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    img_path = random.choice(image_paths)
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)

    h, w = img.shape[:2]
    window_size = 64
    margin = 16
    disp_range = (-16, 16)

    # Sample window
    x, y = sample_window((h, w), window_size, margin)

    # Get corners
    src_corners = get_corners(x, y, window_size)
    dst_corners = perturb_corners(src_corners, disp_range)

    # Compute homography
    H = cv2.getPerspectiveTransform(src_corners, dst_corners)
    H_inv = np.linalg.inv(H)

    # Warp image
    warped = cv2.warpPerspective(img, H_inv, (w, h), flags=cv2.INTER_LINEAR)

    # Extract patches
    orig_patch = img[y:y + window_size, x:x + window_size]
    warped_patch = warped[y:y + window_size, x:x + window_size]

    # Calculate offsets
    offsets = dst_corners - src_corners

    # Plot
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))

    # 1. Original image with source corners
    ax = axes[0, 0]
    ax.imshow(img, cmap='gray')
    for i, (cx, cy) in enumerate(src_corners):
        ax.plot(cx, cy, 'go', markersize=10)
        ax.text(cx, cy - 5, f'{i}', color='green', fontsize=12, ha='center')
    rect = plt.Rectangle((x, y), window_size, window_size, fill=False, edgecolor='green', linewidth=2)
    ax.add_patch(rect)
    ax.set_title('Original Image + Source Corners (green)')
    ax.axis('off')

    # 2. Original image with destination corners
    ax = axes[0, 1]
    ax.imshow(img, cmap='gray')
    for i, (cx, cy) in enumerate(dst_corners):
        ax.plot(cx, cy, 'ro', markersize=10)
        ax.text(cx, cy - 5, f'{i}', color='red', fontsize=12, ha='center')
    # Draw lines showing displacement
    for i in range(4):
        ax.arrow(src_corners[i, 0], src_corners[i, 1],
                 offsets[i, 0], offsets[i, 1],
                 head_width=3, head_length=3, fc='yellow', ec='yellow', alpha=0.7)
    rect = plt.Rectangle((x, y), window_size, window_size, fill=False, edgecolor='green', linewidth=2, linestyle='--')
    ax.add_patch(rect)
    ax.set_title(f'Perturbed Corners (red)\nAvg offset: {np.abs(offsets).mean():.1f}px')
    ax.axis('off')

    # 3. Warped image with H^-1
    ax = axes[0, 2]
    ax.imshow(warped, cmap='gray')
    rect = plt.Rectangle((x, y), window_size, window_size, fill=False, edgecolor='blue', linewidth=2)
    ax.add_patch(rect)
    ax.set_title('Warped Image (H‚Åª¬π applied)')
    ax.axis('off')

    # 4. Original patch
    ax = axes[1, 0]
    ax.imshow(orig_patch, cmap='gray')
    ax.set_title('Original Patch (64√ó64)')
    ax.axis('off')

    # 5. Warped patch
    ax = axes[1, 1]
    ax.imshow(warped_patch, cmap='gray')
    ax.set_title('Warped Patch (64√ó64)')
    ax.axis('off')

    # Hide the 6th subplot
    axes[1, 2].axis('off')

    plt.tight_layout()
    plt.show()

    # Print offsets
    print("\nCorner offsets (Œîx, Œîy):")
    for i, (dx, dy) in enumerate(offsets):
        print(f"  Corner {i}: ({dx:+.1f}, {dy:+.1f}) px")


PREPROCESSED_DIR = "datasets/val2017_preprocessed"
# Run visualization
# if os.path.exists(PREPROCESSED_DIR):
#     visualize_generate_pair(PREPROCESSED_DIR)

In [32]:
def visualize_offset_sign(image_dir, window_size=64, margin=16, disp_range=(-16, 16)):
    # Pick random image
    image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
                   if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    if not image_paths:
        print("‚ùå No images found in", image_dir)
        return

    img_path = random.choice(image_paths)
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        print("‚ùå Could not read image:", img_path)
        return

    # Generate pair using your pipeline
    pair, offsets, src_corners, warped_true = generate_pair(
        img, window_size=window_size, margin=margin, disp_range=disp_range
    )
    dst_corners = src_corners + offsets

    # Reconstruct warps using both offset signs
    H_plus = cv2.getPerspectiveTransform(src_corners, dst_corners)
    H_minus = cv2.getPerspectiveTransform(src_corners, src_corners - offsets)

    warped_plus = cv2.warpPerspective(img, H_plus, (img.shape[1], img.shape[0]))
    warped_minus = cv2.warpPerspective(img, H_minus, (img.shape[1], img.shape[0]))

    # --- Visualization ---
    fig, axes = plt.subplots(1, 3, figsize=(15, 6))

    # 1Ô∏è‚É£ True warped image
    axes[0].imshow(warped_true, cmap='gray')
    axes[0].add_patch(plt.Polygon(src_corners, fill=False, edgecolor='green', lw=2, label='src'))
    axes[0].add_patch(plt.Polygon(dst_corners, fill=False, edgecolor='blue', lw=2, label='dst (GT)'))
    axes[0].set_title("True warped image (from generate_pair)")
    axes[0].legend()
    axes[0].axis('off')

    # 2Ô∏è‚É£ Reconstructed with +offsets
    axes[1].imshow(warped_plus, cmap='gray')
    axes[1].add_patch(plt.Polygon(src_corners, fill=False, edgecolor='green', lw=2))
    axes[1].add_patch(plt.Polygon(dst_corners, fill=False, edgecolor='red', lw=2))
    axes[1].set_title("Reconstructed warp (+offsets)")
    axes[1].axis('off')

    # 3Ô∏è‚É£ Reconstructed with -offsets
    axes[2].imshow(warped_minus, cmap='gray')
    axes[2].add_patch(plt.Polygon(src_corners, fill=False, edgecolor='green', lw=2))
    axes[2].add_patch(plt.Polygon(dst_corners, fill=False, edgecolor='red', lw=2))
    axes[2].set_title("Reconstructed warp (-offsets)")
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

    # Numeric comparison for confirmation
    diff_plus = np.mean(np.abs(warped_true.astype(np.float32) - warped_plus.astype(np.float32)))
    diff_minus = np.mean(np.abs(warped_true.astype(np.float32) - warped_minus.astype(np.float32)))

    print(f"Mean pixel difference (true vs +offsets): {diff_plus:.2f}")
    print(f"Mean pixel difference (true vs -offsets): {diff_minus:.2f}")

    if diff_minus < diff_plus:
        print("‚ö†Ô∏è Offsets likely need to be NEGATED during training.")
    else:
        print("‚úÖ Offsets appear to have the correct sign.")

# visualize_offset_sign(PREPROCESSED_DIR)


In [33]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels=None, stride=1,
                 dropout_rate=0.1):
        super(ResNetBlock, self).__init__()
        out_channels = out_channels or in_channels  # ƒçe ni doloƒçeno, ohrani enako ≈°t. kanalov

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        self.dropout = nn.Dropout2d(p=dropout_rate)
        self.relu = nn.ReLU(inplace=True)

        if in_channels != out_channels or stride != 1:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        residual = self.shortcut(x)

        out = self.conv1(x)
        out = self.dropout(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.dropout(out)
        out += residual

        out = self.relu(out)
        return out


In [34]:
class ResNetBody(nn.Module):
    def __init__(self, in_channels=2, dropout_rate=0.1):
        super(ResNetBody, self).__init__()

        # ----- 1. stopnja -----
        self.layer1 = nn.Sequential(
            ResNetBlock(in_channels, 64, dropout_rate=dropout_rate),
            ResNetBlock(64, 64, dropout_rate=dropout_rate),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 64x64 -> 32x32
        )

        # ----- 2. stopnja -----
        self.layer2 = nn.Sequential(
            ResNetBlock(64, 64, dropout_rate=dropout_rate),
            ResNetBlock(64, 64, dropout_rate=dropout_rate),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 32x32 -> 16x16
        )

        # ----- 3. stopnja -----
        self.layer3 = nn.Sequential(
            ResNetBlock(64, 128, dropout_rate=dropout_rate),
            ResNetBlock(128, 128, dropout_rate=dropout_rate),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 16x16 -> 8x8
        )

        # ----- 4. stopnja -----
        self.layer4 = nn.Sequential(
            ResNetBlock(128, 128, dropout_rate=dropout_rate),
            ResNetBlock(128, 128, dropout_rate=dropout_rate),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
            # zadnji max pool ni potreben, ohranimo 8x8
        )

        # ----- Polno povezan sloj -----
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128 * 8 * 8, 512)

    def forward(self, x):  # Nx2x64x64
        x = self.layer1(x)  # Nx64x32x32
        x = self.layer2(x)  # Nx64x16x16
        x = self.layer3(x)  # Nx128x8x8
        x = self.layer4(x)  # Nx128x8x8
        x = self.flatten(x)  # Nx8192
        x = self.fc(x)  # Nx512
        return x


In [35]:
class RegressionHead(nn.Module):
    def __init__(self, in_features=512, out_features=8):
        super(RegressionHead, self).__init__()
        self.fc = nn.Linear(in_features, out_features)

    def forward(self, x):  # Nx512
        return self.fc(x)  # Nx8

In [36]:
class ClassificationHead(nn.Module):
    def __init__(self, in_features=512, num_classes=21, class_dim=8):
        super(ClassificationHead, self).__init__()
        self.num_classes = num_classes
        self.class_dim = class_dim
        self.fc = nn.Linear(in_features, num_classes * class_dim)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):  # Nx512
        x = self.fc(x)  # Nx168
        x = x.view(-1, self.num_classes, self.class_dim)  # Nx21x8
        x = self.softmax(x)  # Nx21x8
        return x

In [37]:
class HomographyRegressor(nn.Module):
    def __init__(self, dropout_rate=0.1):
        super(HomographyRegressor, self).__init__()
        self.body = ResNetBody(in_channels=2, dropout_rate=dropout_rate)
        self.head = RegressionHead(in_features=512, out_features=8)

    def forward(self, x):
        x = self.body(x)
        x = self.head(x)
        return x

In [38]:
class HomographyClassifier(nn.Module):
    def __init__(self, num_classes=21, class_dim=8, dropout_rate=0.1):
        super(HomographyClassifier, self).__init__()
        self.body = ResNetBody(in_channels=2, dropout_rate=dropout_rate)
        self.head = ClassificationHead(in_features=512,
                                       num_classes=num_classes,
                                       class_dim=class_dim)

    def forward(self, x):
        x = self.body(x)
        x = self.head(x)
        return x

In [39]:
import os
import re


def extract_epoch(filename):
    match = re.search(r"epoch_(\d+)", filename)
    return int(match.group(1)) if match else -1


def nn_train(model, num_epochs, model_file_name, img, optimizer, criterion, checkpoint_dir="checkpoints"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    start_epoch = 0

    # üîÑ Resume if checkpoint exists
    checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")])
    checkpoints = sorted(checkpoints, key=extract_epoch)
    if checkpoints:
        latest_ckpt = os.path.join(checkpoint_dir, checkpoints[-1])
        checkpoint = torch.load(latest_ckpt, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"]
        print(f"‚úÖ Resuming from checkpoint: {latest_ckpt} (epoch {start_epoch})")
    else:
        print("üöÄ Starting training from scratch.")

    # üîÅ Training loop with progress bar for epochs
    progress_bar = tqdm(range(start_epoch, num_epochs), desc="Training", ncols=100)

    for epoch in progress_bar:
        model.train()

        pair, offsets, src_corners, warped = generate_pair(
            img=random.choice(img) if isinstance(img, list) else img,
            window_size=64,
            margin=16,
            disp_range=(-16, 16)
        )

        pair = torch.from_numpy(pair).permute(2, 0, 1).unsqueeze(0).to(device).float() # 1x2x64x64
        offsets = torch.from_numpy(offsets.flatten()).unsqueeze(0).to(device).float()  # 1x8

        # Forward
        preds = model(pair)
        loss = criterion(preds, -offsets)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar with current epoch and loss
        progress_bar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
        progress_bar.set_postfix(loss=f"{loss.item():.6f}")

        # üíæ Save checkpoint every 1000 epochs
        if (epoch + 1) % 1000 == 0 or (epoch + 1) == num_epochs:
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pth")
            torch.save({
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            }, checkpoint_path)
            print(f"\nüíæ Saved checkpoint: {checkpoint_path}")

            # üßπ Keep only last 4 checkpoints
            checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")])
            checkpoints = sorted(checkpoints, key=extract_epoch)
            while len(checkpoints) > 4:
                old_ckpt = os.path.join(checkpoint_dir, checkpoints[0])
                os.remove(old_ckpt)
                print(f"üóëÔ∏è Removed old checkpoint: {old_ckpt}")
                checkpoints.pop(0)

    progress_bar.close()

    # ‚úÖ Save final model
    torch.save(model.state_dict(), model_file_name)
    print(f"‚úÖ Final model saved: {model_file_name}")

    # üßπ Clear GPU cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"üßπ GPU memory cleared. Current allocated: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB")


def check_gpu_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated(device) / 1e9
        reserved = torch.cuda.memory_reserved(device) / 1e9
        total = torch.cuda.get_device_properties(device).total_memory / 1e9
        print(f"GPU Memory Status:")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Reserved:  {reserved:.2f} GB")
        print(f"  Total:     {total:.2f} GB")
        print(f"  Free:      {total - reserved:.2f} GB")
    else:
        print("CUDA not available")


def clear_gpu_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        print("‚úÖ GPU cache cleared")
        check_gpu_memory()
    else:
        print("CUDA not available")

In [42]:
import torch.optim as optim

num_epochs = 50000
learning_rate = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = HomographyRegressor(dropout_rate=0.1).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

image_names = [
    "000000002299.jpg",
    # "000000000285.jpg",
    # "000000000632.jpg",
]

images = []
for filename in image_names:
    img_path = os.path.join(PREPROCESSED_DIR, filename)
    image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if image is not None:
        images.append(image)
    else:
        print(f"‚ö†Ô∏è Warning: Could not load {filename}")

print(f"üì∑ Loaded {len(images)} image(s) for training")
#
# nn_train(
#     model=model,
#     num_epochs=num_epochs,
#     model_file_name=f"h_regressor_ep{num_epochs}_I{len(images)}.pth",
#     img=images[0] if len(images) == 1 else images,
#     optimizer=optimizer,
#     criterion=criterion,
#     checkpoint_dir="checkpoints_homography_-O"
# )

Using device: cuda
üì∑ Loaded 1 image(s) for training
üöÄ Starting training from scratch.


Epoch 1002/50000:   2%|‚ñç                     | 1001/50000 [01:22<1:37:24,  8.38it/s, loss=85.074081]


üíæ Saved checkpoint: checkpoints_homography_-O\checkpoint_epoch_1000.pth


Epoch 2004/50000:   4%|‚ñâ                       | 2002/50000 [02:44<51:45, 15.46it/s, loss=99.486053]


üíæ Saved checkpoint: checkpoints_homography_-O\checkpoint_epoch_2000.pth


Epoch 3002/50000:   6%|‚ñà‚ñé                    | 3001/50000 [04:07<1:19:23,  9.87it/s, loss=14.871353]


üíæ Saved checkpoint: checkpoints_homography_-O\checkpoint_epoch_3000.pth


Epoch 4002/50000:   8%|‚ñà‚ñä                    | 4001/50000 [05:22<1:29:18,  8.58it/s, loss=30.407738]


üíæ Saved checkpoint: checkpoints_homography_-O\checkpoint_epoch_4000.pth


Epoch 5002/50000:  10%|‚ñà‚ñà‚ñè                   | 5001/50000 [06:48<1:13:19, 10.23it/s, loss=11.859369]


üíæ Saved checkpoint: checkpoints_homography_-O\checkpoint_epoch_5000.pth
üóëÔ∏è Removed old checkpoint: checkpoints_homography_-O\checkpoint_epoch_1000.pth


Epoch 5012/50000:  10%|‚ñà‚ñà‚ñè                   | 5012/50000 [06:49<1:01:18, 12.23it/s, loss=37.289501]


KeyboardInterrupt: 

In [None]:
def test_model(model_path, test_image_names, num_samples=10, visualize=True):
    # Load model (accept both plain state_dict and checkpoint dict)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = HomographyRegressor(dropout_rate=0.1).to(device)

    loaded = torch.load(model_path, map_location=device)
    if isinstance(loaded, dict) and "model_state_dict" in loaded:
        model.load_state_dict(loaded["model_state_dict"])
        epoch_info = loaded.get("epoch", None)
        print(f"‚úÖ Checkpoint loaded from: {model_path}" + (f" (epoch {epoch_info})" if epoch_info is not None else ""))
    else:
        model.load_state_dict(loaded)
        print(f"‚úÖ Model state_dict loaded from: {model_path}")

    model.eval()
    print(f"Using device: {device}")

    # Load test images
    test_images = []
    for filename in test_image_names:
        img_path = os.path.join(PREPROCESSED_DIR, filename)
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if image is not None:
            test_images.append(image)
        else:
            print(f"‚ö†Ô∏è Warning: Could not load {filename}")

    print(f"üì∑ Loaded {len(test_images)} test image(s)")

    if len(test_images) == 0:
        print("‚ùå No test images loaded!")
        return None

    # Test the model
    errors = []
    predictions = []
    ground_truths = []

    with torch.no_grad():
        for i in tqdm(range(num_samples), desc="Testing"):
            # Generate test pair
            test_img = random.choice(test_images)
            pair, offsets, src_corners, warped = generate_pair(
                img=test_img,
                window_size=64,
                margin=16,
                disp_range=(-16, 16)
            )

            # Prepare input
            pair_tensor = torch.from_numpy(pair).permute(2, 0, 1).unsqueeze(0).to(device).float()
            offsets_gt = offsets.flatten()

            # Predict (use detach before converting to numpy)
            pred = model(pair_tensor)
            pred_np = pred.detach().cpu().numpy().flatten()

            # Calculate error
            error = np.abs(pred_np - offsets_gt)
            errors.append(error)
            predictions.append(pred_np)
            ground_truths.append(offsets_gt)

    errors = np.array(errors)
    predictions = np.array(predictions)
    ground_truths = np.array(ground_truths)

    # Calculate metrics
    mean_error = errors.mean()
    std_error = errors.std()
    max_error = errors.max()
    min_error = errors.min()

    print("\n" + "=" * 50)
    print("TEST RESULTS")
    print("=" * 50)
    print(f"Number of samples: {num_samples}")
    print(f"Mean absolute error: {mean_error:.4f} pixels")
    print(f"Std deviation: {std_error:.4f} pixels")
    print(f"Min error: {min_error:.4f} pixels")
    print(f"Max error: {max_error:.4f} pixels")
    print("=" * 50)

    # Visualize if requested
    if visualize:
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        # 1. Error distribution
        ax = axes[0, 0]
        ax.hist(errors.flatten(), bins=50, alpha=0.7, color='blue', edgecolor='black')
        ax.axvline(mean_error, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_error:.4f}')
        ax.set_xlabel('Absolute Error (pixels)')
        ax.set_ylabel('Frequency')
        ax.set_title('Error Distribution')
        ax.legend()
        ax.grid(True, alpha=0.3)

        # 2. Per-coordinate error
        ax = axes[0, 1]
        coord_errors = errors.mean(axis=0)
        coord_names = ['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4']
        bars = ax.bar(coord_names, coord_errors, color=['red' if i % 2 == 0 else 'blue' for i in range(8)])
        ax.set_ylabel('Mean Absolute Error (pixels)')
        ax.set_title('Error per Corner Coordinate')
        ax.grid(True, alpha=0.3, axis='y')

        # 3. Prediction vs Ground Truth scatter
        ax = axes[1, 0]
        ax.scatter(ground_truths.flatten(), predictions.flatten(), alpha=0.5, s=10)
        min_val = min(ground_truths.min(), predictions.min())
        max_val = max(ground_truths.max(), predictions.max())
        ax.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect prediction')
        ax.set_xlabel('Ground Truth Offset (pixels)')
        ax.set_ylabel('Predicted Offset (pixels)')
        ax.set_title('Predictions vs Ground Truth')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.axis('equal')

        # 4. Sample visualization
        ax = axes[1, 1]
        test_img = random.choice(test_images)
        pair, offsets, src_corners, warped = generate_pair(
            img=test_img,
            window_size=64,
            margin=16,
            disp_range=(-16, 16)
        )
        pair_tensor = torch.from_numpy(pair).permute(2, 0, 1).unsqueeze(0).to(device).float()
        pred = model(pair_tensor).cpu().detach().numpy().flatten().reshape(4, 2)

        # Show the pair
        combined = np.hstack([pair[:, :, 0], pair[:, :, 1]])
        ax.imshow(combined, cmap='gray')
        ax.set_title('Sample: Original (left) | Warped (right)')
        ax.axis('off')

        plt.tight_layout()
        plt.savefig('test_model_results.png', dpi=150, bbox_inches='tight')
        print(f"üìä Visualization saved to: test_model_results.png")
        plt.show()

    return {
        'mean_error': mean_error,
        'std_error': std_error,
        'max_error': max_error,
        'min_error': min_error,
        'errors': errors,
        'predictions': predictions,
        'ground_truths': ground_truths
    }

# test on latest checkpoint
latest_model_path = "checkpoints_homography_-O/checkpoint_epoch_5000.pth"
test_image_names = [
    "000000002299.jpg",
]
test_model(
    model_path=latest_model_path,
    test_image_names=test_image_names,
    num_samples=1000,
    visualize=True
)