In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from models.GigaCount import GigaCount
from dataset import Crowd
from utils_custom import collate_fn
from transforms import RandomResizedCrop
from torchvision.transforms.v2 import Compose
import matplotlib.pyplot as plt
import os
from torchvision.transforms.functional import normalize, to_pil_image
from utils_custom import resize_density_map
from eval_metrics import sliding_window_predict


In [2]:

dataset = Crowd("qnrf", split="val", transforms=None, sigma=None, return_filename=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4, collate_fn=collate_fn)
data_iter = iter(dataloader)

In [3]:
checkpoint_path = "checkpoints/best_model.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
original_bins = {
        "qnrf": {
            "bins": {
                "fine": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]]
            },
            "anchor_points": {
                "fine": {
                    "middle": [0, 1, 2, 3, 4],
                    "average": [0, 1, 2, 3, 4.21937]
                }
            }
        }
    }


bins = original_bins["qnrf"]["bins"]["fine"]
anchor_points = original_bins["qnrf"]["anchor_points"]["fine"]["average"] 
bins = [(float(b[0]), float(b[1]) if b[1] != "inf" else float('inf')) for b in bins]
print("BINS SHAPE: ", bins)
anchor_points = [float(p) for p in anchor_points]
#bins = compute_dynamic_bins(original_bins["qnrf"]["bins"], num_bins=3)
#anchor_points = original_bins["qnrf"]["anchor_points"]["middle"]

model = GigaCount(bins=bins, anchor_points=anchor_points).to(DEVICE)
checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
model.load_state_dict(checkpoint) 

model.eval()

BINS SHAPE:  [(0.0, 0.0), (1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (4.0, inf)]


GigaCount(
  (image_encoder): Backbone(
    (stem): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
      )
      (1): Sequential(
        (0): CNBlock(
          (block): Sequential(
            (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
            (1): Permute()
            (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
            (3): Linear(in_features=96, out_features=384, bias=True)
            (4): GELU(approximate='none')
            (5): Linear(in_features=384, out_features=96, bias=True)
            (6): Permute()
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        )
        (1): CNBlock(
          (block): Sequential(
            (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
            (1): Permute()
            (2): LayerNorm((96,)

In [None]:
num_images = 30
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
window_size = 224
stride = 224
alpha = 0.8
# Create a grid of subplots: 2 rows and 10 columns (or adjust as needed)
fig, axes = plt.subplots(num_images, 2, figsize=(15, num_images * 2), dpi=200, tight_layout=True, frameon=False)
img_id = None
# Loop through the dataset and display images
for i in range(num_images):
    # Fetch next image from the iterator or dataset
    if img_id is not None:
        image, points, density, image_path = dataset[img_id]
    else:
        image, points, density, image_path = next(data_iter)

    image_height, image_width = image.shape[-2:]
    image = image.to(DEVICE)
    image_name = os.path.basename(image_path[0])

    with torch.no_grad():
        if stride is not None:  # Sliding window prediction.
            pred_density = sliding_window_predict(model, image, window_size, stride)
        else:
            _, pred_density = model(image)
        pred_count = pred_density.sum().item()
        resized_pred_density = resize_density_map(pred_density, (image_height, image_width)).cpu()

    image = normalize(image, mean=(0., 0., 0.), std=(1. / std[0], 1. / std[1], 1. / std[2]))
    image = normalize(image, mean=(-mean[0], -mean[1], -mean[2]), std=(1., 1., 1.))
    image = to_pil_image(image.squeeze(0))

    density = density.squeeze().numpy()
    resized_pred_density = resized_pred_density.squeeze().numpy()
    points = points[0].numpy()

    # Plot the Ground Truth Image (Left Column)
    axes[i, 0].imshow(image)
    axes[i, 0].axis("off")
    axes[i, 0].set_title(f"{image_name}\nGT count: {len(points)}")
    if len(points) > 0:
        axes[i, 0].scatter(points[:, 0], points[:, 1], s=1, c="white")
    axes[i, 0].imshow(density, cmap="jet", alpha=0.5)

    # Plot the Prediction Image (Right Column)
    axes[i, 1].imshow(image)
    axes[i, 1].imshow(resized_pred_density, cmap="jet", alpha=0.5)
    axes[i, 1].axis("off")
    axes[i, 1].set_title(f"Pred count: {pred_count:.2f}")

plt.show()

Reduction:  8
Before stage2 NaN: False
Before stage2 min/max: -304.03729248046875 527.873779296875
After stage2 NaN: False
torch.Size([20, 96, 56, 56]) torch.Size([20, 192, 28, 28]) torch.Size([20, 384, 14, 14])
Shallow NaN: False
Mid NaN: False
Deep NaN: False
Max of x before attention: tensor(93.2078, device='cuda:0')
Min of x before attention: tensor(-36.0024, device='cuda:0')
x has Inf: False
channel_attention NaN: False
filter_attention NaN: False
kernel_attention NaN: False
Max of x before attention: tensor(96.9354, device='cuda:0')
Min of x before attention: tensor(-37.1678, device='cuda:0')
x has Inf: False
channel_attention NaN: False
filter_attention NaN: False
kernel_attention NaN: False
Max of x before attention: tensor(28.7619, device='cuda:0')
Min of x before attention: tensor(-0.8843, device='cuda:0')
x has Inf: False
channel_attention NaN: False
filter_attention NaN: False
kernel_attention NaN: False
Max of x before attention: tensor(28.3144, device='cuda:0')
Min of x b

In [None]:
num_images = 30

# Create a grid of subplots: 2 rows and 10 columns (or adjust as needed)
fig, axes = plt.subplots(num_images, 2, figsize=(15, num_images * 2), dpi=200, tight_layout=True, frameon=False)
img_id = None
# Loop through the dataset and display images
for i in range(num_images):
    # Fetch next image from the iterator or dataset
    if img_id is not None:
        image, points, density, image_path = dataset[img_id]
    else:
        image, points, density, image_path = next(data_iter)

    image_height, image_width = image.shape[-2:]
    image = image.to(DEVICE)
    image_name = os.path.basename(image_path[0])

    with torch.no_grad():
        if stride is not None:  # Sliding window prediction.
            pred_density = sliding_window_predict(model, image, window_size, stride)
        else:
            pred_density = model(image)
        pred_count = pred_density.sum().item()
        resized_pred_density = resize_density_map(pred_density, (image_height, image_width)).cpu()

    image = normalize(image, mean=(0., 0., 0.), std=(1. / std[0], 1. / std[1], 1. / std[2]))
    image = normalize(image, mean=(-mean[0], -mean[1], -mean[2]), std=(1., 1., 1.))
    image = to_pil_image(image.squeeze(0))

    density = density.squeeze().numpy()
    resized_pred_density = resized_pred_density.squeeze().numpy()
    points = points[0].numpy()

    # Plot the Ground Truth Image (Left Column)
    axes[i, 0].imshow(image)
    axes[i, 0].axis("off")
    axes[i, 0].set_title(f"{image_name}\nGT count: {len(points)}")
    if len(points) > 0:
        axes[i, 0].scatter(points[:, 0], points[:, 1], s=1, c="white")
    axes[i, 0].imshow(density, cmap="jet", alpha=0.5)

    # Plot the Prediction Image (Right Column)
    axes[i, 1].imshow(image)
    axes[i, 1].imshow(resized_pred_density, cmap="jet", alpha=0.5)
    axes[i, 1].axis("off")
    axes[i, 1].set_title(f"Pred count: {pred_count:.2f}")

plt.show()