In [5]:
import torch
import numpy as np
import torch.nn as nn
from torchvision.io import read_image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import matplotlib.image as mpimg
import os
import cv2
from tqdm import tqdm
from functools import partial
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torchvision.transforms import v2
from fly_dataset.py import FLYDataset
from fly_cnn import CNN_Fly

ModuleNotFoundError: No module named 'fly_dataset'

In [2]:
# Visualize
def visualize_predictions(model, dataset, device="cuda", num_samples=5):
    model.eval()
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 4))
    loss_fn = torch.nn.MSELoss()
    losses = []

    for i in range(num_samples):
        img, true_kp = dataset[i]

        input_img = img.unsqueeze(0).to(device)  # [1, 1, H, W]

        with torch.no_grad():
            pred_kp, _, _, _ = model(input_img)
            
        pred_kp = pred_kp.squeeze(0).cpu()
        true_kp = true_kp.cpu()
        # Compute MSE loss on normalized coordinates
        # loss = loss_fn(pred_kp, true_kp).item()
        # losses.append(loss)

        # Convert normalized to pixel coordinates
        true_kp_px = true_kp.clone()
        true_kp_px[:, 0] *= dataset.H  # x
        true_kp_px[:, 1] *= dataset.W  # y

        pred_kp_px = pred_kp.clone()
        pred_kp_px[:, 1] *= dataset.W
        pred_kp_px[:, 0] *= dataset.H

        loss = loss_fn(pred_kp_px, true_kp_px).item()
        losses.append(loss)

        img_np = img.squeeze(0).numpy()

        ax = axes[i]
        ax.imshow(img_np, cmap="gray")
        ax.scatter(pred_kp_px[:, 1], pred_kp_px[:, 0], c="r", label="Pred", s=10)
        ax.scatter(true_kp_px[:, 1], true_kp_px[:, 0], c="g", label="GT", s=10, alpha=0.6)
        ax.set_title(f"Sample {i}\nMSE: {loss:.4f}")
        ax.axis("off")

    plt.tight_layout()
    plt.legend()
    plt.show()
    avg_loss = sum(losses) / len(losses)
    print(f"\nAverage MSE over {num_samples} samples: {avg_loss:.4f}")

In [None]:
model = CNN_Fly(input_size=(480, 960), embedding_size=32)
train_dataset = FLYDataset("/scratch/cv-course2025/group2/data")
test_dataset = FLYDataset("/scratch/cv-course2025/group2/data", mode="test")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Use CPU for debug so you do not use too much GPU time
print(f"Used Device: {device}")

model.load_state_dict(torch.load("./fly-test.pt"))
model.to(device)
visualize_predictions(model, test_dataset, device=device)