In [61]:
from swin_model import Swin_model
from unet_model import Unet_model
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
import monai

In [2]:
def load_model():
    
    swin_8687 = Swin_model("/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/swin-again_no_back_1000hu_8655.pth")
    swin_8675 = Swin_model("/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/swin-again_no_back_1000hu_8675.pth")
    swin_8655 = Swin_model("/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/swin-again_no_back_1000hu_8687.pth")
    unet_8530 = Unet_model("/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/unet_1000_hu_160_0853.pth")
    unet_8550 = Unet_model("/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/unet_1000_hu_160_8550.pth")
    unet_8551 = Unet_model("/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/unet_1000_hu_160_w_augmentations_8551.pth")
    
    return swin_8687, swin_8675, swin_8655, unet_8530, unet_8550, unet_8551

In [3]:
swin_8687, swin_8675, swin_8655, unet_8530, unet_8550, unet_8551 = load_model()
ensemble_weights = [88, 87, 86.55, 85.30, 85.50, 85.51]

### Dataset

In [50]:
import glob
import os
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd, 
    EnsureTyped,
    EnsureType,
    Invertd,
    KeepLargestConnectedComponent,
    AddChanneld,
    ToTensord

)
test_transforms = Compose(
    [
        LoadImaged(keys=["images"]),
        EnsureChannelFirstd(keys=["images"]),
        Orientationd(keys=["images"], axcodes="LPS"),
        ScaleIntensityRanged(
            keys=["images"], a_min=-1000, a_max=1000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["images"], source_key="images"),
        EnsureTyped(keys=["images"]),
    ]
)

def test_dataloader(path2input, test_transforms = test_transforms):
    root_dir = path2input
    test_files_path = sorted(glob.glob(os.path.join(root_dir, "**/*.nii.gz"), recursive = True))
    test_data = [{"images": image_name } for image_name in test_files_path]
    test_ds = Dataset(data = test_data, transform=test_transforms)
    test_loader = DataLoader(test_ds, batch_size = 1, shuffle = False)
    return test_loader

In [51]:
loader = test_dataloader("./test_inputs/")

In [63]:
from monai.inferers import sliding_window_inference


In [64]:
def predict_and_save(loader, model, path2save,test_transforms, roi_size = (288, 288, 288), sw_batch_size = 8):
    os.makedirs(path2save, exist_ok=True)
    post_transforms = Compose([
        Invertd(
            keys="pred",
            transform=test_transforms,
            orig_keys="images",
            meta_keys=None,
            orig_meta_keys=None,
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
        ),
        AsDiscreted(keys="pred", argmax=True),
        SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir=path2save, output_postfix='seg', resample=False),
    ])   
    
    model.eval()
    with torch.no_grad():
        for test_data in loader:
            test_inputs = test_data["images"].to(device)
            test_data["pred"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)
            test_data = [post_transforms(i) for i in decollate_batch(test_data)]
    torch.cuda.empty_cache()


In [None]:
## Unet
# predict_and_save(loader, unet_8530, "./test_outputs/unet_8530_out/", test_transforms)
# predict_and_save(loader, unet_8550, "./test_outputs/unet_8550_out/", test_transforms)
# predict_and_save(loader, unet_8551, "./test_outputs/unet_8551_out/", test_transforms)

#Swin
%time
predict_and_save(loader, swin_8655, "./test_outputs/swin_8655_out/", test_transforms, roi_size = (96, 96, 96), sw_batch_size = 4)
predict_and_save(loader, swin_8675, "./test_outputs/swin_8675_out/", test_transforms, roi_size = (96, 96, 96), sw_batch_size = 4)
predict_and_save(loader, swin_8687, "./test_outputs/swin_8687_out/", test_transforms, roi_size = (96, 96, 96), sw_batch_size = 4)



CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 11.9 µs


None of the inputs have requires_grad=True. Gradients will be None


2022-08-02 13:33:13,097 INFO image_writer.py:190 - writing: test_outputs/swin_8655_out/PA000005/PA000005_seg.nii.gz
2022-08-02 13:46:40,340 INFO image_writer.py:190 - writing: test_outputs/swin_8655_out/PA000016/PA000016_seg.nii.gz
2022-08-02 14:00:06,800 INFO image_writer.py:190 - writing: test_outputs/swin_8675_out/PA000005/PA000005_seg.nii.gz
2022-08-02 14:13:27,728 INFO image_writer.py:190 - writing: test_outputs/swin_8675_out/PA000016/PA000016_seg.nii.gz
2022-08-02 14:26:53,304 INFO image_writer.py:190 - writing: test_outputs/swin_8687_out/PA000005/PA000005_seg.nii.gz


In [None]:


def predict_and_save(loader, path2model, path2save,test_transforms, roi_size = (96, 96, 96), sw_batch_size = 4):
    print("model-loaded")

    os.makedirs(path2save, exist_ok=True)
    post_transforms = Compose([
        Invertd(
            keys="pred",
            transform=test_transforms,
            orig_keys="images",
            meta_keys=None,
            orig_meta_keys=None,
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
        ),
        AsDiscreted(keys="pred", argmax=True),
        SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir=path2save, output_postfix='seg', resample=False),
    ])   
    
    model.eval()
    with torch.no_grad():
        for test_data in loader:
            test_inputs = test_data["images"].to(device)
            test_data["pred"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)
            test_data = [post_transforms(i) for i in decollate_batch(test_data)]
