In [None]:
import os
import glob
import torch
import numpy as np
import nibabel as nib
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio as psnr
from sklearn.metrics import mean_absolute_error as mae
from monai.metrics import DiceMetric
from models import CycleTransMorph, SpatialTransformer

BASE_PATH = "/mnt/hot/public/Akul/exhale_pred_data"
INSP_PATH = os.path.join(BASE_PATH, "inhale")
EXP_PATH = os.path.join(BASE_PATH, "exhale")
INSP_MASK_PATH = os.path.join(BASE_PATH, "masks", "inhale")
EXP_MASK_PATH = os.path.join(BASE_PATH, "masks", "exhale")
MODEL_PATH = "./experiments/CTM/best_model.pth"

IMG_SIZE = (160, 192, 128) 
BATCH_SIZE = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

Using device: cuda


In [5]:
class LungDataset(Dataset):
    def __init__(self, file_paths):
        self.file_paths = file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        paths = self.file_paths[idx]
        return {
            'insp': torch.from_numpy(np.load(paths['insp'])).float().unsqueeze(0),
            'exp': torch.from_numpy(np.load(paths['exp'])).float().unsqueeze(0),
            'insp_mask': torch.from_numpy(np.load(paths['insp_mask'])).float().unsqueeze(0),
            'exp_mask': torch.from_numpy(np.load(paths['exp_mask'])).float().unsqueeze(0)
        }

patient_ids = [os.path.basename(p).replace('.npy', '') for p in glob.glob(os.path.join(INSP_PATH, "*.npy"))]
patient_ids.sort()
all_files = [{'insp': os.path.join(INSP_PATH, f"{pid}.npy"), 'exp': os.path.join(EXP_PATH, f"{pid}.npy"), 'insp_mask': os.path.join(INSP_MASK_PATH, f"{pid}_INSP_mask.npy"), 'exp_mask': os.path.join(EXP_MASK_PATH, f"{pid}_EXP_mask.npy")} for pid in patient_ids]

split_idx = int(0.9 * len(all_files))
val_files = all_files[split_idx:]

test_dataset = LungDataset(val_files)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Total processed samples: {len(all_files)}")
print(f"Using {len(test_dataset)} samples for testing (10% validation set).")

Total processed samples: 8702
Using 871 samples for testing (10% validation set).


In [None]:
model = CycleTransMorph(img_size=IMG_SIZE).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))
model.eval()
transformer = SpatialTransformer(size=IMG_SIZE).to(DEVICE)

  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


In [4]:
from monai.metrics import DiceMetric
dice_metric = DiceMetric(include_background=True, reduction="mean")
with torch.no_grad():
    for i, batch in enumerate(test_dataloader):
        # ... (rest of the dice score loop)
        warped_mask_img, dvf_i_e, _ = model(inhale_mask.float(), exhale_mask.float())
        pred_exhale_mask_binary = (warped_mask_img > 0.5).float()
        dice_metric(y_pred=pred_exhale_mask_binary, y=exhale_mask)
mean_dice = dice_metric.aggregate().item()
print(f"Average Dice Score: {mean_dice:.4f}")

NameError: name 'inhale_mask' is not defined

In [None]:
psnr_scores = []
mae_scores = []

with torch.no_grad():
    for i, batch in enumerate(test_dataloader):
        print(f"Processing PSNR/MAE for sample {i+1}/{len(test_dataloader)}...")
        inhale_img = batch['inhale_img'].to(DEVICE)
        exhale_img = batch['exhale_img'].to(DEVICE)

        # Predict the DVF
        dvf_i_e, _ = model(inhale_img, exhale_img)
        
        # Warp the inhale image to get the predicted exhale image
        pred_exhale_img = transformer(inhale_img, dvf_i_e)

        # Move tensors to CPU and convert to numpy for skimage metrics
        pred_np = pred_exhale_img.squeeze().cpu().numpy()
        gt_np = exhale_img.squeeze().cpu().numpy()
        
        # Calculate metrics
        data_range = gt_np.max() - gt_np.min()
        psnr_scores.append(psnr(gt_np, pred_np, data_range=data_range))
        mae_scores.append(mae(gt_np, pred_np))

avg_psnr = np.mean(psnr_scores)
avg_mae = np.mean(mae_scores)

print("\n" + "="*30)
print(f"  Average PSNR: {avg_psnr:.4f} dB")
print(f"  Average MAE: {avg_mae:.4f}")
print("="*30)

Processing PSNR/MAE for sample 1/871...


RuntimeError: The size of tensor a (192) must match the size of tensor b (128) at non-singleton dimension 3

In [None]:
def get_jacobian_determinant(dvf):
    """
    Calculates the Jacobian determinant of a 3D deformation vector field.
    DVF shape: (batch, 3, H, W, D)
    """
    # Permute to (H, W, D, 3) for np.gradient
    dvf_numpy = dvf.squeeze().permute(1, 2, 3, 0).cpu().numpy()
    
    # Get gradients of each displacement component
    grad_x_du = np.gradient(dvf_numpy[..., 0], axis=0)
    grad_y_dv = np.gradient(dvf_numpy[..., 1], axis=1)
    grad_z_dw = np.gradient(dvf_numpy[..., 2], axis=2)

    # Construct the Jacobian matrix.
    J = np.zeros(dvf_numpy.shape + (3,))
    J[..., 0, 0] = 1 + grad_x_du
    J[..., 0, 1] = np.gradient(dvf_numpy[..., 0], axis=1)
    J[..., 0, 2] = np.gradient(dvf_numpy[..., 0], axis=2)

    J[..., 1, 0] = np.gradient(dvf_numpy[..., 1], axis=0)
    J[..., 1, 1] = 1 + grad_y_dv
    J[..., 1, 2] = np.gradient(dvf_numpy[..., 1], axis=2)

    J[..., 2, 0] = np.gradient(dvf_numpy[..., 2], axis=0)
    J[..., 2, 1] = np.gradient(dvf_numpy[..., 2], axis=1)
    J[..., 2, 2] = 1 + grad_z_dw

    return np.linalg.det(J)

non_positive_jacobians = []

with torch.no_grad():
    for i, batch in enumerate(test_dataloader):
        print(f"Processing Jacobian for sample {i+1}/{len(test_dataloader)}...")
        inhale_img = batch['inhale_img'].to(DEVICE)
        exhale_img = batch['exhale_img'].to(DEVICE)

        # Predict the DVF
        dvf_i_e, _ = model(inhale_img, exhale_img)
        
        # Calculate Jacobian determinant
        jacobian_det = get_jacobian_determinant(dvf_i_e)
        
        # Count non-positive values
        num_non_positive = np.sum(jacobian_det <= 0)
        total_voxels = np.prod(jacobian_det.shape)
        
        percentage = (num_non_positive / total_voxels) * 100
        non_positive_jacobians.append(percentage)

avg_non_positive = np.mean(non_positive_jacobians)

print("\n" + "="*55)
print(f"  Average Percentage of Non-Positive Jacobian Values: {avg_non_positive:.6f}%")
print("="*55)