In [3]:
import nibabel as nib
import numpy as np
from monai.metrics import DiceMetric, HausdorffDistanceMetric
#from monai.inferers import SliceInferer
import torch
from segmentation_models_pytorch import Unet
from non_monai_metrics import NonMONAIMetrics
from monai.networks import one_hot

from pathlib import Path



In [56]:
ckpt_path = Path("/home/jesse/BRICIA/MVH_JPhitidis_PhD/canon_placement_y2/SLAug/train_stroke_val_stroke/logs/2024-02-13T12-46-06_seed23_NEW_efficientUnet_T1_to_DWI/checkpoints/latest.pth")
out_channels = 2

modality = Path("flair")
pred_dir = Path("/home/jesse/BRICIA/MVH_JPhitidis_PhD/canon_placement_y2/SLAug/test")
save_preds = False

image_paths = sorted(Path(f"/home/jesse/BRICIA/MVH_JPhitidis_PhD/canon_placement_y2/jan2024/data/jan2024new/test/{modality}").glob("*.nii.gz"))
label_paths = sorted(Path(f"/home/jesse/BRICIA/MVH_JPhitidis_PhD/canon_placement_y2/jan2024/data/jan2024new/test/labels_pathology").glob("*.nii.gz"))

In [57]:
def dice_score(pred, gt):
    num = 2 * torch.sum(pred * gt)
    denom = torch.sum(pred) + torch.sum(gt)
    return num / denom

In [58]:
# Load model

model = Unet(
    encoder_name="efficientnet-b2",
    encoder_weights=None,
    in_channels=1,
    classes=out_channels,
    activation=None
)

ckpt = torch.load(ckpt_path)

model.load_state_dict(ckpt["model"])
model.cuda()
model.eval()

def predictor(inp):
    
    out_array = torch.empty_like(inp)
    out_array_soft = torch.empty_like(inp)
    
    for slc in range(inp.shape[-1]):
    
        out = torch.permute(model(torch.permute(inp[..., slc], dims=(0,1,3,2))), dims=(0,1,3,2)) # SLAug z slices are permuted in x/y because theyu load with sitk not nib
        if out_channels == 20:
            out = out[:, (0, 19), ...]
        if out.shape[1] == 2:
            pred_soft = torch.softmax(out, dim=1)[:, 1, ...]
        else:
            pred_soft = None
        
        pred = torch.argmax(out, dim=1, keepdim=True)
        
        out_array[..., slc] = pred
        if pred_soft is not None:
            out_array_soft[..., slc] = pred_soft
        else:
            out_array_soft = None
    
    return out_array, out_array_soft

In [59]:
DiceMetricInstance = DiceMetric(include_background=True, reduction="mean") # ignore_empty not available with this version of monai but no empty masks for path anyway
HausdorffDistanceMetricInstance = HausdorffDistanceMetric(include_background=True, reduction="mean", percentile=95)
NonMONAIMetricInstance = NonMONAIMetrics(include_background=True, reduction="mean")
total=len(image_paths)

check_dice = []
for i,  (image_path, label_path) in enumerate(zip(image_paths, label_paths)):
    
    # Load image and label nii
    im = nib.load(image_path)
    lab = nib.load(label_path)
    
    # Convert to tensors and normalise image
    im_data = torch.from_numpy(im.get_fdata().astype(np.float32)).to(device="cuda").unsqueeze(0).unsqueeze(0)
    im_data = (im_data - im_data.mean()) / im_data.std()
    lab_data = torch.from_numpy(lab.get_fdata().astype(np.uint8)).unsqueeze(0).unsqueeze(0)
    
    # Do pred on volume
    pred, pred_soft = predictor(im_data)
    
    ### for anat ###
    # pred = one_hot(pred, num_classes=19, dim=1)
    # lab_data = one_hot(lab_data, num_classes=19, dim=1)
    ################
    
    # Get metrics
    dice = DiceMetricInstance(pred.cpu(), lab_data)  
    hd95 = HausdorffDistanceMetricInstance(pred.cpu(), lab_data)
    if pred_soft is not None:
        pre, rec, lf1, lpre, lrec, ap = NonMONAIMetricInstance(pred.cpu(), lab_data.cpu(), pred_soft.cpu())
    else:
        pre, rec, lf1, lpre, lrec = NonMONAIMetricInstance(pred.cpu(), lab_data.cpu())
    
    # Save
    if save_preds:
        pred = nib.Nifti1Image(pred[0,0,...].cpu().numpy().astype(np.uint8), affine=lab.affine, header=lab.header)
        nib.save(pred, pred_dir / label_path.name)
        
    print(f"{i+1}/{total}")
    print(dice)

1/127
tensor([[0.]])
2/127
tensor([[0.0198]])
3/127
tensor([[0.]])
4/127
tensor([[0.]])
5/127
tensor([[0.]])
6/127
tensor([[0.0082]])
7/127
tensor([[0.0485]])
8/127
tensor([[0.]])
9/127
tensor([[0.]])
10/127
tensor([[0.]])
11/127
tensor([[0.1045]])
12/127
tensor([[0.0008]])
13/127
tensor([[0.]])
14/127
tensor([[0.0157]])
15/127
tensor([[0.]])
16/127
tensor([[0.]])
17/127
tensor([[0.0012]])
18/127
tensor([[0.0041]])
19/127
tensor([[0.0402]])
20/127
tensor([[0.0175]])
21/127
tensor([[0.0003]])
22/127
tensor([[0.0094]])
23/127
tensor([[0.0041]])
24/127
tensor([[0.]])
25/127
tensor([[0.1371]])
26/127
tensor([[0.0071]])
27/127
tensor([[0.3701]])
28/127
tensor([[0.0203]])
29/127
tensor([[0.]])
30/127
tensor([[0.]])
31/127
tensor([[0.]])
32/127
tensor([[0.0002]])
33/127
tensor([[0.1987]])
34/127
tensor([[0.]])
35/127
tensor([[0.0266]])
36/127
tensor([[0.3254]])
37/127
tensor([[0.]])
38/127
tensor([[0.0184]])
39/127
tensor([[0.0545]])
40/127
tensor([[0.0001]])
41/127
tensor([[0.0080]])
42/127


In [60]:
dice = DiceMetricInstance.aggregate()
hd95 = HausdorffDistanceMetricInstance.aggregate()
pre, rec, lf1, lpre, lrec, ap = NonMONAIMetricInstance.aggregate()

In [61]:
print("dice:", dice.item() * 100)
print("HD95:", hd95.item() * 2)
print("Pre:", pre.item() * 100)
print("Rec:", rec.item() * 100)
print("LF1:", lf1.item() * 100)
print("LPre:", lpre.item() * 100)
print("LRec:", lrec.item() * 100)
print("Ap:", ap.item() * 100)

dice: 2.941836230456829
HD95: 67.31951191668948
Pre: 3.4171215485057256
Rec: 6.273188270221579
LF1: 6.160284857672405
LPre: 4.953289108325355
LRec: 26.893266565526304
Ap: 3.3533755104959297
