In [2]:
import os
import h5py
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
from torchvision import transforms
from models.csrnet_mbv3 import MobileCSRNet
from models.csrnet_vgg import CSRNet

MODEL="build/csrnet_mobile_A.pt"
PREDDIR=""
PART="A"

def load_model(model_path, device):
    model = MobileCSRNet().to(device)
    model.load_state_dict(torch.load(MODEL, map_location=device, weights_only=False))
    # checkpoint = torch.load(MODEL, map_location=device, weights_only=False)
    # model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    return model

def get_transform():
    return transforms.Compose([
        transforms.Resize((512, 512)),
        # transforms.RandomHorizontalFlip(p=0.5),
        # transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

def predict_density(model, image_path, device):
    image = Image.open(image_path).convert("RGB")
    input_tensor = get_transform()(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor).cpu().squeeze(0).squeeze(0).numpy()
    return output

def evaluate(gt_dir, img_dir=None, model=None, pred_dir=None, device='cpu'):
    mae, mse, total = 0.0, 0.0, 0
    gt_files = sorted([f for f in os.listdir(gt_dir) if f.endswith('.h5')])

    for fname in tqdm(gt_files, desc="Evaluating"):
        gt_path = os.path.join(gt_dir, fname)
        with h5py.File(gt_path, 'r') as f:
            gt_density = np.asarray(f['density'])
        gt_count = gt_density.sum()

        if model and img_dir:
            img_path = os.path.join(img_dir, fname.replace('.h5', '.jpg').replace('GT_', ''))
            pred_density = predict_density(model, img_path, device)
        elif pred_dir:
            pred_path = os.path.join(pred_dir, fname)
            if not os.path.exists(pred_path):
                print(f"Missing prediction for {fname}, skipping.")
                continue
            with h5py.File(pred_path, 'r') as f:
                pred_density = np.asarray(f['density'])
        else:
            raise ValueError("Either model+img_dir or pred_dir must be provided.")

        pred_count = pred_density.sum()
        error = abs(gt_count - pred_count)
        mae += error
        mse += error ** 2
        total += 1

    mae /= total
    rmse = (mse / total) ** 0.5

    print(f"\nEvaluation Results:")
    print(f"  MAE  = {mae:.2f}")
    print(f"  RMSE = {rmse:.2f}")

def main():
    part = PART.upper()
    root = os.path.abspath("dataset")
    gt_dir = os.path.join(root, f"part_{part}", "test_data", "ground-truth")
    img_dir = os.path.join(root, f"part_{part}", "test_data", "images")

    if not os.path.isdir(gt_dir):
        raise FileNotFoundError(f"Ground truth directory not found: {gt_dir}")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if MODEL:
        model = load_model(MODEL, device)
        evaluate(gt_dir, img_dir=img_dir, model=model, device=device)
    elif PREDDIR:
        pred_dir = os.path.abspath(PREDDIR)
        evaluate(gt_dir, pred_dir=pred_dir)
    else:
        raise ValueError("You must provide either --model or --pred_dir")

if __name__ == '__main__':
    main()


Evaluating: 100%|██████████| 182/182 [00:07<00:00, 23.17it/s]


Evaluation Results:
  MAE  = 172.29
  RMSE = 258.47





In [3]:
import torch
from models.csrnet_mbv3 import MobileCSRNet
from fvcore.nn import FlopCountAnalysis, parameter_count

MODEL="build/csrnet_mobile_A.pt"

# Inisialisasi model
model = MobileCSRNet()
model.load_state_dict(torch.load(MODEL, map_location='cpu', weights_only=False))
# checkpoint = torch.load(MODEL, map_location=device, weights_only=False)
# model.load_state_dict(checkpoint['state_dict'])
model.eval()

# Input dummy sesuai resolusi saat training
dummy_input = torch.randn(1, 3, 512, 512)

# Hitung FLOPs dan jumlah parameter
flops = FlopCountAnalysis(model, dummy_input)
params = parameter_count(model)

print(f"FLOPs: {flops.total() / 1e9:.2f} GFLOPs")
print(f"Params: {params[''] / 1e6:.2f} M")

Unsupported operator aten::hardswish_ encountered 20 time(s)
Unsupported operator aten::add_ encountered 10 time(s)
Unsupported operator aten::hardsigmoid encountered 8 time(s)
Unsupported operator aten::mul encountered 8 time(s)


FLOPs: 3.90 GFLOPs
Params: 13.66 M
