# Sanity check

In [1]:
import torch.nn.functional as F
import torch
from math import exp

def compute_SSIM(img1, img2, data_range, window_size=11, channel=1, size_average=True, spatial_dims=2):
    # referred from https://github.com/Po-Hsun-Su/pytorch-ssim
    # default window_size 11
    if len(img1.size()) == 2:
        shape_ = img1.shape
        img1 = img1.view(1, 1, *shape_)
        img2 = img2.view(1, 1, *shape_)
    window = create_window(window_size, channel, spatial_dims=spatial_dims)
    window = window.type_as(img1)

    conv_op = F.conv2d if spatial_dims == 2 else F.conv3d

    mu1 = conv_op(img1, window, padding=window_size//2)
    mu2 = conv_op(img2, window, padding=window_size//2)
    mu1_sq, mu2_sq = mu1.pow(2), mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = conv_op(img1*img1, window, padding=window_size//2) - mu1_sq
    sigma2_sq = conv_op(img2*img2, window, padding=window_size//2) - mu2_sq
    sigma12 = conv_op(img1*img2, window, padding=window_size//2) - mu1_mu2

    C1, C2 = (0.01*data_range)**2, (0.03*data_range)**2
    #C1, C2 = 0.01**2, 0.03**2
    ssim_map = ((2*mu1_mu2+C1)*(2*sigma12+C2)) / ((mu1_sq+mu2_sq+C1)*(sigma1_sq+sigma2_sq+C2))
    if size_average:
        return ssim_map.mean().item()
    else:
        return ssim_map.mean(1).mean(1).mean(1).item()
    
def create_window(window_size, channel, spatial_dims=2):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    if spatial_dims == 2:
        window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    else:
        window = _2D_window.expand(channel, 1, window_size, window_size, window_size).contiguous()
    return window

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()

In [6]:
import numpy as np
# In CTTransforms.normalize_hu()
def normalize_hu(hu, min_hu=-1024, max_hu=3072, do_clip=False):
    if do_clip:
        hu = np.clip(hu, min_hu, max_hu)
    norm_hu = (hu - min_hu) / (max_hu - min_hu)  # Normalize to [0,1]
    return norm_hu


In [11]:
from skimage.metrics import structural_similarity as ssim

In [20]:
import os, re, glob
import torch
from tqdm import tqdm

# -----------------------------
# CONFIG — adjust as needed
# -----------------------------
MA_DIR = r"E:\Briya challenge data\input"   # artifact-affected .raw
GT_DIR = r"E:\Briya challenge data\simulated"    # ground-truth .raw
MASK_DIR = r"D:\AAPM_MAR_dataset\body_validate\Mask"  # optional: metal mask .raw

# Most medical RAWs are uint16; change if yours differ (e.g., np.uint8, np.float32)
DTYPE = np.float32
LITTLE_ENDIAN = True  # set False if your RAWs are big-endian
# If your filenames ALWAYS carry shape suffix like ..._512x512x1, keep True.
# Otherwise set to False and hardcode SHAPE_2D or SHAPE_3D below.
PARSE_SHAPE_FROM_NAME = True
# For fallback (if not parsing from name):
SHAPE_2D = (512, 512)      # (H, W)
SHAPE_3D = (512, 512, 1)   # (H, W, D)

# SSIM options
WINDOW_SIZE = 11
SPATIAL_DIMS = 2           # 2 for 2D images, 3 for volumes
CHANNELS = 1               # single-channel
SIZE_AVERAGE = True        # True: scalar mean SSIM; False: per-image SSIM

# -----------------------------
# Helpers
# -----------------------------
shape_pat = re.compile(r"_([0-9]+)x([0-9]+)x([0-9]+)\.raw$", re.IGNORECASE)

def parse_shape_from_name(fname):
    """
    Extract shape like _512x512x1 from the end of the filename.
    Returns (H, W) for 2D (if D==1) or (H, W, D) for 3D.
    """
    m = shape_pat.search(fname)
    if not m:
        return None
    H, W, D = map(int, m.groups())
    if D == 1:
        return (H, W) if SPATIAL_DIMS == 2 else (H, W, D)
    else:
        # If you truly have 3D volumes with D>1, set SPATIAL_DIMS=3 above.
        return (H, W, D)

def load_raw(path, dtype=DTYPE, little_endian=LITTLE_ENDIAN, shape=None):
    """
    Reads a .raw binary file into a NumPy array of given shape.
    If shape is 2D: (H, W).
    If shape is 3D: (H, W, D).
    """
    dt = dtype.newbyteorder("<" if little_endian else ">")
    data = np.fromfile(path, dtype=dt)
    if shape is None:
        # Fallback: use configured defaults
        shape = SHAPE_3D if SPATIAL_DIMS == 3 else SHAPE_2D
    expected = int(np.prod(shape))
    if data.size != expected:
        raise ValueError(f"Size mismatch for {path}: got {data.size}, expected {expected} from shape={shape}")
    arr = data.reshape(shape)
    return arr

def to_torch(img):
    """
    Convert numpy HxW or HxWxD to torch NCHW (2D) or NCDHW (3D) with N=C=1.
    """
    if SPATIAL_DIMS == 2:
        if img.ndim == 2:
            # H, W -> N=1,C=1,H,W
            t = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
        elif img.ndim == 3 and img.shape[-1] == 1:
            # H, W, 1 -> squeeze last
            t = torch.from_numpy(img[..., 0]).unsqueeze(0).unsqueeze(0)
        else:
            raise ValueError(f"Expected 2D data, got shape {img.shape}")
    else:
        # 3D volume H, W, D -> N=1,C=1,D,H,W (DepthAI conv3d expects NCDHW)
        if img.ndim != 3:
            raise ValueError(f"Expected 3D data, got shape {img.shape}")
        t = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).unsqueeze(0)  # (H,W,D)->(D,H,W)->N,C,D,H,W
    return t.float()

def data_range_from_dtype(dtype, sample=None):
    """
    Compute data_range for SSIM constants.
    If integer dtype, use full-range; if float, infer from sample or use 1.0.
    """
    if np.issubdtype(dtype, np.integer):
        info = np.iinfo(dtype)
        return float(info.max - info.min)
    if np.issubdtype(dtype, np.floating):
        if sample is not None:
            m, M = float(np.min(sample)), float(np.max(sample))
            return max(M - m, 1e-6)
        return 1.0
    # default
    return 1.0

def make_pair_key(path):
    """
    Generate a pairing key from filename:
    - strip folder
    - remove 'metalart'/'nometal' marker so both sides map to the same key
    """
    name = os.path.basename(path)
    # Remove 'metalart' or 'nometal' in the body
    for marker in ['Input_Image', 'Simulated_Image']:
        name = name.replace(marker, '')
    return name

# -----------------------------
# Collect files and pair MA↔GT and mask
# -----------------------------
ma_files = sorted(glob.glob(os.path.join(MA_DIR, "*.raw")))
gt_files = sorted(glob.glob(os.path.join(GT_DIR, "*.raw")))
# mask_files=sorted(glob.glob(os.path.join(MA_DIR, "*.raw")))

ma_map = { make_pair_key(p): p for p in ma_files }
gt_map = { make_pair_key(p): p for p in gt_files }
# mask_map = { make_pair_key(p): p for p in mask_files }

common_keys = sorted(set(ma_map.keys()) & set(gt_map.keys()))
if len(common_keys) == 0:
    raise RuntimeError("No matching MA/GT pairs found. Check filenames and pairing logic.")

print(f"Found {len(common_keys)} matching pairs.")

# -----------------------------
# SSIM over all pairs
# -----------------------------
ssim_values = []
for idx, key in enumerate(tqdm(common_keys, desc="SSIM")):

    if idx>=200:
        break
    p_ma = ma_map[key]
    p_gt = gt_map[key]
    # p_mask=mask_map[key]

    # # Parse or fallback to configured shape
    # shape = parse_shape_from_name(p_ma) if PARSE_SHAPE_FROM_NAME else (SHAPE_3D if SPATIAL_DIMS==3 else SHAPE_2D)
    # if shape is None:
    #     # Try GT if MA didn't contain the suffix
    #     shape = parse_shape_from_name(p_gt)
    # if shape is None:
    #     raise ValueError(f"Could not parse shape from filename: {p_ma} or {p_gt}")

    ma_np = np.fromfile(p_ma, dtype=np.float32).reshape(512, 512)
    gt_np = np.fromfile(p_gt, dtype=np.float32).reshape(512, 512)
    # mask_np=np.fromfile(p_mask,dtype=np.float32).reshape(512,512)

    # valid_mask=to_torch(1-mask_np)



    # To torch (N,C,H,W) or (N,C,D,H,W)
    ma_np = normalize_hu(ma_np, min_hu=-1024, max_hu=3072, do_clip=True)
    gt_np = normalize_hu(gt_np, min_hu=-1024, max_hu=3072, do_clip=True)
    ma_t = to_torch(ma_np)
    gt_t = to_torch(gt_np)

    # Decide data_range: from dtype (e.g., 65535 for uint16) or from sample if normalized
    if np.issubdtype(DTYPE, np.integer):
        data_range = data_range_from_dtype(DTYPE)
    else:
        # if you normalized to [0,1], set data_range=1.0
        data_range = data_range_from_dtype(DTYPE, sample=gt_np)

    # Ensure window on correct device/dtype inside compute_SSIM (your earlier function should do .to(device,dtype))
    # ssim_val = compute_SSIM(
    #     gt_t, ma_t, data_range=1.0,
    #     window_size=WINDOW_SIZE, channel=CHANNELS,
    #     size_average=True, spatial_dims=SPATIAL_DIMS
    # )
    ssim_val = ssim(
        gt_np, ma_np, data_range=1.0
    )
    ssim_values.append(ssim_val)

avg_ssim = float(np.mean(ssim_values))
print(f"\nAverage SSIM over {len(ssim_values)} pairs: {avg_ssim:.6f}")


Found 14000 matching pairs.


SSIM:   1%|▏         | 200/14000 [00:13<15:27, 14.89it/s]


Average SSIM over 200 pairs: 0.795919





In [13]:
ssim_values[:20]

[0.8442495251349558,
 0.8567595157469338,
 0.7729889872824913,
 0.8245272018819748,
 0.8328836808695359,
 0.701032812831439,
 0.8503461224976925,
 0.9225199643742835,
 0.7312957572420979,
 0.8082519189332266,
 0.7983551847359592,
 0.7900781002607772,
 0.7814206684515527,
 0.7694909208938282,
 0.789620832927139,
 0.9117608507317206,
 0.7627296586821373,
 0.8233897798807923,
 0.771058483787407,
 0.8365776754428165]

In [1]:
import numpy as np
path = r"D:\AAPM_MAR_dataset\body_validate\Mask\test_body_metalonlymask_img3_512x512x1.raw"
mask_np = np.fromfile(path,dtype=np.float32).reshape(512,512)
mask_np

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [3]:
mask_np.max()

1.0

In [17]:
# In Colab notebook
# ma_dir = "/content/CT-MAR-Training/data/data_npy/input/MA_image"
# li_dir = "/content/CT-MAR-Training/data/data_npy/input/LI"
# gt_dir = "/content/CT-MAR-Training/data/data_npy/simulated/img"
# config_path = "/content/CT-MAR-Training/swin_unet/config.yaml"
# log_dir = "/content/drive/MyDrive/MAR_project/runs"

ma_dir = r"E:\Briya challenge data\input"
gt_dir = r"E:\Briya challenge data\simulated"
config_path = r"C:\Khalifa University Documents\Summer 2025\Training\swin_unet\config.yaml"
log_dir = r"C:\Khalifa University Documents\Summer 2025\Training\runs"

# Then inside the cell:
!python trainer.py --model swinunet --ma_dir "{ma_dir}" --gt_dir "{gt_dir}" --config "{config_path}" --log_dir "{log_dir}"

^C
