In [99]:
import torch
from monai.networks.nets import UNet
from trainValTestGreyscale import getData
from torch.utils.data import DataLoader
import monai
import cv2
import numpy as np
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
from monai.transforms import (
    Activations,
    AddChanneld,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandRotate90d,
    RandFlipd,
    Rand2DElasticd,
    ScaleIntensityd,
    EnsureTyped,
    EnsureType,
    RandGaussianNoised
)

In [100]:
#load desired model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256), #16, 32, 64....
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

model.load_state_dict(torch.load(
    "ALL_4l_rotFlip_1000epochs.pth"))
threshold = 0.5
model.eval()

UNet(
  (model): Sequential(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (adn): ADN(
            (N): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (unit1): Convolution(
          (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (adn): ADN(
            (N): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
      )
      (residual): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): ResidualUnit(
          (conv): Sequential(
            (unit

In [101]:
# setup dataset and dataloaders with the test data
_, _, _, _, test_images, test_segs = getData.getImageSegTrainValTest("ALL")

test_files = [{"img": img, "seg": seg}
                for img, seg in zip(test_images, test_segs)]

test_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        EnsureTyped(keys=["img", "seg"]),
    ]
)

test_ds = monai.data.Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1,
                        num_workers=4, collate_fn=list_data_collate)
test_iter = iter(test_loader)
post_trans = Compose([EnsureType(), Activations(
        sigmoid=True), AsDiscrete(threshold=0.5)])

In [102]:
# save test set segmentations
for i, data in enumerate(test_iter):
  img, label = data["img"].to(
                device), data["seg"].to(device) 
  output = model(img)
  test_outputs = [post_trans(i) for i in decollate_batch(output)]
  test_outputs = test_outputs[0][0].cpu().numpy()
  test_outputs = np.where(test_outputs ==1, 255, 0)
  test_outputs = np.fliplr(np.rot90(test_outputs, k = 3))
  test = cv2.imwrite(r"testSegmentations\testSeg_" + str(i) +".png", test_outputs)
