In [2]:
import os
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.transforms.functional import resize as tvf_resize
from models.csrnet_mbv3 import MobileCSRNet

# ==== Config ====
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_DIR = "inference/test_images"
# MODEL_PATH = "build/csrnet_vgg_B.pth"
MODEL_PATH = "build/csrnet_mobile_B.pt"
SAVE_HEATMAP = True  # Set False jika tidak ingin simpan visualisasi density map
HEATMAP_DIR = "inference/heatmaps"

os.makedirs(HEATMAP_DIR, exist_ok=True)

# ==== Transform ====
# transform = transforms.Compose([
#     # transforms.Resize((384, 384)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                          std=[0.229, 0.224, 0.225]),

# ])

m_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225])
])

# ==== Load Model ====
model = MobileCSRNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False))
# checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
# model.load_state_dict(checkpoint['state_dict'])
# print(f"MAE: {checkpoint['best_prec1']} - epoch: {checkpoint['epoch']}")
model.eval()

# ==== Predict ====
image_files = [f for f in os.listdir(
    IMAGE_DIR) if f.endswith(('.jpg', '.png'))]

for fname in image_files:
    path = os.path.join(IMAGE_DIR, fname)
    image = Image.open(path).convert('RGB')
    input_img = m_transform(image).unsqueeze(0).to(DEVICE)  # [1, 3, 384, 384]

    with torch.no_grad():
        output = model(input_img)  # [1, 1, H, W]
        density_map = output.squeeze().cpu().numpy()
        total_count = density_map.sum()

    print(f"{fname}: Estimated Count = {total_count:.2f}")

    # Optional: Save heatmap
    if SAVE_HEATMAP:
        # plt.figure(figsize=(4, 4))
        plt.imshow(density_map, cmap='jet')
        plt.axis('off')
        plt.title(f'Count: {total_count:.1f}')
        heatmap_path = os.path.join(
            HEATMAP_DIR, fname.replace('.jpg', '_heatmap.png'))
        plt.savefig(heatmap_path, bbox_inches='tight', pad_inches=0)
        plt.close()


34-peoples.jpg: Estimated Count = 54.78
271-peoples.jpg: Estimated Count = 175.04
93-peoples.jpg: Estimated Count = 103.10
