In [1]:
import os
import glob
import torch
import numpy as np
import scipy.ndimage
import nibabel as nib
from torch.utils.data import Dataset
from torchvision import transforms
import napari

In [2]:
path_dir_train = "E:/train_dir/images"
images_train = sorted(glob.glob(os.path.join(path_dir_train,  "*.nii.gz")))

In [3]:
class MyDataset(Dataset):
    def __init__(self, path):
        self.images_list = [nib.load(image_path) for image_path in path]
    
    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, idx):
        nii_image = self.images_list[idx]
        data = np.asarray(nii_image.dataobj)
        return data

In [4]:
# resampling of voxels 
def resample(data, old_spacing, new_spacing=[3,3,3]):
    
    old_spacing = np.asarray(old_spacing)
    new_spacing = np.asarray(new_spacing)
    old_spacing = [round(i,2) for i in old_spacing]
    scale_factor = old_spacing / new_spacing
        
    new_shape = data.shape * scale_factor
    new_shape = np.round(new_shape)
    zoom_factor = new_shape / data.shape
    new_spacing = old_spacing / zoom_factor
    new_spacing = np.asarray(new_spacing)
        
    data = scipy.ndimage.zoom(data, zoom_factor, order=1, prefilter=True)
    return data

In [5]:
train_data = MyDataset(images_train)

In [6]:
train_data = list(train_data)

In [7]:
for i in range(len(images_train)):
    train_data[i] = resample(train_data[i], [1,1,3], [1,1,1])
    train_data[i] = torch.from_numpy(train_data[i])
    train_data[i] = transforms.functional.resize(train_data[i], size=(512,388))
    train_data[i] = torch.reshape(train_data[i], (train_data[i].shape[0]*train_data[i].shape[1]*train_data[i].shape[2],))
    train_data[i] = train_data[i].numpy()

In [8]:
train_data = np.asarray(train_data)
vec_mean = np.mean(train_data, axis=0)
mat = np.reshape(vec_mean, (512,512,388))

In [9]:
mat.shape

(512, 512, 388)

In [10]:
print(np.amin(mat),np.amax(mat)) #check the internsity range of the average image (min=-1922, max=428)

-1922.341463414634 428.7560975609756


In [11]:
V = napari.Viewer()
V.add_image(mat)

<Image layer 'mat' at 0x1e906d5a280>