In [3]:
import nibabel as nib
%matplotlib inline

import SimpleITK as sitk

import numpy as np 
import pandas as pd 


import os
import scipy.ndimage
import matplotlib.pyplot as plt

%matplotlib inline

In [4]:
img = nib.load('/data/datasets/test_task_cmai/Brain_CT_labeling/1.nii.gz')
img_seg = nib.load('/data/datasets/test_task_cmai/Brain_CT_labeling/1-seg.nii.gz')

In [5]:
directory = '/data/datasets/test_task_cmai/Brain_CT_labeling/'

segs = []
imgs = []

for filename in os.listdir(directory):
    if filename.endswith(".gz") or filename.endswith(".nii"):
            if "seg" in filename:
                segs.append(filename)
            else:
                imgs.append(filename)
    else:
        continue

In [6]:
sorted(imgs)

['1.nii.gz',
 '2.nii.gz',
 '3.nii.gz',
 '4.nii.gz',
 '5.nii',
 '6.nii.gz',
 '7.nii.gz']

In [7]:
class Dataset:
    
    def __init__(self, path, images, segms):
        self.path = path
        self.images = images
        self.segms = segms
    
    def __getitem__(self, idx):
        
        if idx >= len(self.images):
            raise IndexError("list index out of range")
        
        img = nib.load(self.path + self.images[idx]).get_fdata()
        img_seg = nib.load(self.path + self.segms[idx]).get_fdata()
        return { "image": img, "seg": img_seg}
        
    def __len__(self):
        return len(images)

In [8]:
dataset = Dataset(directory, imgs, segs)

In [27]:
np.unique(dataset[1]["seg"])

array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15., 16., 17., 18., 19., 20.])

In [14]:
 def read_transform_txt(path):

    transform = None
    with open( path, 'r' ) as f:
        for line in f:

            if line.startswith( 'Parameters:' ):
                values = line.split( ': ' )[1].split( ' ' )
                values = [float( e ) for e in values if ( e != '' and e != '\n' )]
                transform_upper_left = np.reshape( values[0:9], ( 3, 3 ) )
                translation = values[9:]
                

            if line.startswith( 'FixedParameters:' ):
                values = line.split( ': ' )[1].split( ' ' )
                values = [float( e ) for e in values if ( e != '' and e != '\n' )]
                center = values


    offset = np.ones( 4 )
    for i in range( 0, 3 ):

        offset[i] = translation[i] + center[i];

        for j in range( 0, 3 ):

            offset[i] -= transform_upper_left[i][j] * center[i]

    transform = np.vstack((transform_upper_left, [0, 0, 0]))
    transform = np.hstack((transform, np.reshape( offset, (4, 1))))

    return transform

In [15]:
transform = read_transform_txt('/data/datasets/test_task_cmai/Brain_CT_labeling/1.txt')

In [16]:
transform.shape

(4, 4)

In [17]:
image = sitk.ReadImage("/data/datasets/test_task_cmai/Brain_CT_labeling/1.nii.gz")

rot_matrix = transform[:3, :-1]
offset = transform[:3,-1:]

e3d_transform = sitk.Euler3DTransform()
e3d_transform.SetMatrix(rot_matrix.reshape(9))
e3d_transform.SetTranslation(offset.reshape( 3))
e3d_transform.TransformPoint([0,0,0])


image = sitk.Resample(image, e3d_transform, sitk.sitkNearestNeighbor)
image_data = sitk.GetArrayFromImage(image)
image_data = np.moveaxis(image_data, 0, -1)
