In [None]:
import nibabel as nib
import numpy as np
from monai.metrics import DiceMetric, HausdorffDistanceMetric
#from monai.inferers import SliceInferer
import torch
from segmentation_models_pytorch import Unet

from pathlib import Path

In [None]:
ckpt_path = Path("/home/jesse/projects/canon_placement_y2/SLAug/logs/2024-02-12T23-56-50_seed23_NEW_efficientUnet_T1_to_DWI/checkpoints/val_best_epoch_83.pth")
out_channels = 2

modality = Path("dwi")
pred_dir = Path("/home/jesse/BRICIA/MVH_JPhitidis_PhD/canon_placement_y2/SLAug/test")
save_preds = True

image_paths = sorted(Path(f"/home/jesse/BRICIA/MVH_JPhitidis_PhD/canon_placement_y2/jan2024/data/jan2024new/test/{modality}").glob("*.nii.gz"))
label_paths = sorted(Path(f"/home/jesse/BRICIA/MVH_JPhitidis_PhD/canon_placement_y2/jan2024/data/jan2024new/test/labels_pathology").glob("*.nii.gz"))

In [None]:
# Load model

model = Unet(
    encoder_name="efficientnet-b2",
    encoder_weights=None,
    in_channels=1,
    classes=out_channels,
    activation=None
)

ckpt = torch.load(ckpt_path)

model.load_state_dict(ckpt["model"])
model.cuda()
model.eval()

def predictor(inp):
    
    out_array = torch.empty_like(inp)
    
    for slc in range(inp.shape[-1]):
    
        out = torch.permute(model(torch.permute(inp[..., slc], dims=(0,1,3,2))), dims=(0,1,3,2)) # SLAug has z slices permuted in the x/y direction so we do the same
        if out_channels == 20:
            out = out[:, (0, 19), ...]
        else:
            assert out_channels == 2
        pred = torch.argmax(out, dim=1)
        
        out_array[..., slc] = pred
    
    return out_array

In [None]:
DiceMetricInstance = DiceMetric(include_background=False, reduction="mean")
HausdorffDistanceMetricInstance = HausdorffDistanceMetric(include_background=False, reduction="mean", percentile=95)

total=len(image_paths)

for i,  (image_path, label_path) in enumerate(zip(image_paths, label_paths)):
    
    # Load image and label nii
    im = nib.load(image_path)
    lab = nib.load(label_path)
    
    # Convert to tensors and normalise image
    im_data = torch.from_numpy(im.get_fdata()).to(dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0)
    im_data = (im_data - im_data.mean()) / im_data.std()
    lab_data = torch.from_numpy(im.get_fdata()).to(dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    
    # Do pred on volume
    pred = predictor(im_data)
    
    # Get metrics
    dice = DiceMetricInstance(pred.cpu(), lab_data)  
    hd95 = HausdorffDistanceMetricInstance(pred.cpu(), lab_data)
    
    # Save
    if save_preds:
        pred = nib.Nifti1Image(pred[0,0,...].cpu().numpy().astype(np.uint8), affine=lab.affine, header=lab.header)
        nib.save(pred, pred_dir / label_path.name)
        
    print(f"{i+1}/{total}")
    print("dice:", dice)
    print("hd95", hd95)
    print()

In [None]:
dice = DiceMetricInstance.aggregate()
hd95 = HausdorffDistanceMetricInstance.aggregate()

In [None]:
print(dice)
print(hd95 * 2)