In [None]:
import numpy as np, torch, os, glob
import torch.backends.cudnn as cudnn
import segmentation_models_pytorch as smp
import nibabel as nib
import matplotlib.pyplot as plt
import albumentations as A
from sklearn.metrics import *
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from scipy import ndimage
from tqdm import tqdm   
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True

In [None]:
class DataLoaderSegmentation(object):
    def __init__(self, base_path, transform=None):
        super(DataLoaderSegmentation, self).__init__()
        # path
        # get single or mini batch nii data
        self.data_image_path = sorted(glob.glob(os.path.join(base_path,'image','*.*')))
        self.data_masks_path = sorted(glob.glob(os.path.join(base_path,'masks','*.*')))
        self.transform = transform
    def __getitem__(self, index):
        img_path = self.data_image_path[index]
        mask_path = self.data_masks_path[index]
        image = self.__nii_load__(img_path)
        masks = self.__nii_load__(mask_path)
        if self.transform is not None: #image and masks must be transformed to numpy array
            transformed = self.transform(image=image.copy(), mask=masks.copy())
            image = transformed["image"]
            masks = transformed["mask"]
            return image, masks
        return torch.from_numpy(image.copy()), torch.from_numpy(masks.copy())

    def __len__(self):
        return len(self.data_image_path)
        
    def __getpath__(self):
        return self.path

    def __nii_load__(self, nii_path):
        self.path = nii_path
        image = nib.load(nii_path)
        # print(nii_path)
        affine = image.header.get_best_affine()
        image = image.get_fdata()
        volume = np.float32(image.copy())
        if affine[1, 1] > 0:
            volume = ndimage.rotate(volume, 90, reshape=False, mode="nearest")
        if affine[1, 1] < 0:
            volume = ndimage.rotate(volume, -90, reshape=False, mode="nearest")
        if affine[1, 1] < 0:                 
            volume = np.fliplr(volume)
        return volume

transformv = A.Compose([ToTensorV2(),])

In [None]:
def model_create():
    checkpoint = torch.load(weight_path, map_location=torch.device('cpu'))
    model = smp.Unet(encoder_name='resnet18', encoder_weights=None, in_channels=1, classes=1)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to('cpu')
    return model

In [None]:
def validate(valid_loader, model):
    predict_array = []
    model.eval()
    stream = tqdm(valid_loader)
    with torch.no_grad():
        for i, (images, target) in enumerate(stream, start=1):
            images = images.to('cpu')
            target = target.to('cpu')
            output =  model(images).squeeze(1)
            predict_array.append(output.detach().numpy()>0.3)
    return predict_array

def  train_valid_process_main(model, valid_dataset, batch_size):
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    predict_array = validate(valid_loader, model)

    return np.array(predict_array)

In [None]:
# source data setting
valid_dataset = DataLoaderSegmentation('./dataset/normalized_zscore/valid/', transform=transformv)
# torch weight setting
weight_path = './checkpoint/2021.11.03.t1 - 2DRes18Unet/best-2DRes18Unet - lr_0.003 - DCEL.pt'
# model loader
model = model_create()
predict_array = train_valid_process_main(model, valid_dataset, 1)

In [None]:
valid_loader = DataLoader(valid_dataset, batch_size=782, shuffle=False)

In [None]:
for idx, (img, msk) in enumerate(valid_loader):
    if idx ==0:
        for idx2, j in enumerate(img):
            fig = plt.figure()
            ax1 = fig.add_subplot(1,3,1)
            ax1.imshow(np.squeeze(j, axis=0), cmap='bone')
            ax1.set_title("DWI Slice")
            ax1.get_xaxis().set_visible(False)
            ax1.get_yaxis().set_visible(False)

            ax2 = fig.add_subplot(1,3,2)
            ax2.imshow(np.squeeze(msk[idx2], axis=0), cmap='bone')
            ax2.set_title("Masks")
            ax2.get_xaxis().set_visible(False)
            ax2.get_yaxis().set_visible(False)

            ax3 = fig.add_subplot(1,3,3)
            ax3.imshow(np.squeeze(predict_array[idx2], axis=0), cmap='bone')
            ax3.set_title("Predict")
            ax3.get_xaxis().set_visible(False)
            ax3.get_yaxis().set_visible(False)
            plt.show()
        break