In [35]:
import SimpleITK as sitk
import torch
from scripts.datasetloader import testing_ds,val_transforms
from models.Basic3DUnet import *
from monai.data import DataLoader, decollate_batch
from monai.transforms import AsDiscrete, Compose, Invertd
from collections import OrderedDict
from monai.inferers import sliding_window_inference

from monai.transforms import (
    AsDiscreted,
    Compose,
    Invertd,
    SaveImaged,
    EnsureTyped
)

In [None]:
post_transforms = Compose([
    EnsureTyped(keys="pred"),
    # Invert the preprocessing transforms to map the prediction back to the original image space
    Invertd(
        keys="pred",
        transform=val_transforms,
        orig_keys="image",
        meta_keys="pred_meta_dict",
        orig_meta_keys="image_meta_dict",
        meta_key_postfix="meta_dict",
        nearest_interp=True, # Use nearest neighbor for segmentations
        to_tensor=True,
    ),
    # Convert the model's logit outputs to a final, discrete label map (0, 1, etc.)
    AsDiscreted(keys="pred", argmax=True),
])

: 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Basic3DUnet(1, 2, 0).to(device)
model_weightpath = '/home/normansmith/blue_storage/projects/3DSegmentationLearning/final_models/basicUnet.pth'


state_dict = torch.load(model_weightpath, map_location=device)


new_state_dict = OrderedDict()


for k, v in state_dict.items():
    new_key = k.replace('module.', '') 
    new_state_dict[new_key] = v


model.load_state_dict(new_state_dict)


model.eval()
test_loader = DataLoader(testing_ds, batch_size=1, shuffle=False, num_workers=1)
outputdir = "/home/normansmith/blue_storage/projects/3DSegmentationLearning/visualizationexamples"
saver = SaveImaged(
    keys="pred",
    meta_keys="pred_meta_dict", # Use the metadata dictionary created by Invertd
    output_dir=outputdir,
    output_postfix="seg", # Appends "_seg" to the original filename
    resample=False # Do not resample, use the original affine
)
with torch.no_grad():
    for test_data in (test_loader):
        test_inputs = test_data["image"].to(device)

        # Perform sliding window inference
        test_data["pred"] = sliding_window_inference(
            inputs=test_inputs,
            roi_size=(96, 96, 96),
            sw_batch_size=4,
            predictor=model,
            device=device, # Specify device for inference
            progress=True
        )

        # `decollate_batch` separates the batch of 1 back into a list of dictionaries
        test_data = decollate_batch(test_data)

        # Now apply the post-processing transforms on the CPU
        # `d` is a dictionary containing "image", "label", "pred", and metadata
        d = post_transforms(test_data[0])

        # The saver transform now takes this dictionary and saves the "pred" key to a file
        saver(d)

print(f"\nInference complete. Segmentations saved to: {outputdir}")


  win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
