robust_densitymap_tensor_verification.py
Run: python robust_densitymap_tensor_verification.py

In [None]:
import scipy.io as sio
import torch
from pathlib import Path
import numpy as np
import pprint

In [None]:
# ---------- EDIT PATHS ----------
RAW_GT_DIR = Path(r"C:\Users\mahal\OneDrive\Desktop\DL\archive\ShanghaiTech\part_A\train_data\ground-truth")
GT_TENSOR_DIR = Path(r"C:\Users\mahal\OneDrive\Desktop\DL\torch_density_trainA")
# ---------------------------------

In [None]:
def read_points_from_mat(mat_path):
    """Robustly extract Nx2 (x,y) points from a ShanghaiTech-style .mat file."""
    mat = sio.loadmat(mat_path)
    # quick debug print of keys (useful if you need to inspect)
    # print("MAT KEYS:", [k for k in mat.keys() if not k.startswith("__")])

    # common key
    if "image_info" in mat:
        info = mat["image_info"]
        candidates = []
        try: candidates.append(info[0][0][0][0])
        except Exception: pass
        try: candidates.append(info[0][0][0][0][0])
        except Exception: pass
        try: candidates.append(info[0][0])
        except Exception: pass

        for cand in candidates:
            # direct Nx2 numeric array
            if isinstance(cand, np.ndarray) and cand.ndim == 2 and cand.shape[1] == 2:
                return cand.astype(np.float32)
            # object array containing numeric arrays
            if isinstance(cand, np.ndarray) and cand.dtype == object:
                for item in cand.ravel():
                    if isinstance(item, np.ndarray) and item.ndim == 2 and item.shape[1] == 2:
                        return item.astype(np.float32)

    # fallback: scan all keys for Nx2 arrays or inside object arrays
    mat_keys = [k for k in mat.keys() if not k.startswith("__")]
    for k in mat_keys:
        v = mat[k]
        if isinstance(v, np.ndarray) and v.ndim == 2 and v.shape[1] == 2:
            return v.astype(np.float32)
        if isinstance(v, np.ndarray) and v.dtype == object:
            for item in v.ravel():
                if isinstance(item, np.ndarray) and item.ndim == 2 and item.shape[1] == 2:
                    return item.astype(np.float32)

    # nothing found
    return None, mat  # return mat for debugging

In [None]:
# pick a sample tensor
sample = "IMG_10"
print("Sample:", sample)

In [None]:
# find .mat
mat1 = RAW_GT_DIR / f"GT_{sample}.mat"
mat2 = RAW_GT_DIR / f"{sample}.mat"
mat_path = mat1 if mat1.exists() else (mat2 if mat2.exists() else None)

In [None]:
if mat_path is None:
    print("ERROR: GT .mat not found for sample:", sample)
    raise SystemExit(1)

In [None]:
# try to read points
pts_or_none = read_points_from_mat(str(mat_path))
if isinstance(pts_or_none, tuple) and pts_or_none[0] is None:
    # read_points_from_mat returned (None, mat) for debugging
    _, full_mat = pts_or_none
    print("Could not find Nx2 points in .mat. mat keys/structure preview:")
    pprint.pprint({k: type(full_mat[k]).__name__ for k in full_mat.keys() if not k.startswith("__")})
    print("\nYou can inspect the 'image_info' structure with:")
    print(">>> import scipy.io as sio; m = sio.loadmat(r'{}'); import pprint; pprint.pprint(m['image_info'])".format(mat_path))
    raise SystemExit(1)
else:
    pts = pts_or_none
    if pts is None or not isinstance(pts, np.ndarray) or pts.size == 0:
        print("No points found in .mat (empty).")
        raise SystemExit(1)

In [None]:
# Count points from .mat
mat_count = int(pts.shape[0])
print("GT headpoints (from .mat):", mat_count)
print("Sample of first 5 points (x,y):\n", pts[:5])

In [None]:
# Load density tensor saved by your preprocessing (.pt)
dens_t = torch.load(GT_TENSOR_DIR / f"{sample}.pt")  # expected shape [1, H_down, W_down]
density_sum = float(dens_t.sum().item())
print("Density map sum (tensor):", density_sum)

In [None]:
# Final compare
print("\n===== RESULT =====")
print(f"GT count = {mat_count}")
print(f"Density sum = {density_sum:.2f}")

In [None]:
if abs(mat_count - density_sum) <= 2:
    print("✔ Tensor is accurate (counts match within tolerance).")
else:
    print("⚠ Counts differ. Check preprocessing (scaling, point extraction, downsampling).")