In [1]:
import pandas as pd, numpy as np, os, sys, glob, nibabel as nib
from scipy import ndimage


In [2]:
csv_path = './NIHSS_score223.csv'
table_1_5t3_0t =  pd.read_csv(csv_path)
table_1_5t3_0t['ID']

0      is0002
1      is0003
2      is0004
3      is0005
4      is0007
        ...  
217    is0345
218    is0346
219    is0347
220    is0348
221    is0349
Name: ID, Length: 222, dtype: object

In [3]:
class nii_process:
    def __init__(self, base_):
        self.base_ = base_
        self.volume = np.array([])

    def normalize(self, volume):
        img_o = np.float32(volume.copy())
        m = np.mean(img_o)
        s = np.std(img_o)
        volume = np.divide((img_o - m), s)
        image = volume.astype("float32")
        return image

    def resize_volume(self, img,size,depth):
        """Resize across z-axis"""
        # Set the desired depth
        current_depth = img.shape[-1]
        current_width = img.shape[0]
        current_height = img.shape[1]
        img = ndimage.zoom(img, (size/current_height, size/current_width, 1), order=0)
        return img

    def process_scan(self,path_img, path_msg):

        image_o = nib.load(path_img)
        masks_o = nib.load(path_msg)
        affine = image_o.header.get_best_affine()

        if len(image_o.shape) == 4:
            image = image_o.get_fdata()
            masks = masks_o.get_fdata()
            width,height,queue,_ = image.shape
            image = image[:,:,:,1]
            image = np.reshape(image,(width,height,queue))
            masks = np.reshape(masks,(width,height,queue))
        else:
            image = image_o.get_fdata()
            masks = masks_o.get_fdata()

        image = self.normalize(image)
        if True:
            image = self.resize_volume(image, 384, 28)
            masks = self.resize_volume(masks, 384, 28)

            image = np.where(masks, image, image*0)
        
        self.slice_n = image.shape[-1]
        nii_name_slices = (os.path.split(path_img))
        nii_name_slices = nii_name_slices[1].split('.')[0]
        # print(nii_name_slices)
        size = image.shape[0]
        if image.shape[-1] >28:
            print(image.shape[-1], path_img)
        if image.shape[-1] !=28:
            black_slice = np.zeros((size,size,))
            new_mask = np.concatenate((masks, np.zeros((size,size, (28-image.shape[-1])))), axis=-1)
            new_image = np.concatenate((image, np.zeros((size,size, (28-image.shape[-1])))), axis=-1)
            adjusted_msk = nib.Nifti1Image(new_mask, affine)
            adjusted_seg = nib.Nifti1Image(new_image, affine)
            adjusted_msk.header['pixdim'] = masks_o.header['pixdim']
            adjusted_seg.header['pixdim'] = image_o.header['pixdim']
            # Save as NiBabel file
            adjusted_seg.to_filename(os.path.join(self.base_, f'{nii_name_slices[0:-1]}.nii.gz'))

In [4]:
process_stack = ['train','valid','test']
# process_stack = ['train']
prepare_data = nii_process('./dataset/S2_data1.5&3.0_seg/')

_list = sorted(os.listdir(os.path.join('./dataset/original_data1.5&3.0/')))
_len = len(_list)
for j in range(0, _len, 2):
    for x in table_1_5t3_0t['ID']:
        if x in _list[j]:
            prepare_data.process_scan(f'./dataset/original_data1.5&3.0/{_list[j]}', f'./dataset/original_data1.5&3.0/{_list[j+1]}')    
            # prepare_data.process_scan(i, f'./dataset/original_data1.5&3.0/{i}/{_list[j]}', f'./dataset/original_data1.5&3.0/{i}/{_list[j+1]}')
        

In [5]:
            # import matplotlib.pyplot as plt

            # for i in range (new_image.shape[-1]):
            #     print(new_image[...,i].shape)
            #     fig = plt.figure(figsize=(10,10))
            #     ax1 = fig.add_subplot(1,4,1)
            #     ax1.imshow(new_image[...,i], cmap='bone')
            #     ax2 = fig.add_subplot(1,4,2)
            #     ax2.imshow(image[...,i], cmap='bone')
            #     ax3 = fig.add_subplot(1,4,3)
            #     ax3.imshow(images[...,i], cmap='bone')
            #     ax4 = fig.add_subplot(1,4,4)
            #     ax4.imshow(masks[...,i], cmap='bone')
            #     plt.show()  