In [14]:
import torch
import monai
import numpy as np
from src.constants import *
from src.model.baselines import *
from src.data.transforms import train_transform, test_transform, all_transforms
import os
import pickle
from src.utils.metrics import *
from scipy.spatial.distance import directed_hausdorff
from tqdm import tqdm
from src.model.my_model import MyModel


In [12]:

filenames = os.listdir(DATA_PATH + "/test_images")
filenames = [f for f in filenames if 'mhd' in f]
test_dataset = monai.data.CacheDataset(filenames, transform=test_transform, num_workers=16)
test_loader = monai.data.DataLoader(test_dataset, batch_size=1, shuffle=True)



Loading dataset: 100%|██████████| 10/10 [01:49<00:00, 10.96s/it]


In [15]:
model = MyModel(in_channels=1,
                out_channels=3,
                lower_channels=16,
                big_channel=16,
                patch_size=8,
                embed_dim=512,
                mode="normal",
                old_embedder=False)

In [19]:
count_learnable_parameters(model.vit.vit)

33928704

In [20]:
weights = torch.load(f"{MODEL_SAVE_PATH}/new_16_512_0.8480931361516317.pt")

In [21]:
model.load_state_dict(weights)

<All keys matched successfully>

In [23]:
model.eval().cuda()
outs = []

largest_component = monai.transforms.KeepLargestConnectedComponent()

visualised = False

with torch.no_grad():
    for d in test_loader:
        img = d['img'].to(DEVICE)
        mask = d['mask'].to(DEVICE)

        out = monai.inferers.sliding_window_inference(img,
                                                      roi_size=CROP_SIZE,
                                                      sw_batch_size=BATCH_SIZE,
                                                      predictor=model,
                                                      overlap=0.75,
                                                      sw_device=DEVICE,
                                                      device="cpu",
                                                      progress=True,
                                                      mode="constant")
        out = torch.argmax(out, 1, keepdim=True).to(DEVICE)
        out_k = largest_component(out)
        outs.append((img.detach().cpu().numpy(), mask.detach().cpu().numpy(), out.detach().cpu().numpy(), out_k.detach().cpu().numpy()))


100%|██████████| 160/160 [01:14<00:00,  2.16it/s]
100%|██████████| 258/258 [01:55<00:00,  2.23it/s]
100%|██████████| 258/258 [01:55<00:00,  2.24it/s]
100%|██████████| 233/233 [01:46<00:00,  2.20it/s]
100%|██████████| 282/282 [02:08<00:00,  2.20it/s]
100%|██████████| 282/282 [02:06<00:00,  2.23it/s]
100%|██████████| 184/184 [01:23<00:00,  2.21it/s]
100%|██████████| 375/375 [02:50<00:00,  2.20it/s]
100%|██████████| 196/196 [01:28<00:00,  2.22it/s]
100%|██████████| 352/352 [02:39<00:00,  2.21it/s]


In [5]:
def compute_hausdorff_distance(mask1, mask2, spacing):
    """
    Compute the Hausdorff distance between two 3D masks for each class

    :param mask1: The first 3D mask
    :param mask2: The second 3D mask
    :param spacing: The pixel spacing
    :return: A list of Hausdorff distances for each class
    """

    # Compute the Hausdorff distance for each class
    hausdorff_distances = []
    for c in [1, 2]:
        mask1_c = (mask1 == c)
        mask2_c = (mask2 == c)

        # Get the coordinates of the non-zero elements
        mask1_coords = np.array(np.nonzero(mask1_c)).T
        mask2_coords = np.array(np.nonzero(mask2_c)).T

        # Apply spacing to the coordinates
        mask1_coords = mask1_coords * spacing
        mask2_coords = mask2_coords * spacing

        # Compute the Hausdorff distance
        hd_c = max(directed_hausdorff(mask1_coords, mask2_coords)[0], directed_hausdorff(mask2_coords, mask1_coords)[0])
        hausdorff_distances.append(hd_c)

    return hausdorff_distances

In [28]:
# o = [np.array(x) for x in outs]
# oo = [np.squeeze(x, axis=(1,2)) for x in o]
pickle.dump(oo, open(f"{RESULTS_SAVE_PATH}/Mine/1.p", "wb"))
# oo = pickle.load(open(f"{RESULTS_SAVE_PATH}/Mine/1.p", "rb"))

In [33]:
dss = []
hss = []
for img, mask, pred, pred_k in tqdm(oo):

    d = dice_scores(torch.tensor(pred_k), torch.tensor(mask))
    if 0 in d:
        continue
    dss.append(d)
    h = compute_hausdorff_distance(pred_k, mask, (0.9, 1.2, 1.2))
    hss.append(h)

ds = np.array(ds)
hs = np.array(hs)

100%|██████████| 10/10 [00:23<00:00,  2.31s/it]


In [36]:
sum(hss)/len(hss)

array([38.73370325, 73.41376718])

In [37]:
sum(dss)/len(dss)

array([0.99948896, 0.89298643, 0.68924108])

In [31]:
sum(hs)/len(hs)

array([38.73370325, 73.41376718])

In [32]:
sum(ds)/len(ds)

array([0.99948896, 0.89298643, 0.68924108])

In [34]:
dss

[array([0.99968851, 0.86870039, 0.91003674]),
 array([0.99982315, 0.89850169, 0.88186651]),
 array([0.99941486, 0.91398263, 0.81243449]),
 array([0.9997856 , 0.94104505, 0.73159832]),
 array([0.99947208, 0.89753211, 0.68503189]),
 array([0.99918985, 0.8170439 , 0.58518022]),
 array([0.99935842, 0.90222538, 0.39258048]),
 array([0.99917918, 0.90486032, 0.51519996])]