In [None]:
import os
import sys
import cv2
import torch
import glob
import shutil
import numpy as np
import nibabel as nib
from PIL import Image
import SimpleITK as sitk
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

In [None]:
# Read nifti file
def read_nifti_file(filepath):
    """Read and load volume"""
    # Read file
    scan = nib.load(filepath)
    # Get raw data
    scan = scan.get_fdata()
    return scan

def normalize(image, eps=0):
    image = (image - image.min())/(image.max()-image.min()+eps)
    image= image*255
    return image.astype(np.uint8)

In [None]:
## Resample voxel spacing

out_spacing = [1.0, 1.0, 1.0]  # change spacing (mm)

def resampling(file_path):
    images = sitk.ReadImage(file_path)
    ori_spacing = images.GetSpacing()
    ori_size = images.GetSize()
    out_size = [int(ori_size[0] * (ori_spacing[0] / out_spacing[0])),
                int(ori_size[1] * (ori_spacing[1] / out_spacing[1])),
                int(ori_size[2] * (ori_spacing[2] / out_spacing[2])),]
    
    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size)
    resample.SetOutputDirection(images.GetDirection())
    resample.SetOutputOrigin(images.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(images.GetPixelIDValue())
    resampled_image = resample.Execute(images)

    return resampled_image

In [None]:
# Slice 3D to 2D

train_img_dir = 'directory of train image'
train_msk_dir = 'directory of train mask'

train_img_name = os.listdir('directory of train image')
train_msk_name = os.listdir('directory of train mask')

for i in train_img_name:
    images = sitk.GetArrayFromImage(sitk.ReadImage(train_img_dir + i))  # shape : (z, y, x)
    # print(images.shape)
    
    split_name = i.split('.')
    new_name = split_name[0]
    
    
    for index in range(images.shape[0]):
        image = normalize(images[index])
        image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_LINEAR) 
        
        
        filename = new_name + '_%02d.png'%(index+1)
        #cv2.imwrite(train_img_dir + filename, image)
        
for i in train_msk_name:
    images = sitk.GetArrayFromImage(sitk.ReadImage(train_msk_dir + i))  # shape : (z, y, x)
    # print(images.shape)
    
    split_name = i.split('.')
    new_name = split_name[0]
    
    for index in range(images.shape[0]):
        image = normalize(images[index])
        image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_NEAREST) 
        
        filename = new_name + '_%02d.png'%(index+1)
        #cv2.imwrite(train_msk_dir + filename, image)

In [None]:
# Calculate mean and std

images = []
masks = []

for index, (image, mask) in enumerate(dataset):
        
        sys.stdout.write('\r[{}/{}]'.format(index + 1,len(dataset)))
        sys.stdout.flush()

        images.append(image)
        masks.append(mask)
        
img = images[0]
img = np.asarray(img, dtype=np.float32)/255.

mean_values = np.mean(img, axis=(0,1,2))
std_values = np.std(img, axis=(0,1,2))

print('mean : ', mean_values)
print('std : ', std_values)

In [None]:
class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    def __call__(self,image,label):
        image = image.astype(np.float32) / 255
        image = image[...,np.newaxis]
        label = label.astype(np.float32) / 255

        image = (image - self.mean)/self.std  # (340,340,3)
        label = label[...,np.newaxis]  # (340,340,1)

        image = image.transpose((2,0,1))  # (3,340,340)
        label = label.transpose((2,0,1))  # (1,340,340)


        return torch.from_numpy(image), torch.from_numpy(label)



class Normalize_For_NII:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
        
    def __call__(self,image):        
        image = image.astype(np.float32) / 255
        image = (image - self.mean)/self.std  # (340,340,3)
        image = image.transpose((2,0,1))  # (3,340,340)

        return torch.from_numpy(image)