In [1]:
import os
import ants
import numpy as np
import pandas
from sklearn.decomposition import PCA

class PCASkullAverager:
    def __init__(self, train_path=None, test_path=None):
        self.pca = PCA()
        self.train_path = train_path
        self.test_path = test_path
        self.im_dims = self.load_training_imgs()


    def load_training_imgs(self):
        # Load images
        self.df_train = pandas.read_csv(self.train_path)
        self.ants_train = [ants.image_read(self.df_train.iloc[i, 0]) for i in range(self.df_train.shape[0])]
        self.np_train_r = np.array([self.ants_train[i].numpy() for i in range(self.df_train.shape[0])])
        self.im_dims = self.np_train_r.shape[1:]  # Image dimensions (eg 512, 512, 237)
        # I must reshape because scikit-learn only allows 2D (index, image).
        self.np_train_r = [np.reshape(self.np_train_r[i], (np.prod(self.im_dims))) for i in range(self.df_train.shape[0])]
        print("Training images: ", self.df_train.shape[0])
        return self.im_dims
        
    def load_test_imgs(self):
        # Load images
        self.df_test = pandas.read_csv(self.test_path)
        self.ants_test = [ants.image_read(self.df_test.iloc[i, 0]) for i in range(self.df_test.shape[0])]
        self.np_test_r = np.array([self.ants_test[i].numpy() for i in range(self.df_test.shape[0])])
        self.np_test_r = [np.reshape(self.np_test_r[i], (np.prod(self.im_dims))) for i in range(self.df_test.shape[0])]
        print("Test images: ", self.df_test.shape[0])
        
    def fit(self):
        self.pca.fit(self.np_train_r)

    def predict_test(self, j):
        # pca transform
        test_pca = self.pca.transform(self.np_test_r)  # Get coefficients in th projected space
        
        # inverse pca
        inversed_pca = self.pca.inverse_transform(test_pca)  # Go back to the original space from the projected one
        inversed_pca = inversed_pca.reshape((inversed_pca.shape[0], self.im_dims[0], self.im_dims[1], self.im_dims[2]))

        for i, pred in enumerate(inversed_pca):
            registered_pred = ants.from_numpy(pred)
            registered_pred.set_origin(self.ants_test[i].origin)
            registered_pred.set_direction(self.ants_test[i].direction)
            registered_pred.set_spacing(self.ants_test[i].spacing)

            
            path, name = os.path.split(self.df_test.iloc[i, 0])
            index = str(i).zfill(3)
            
            # Here transmation maxtrixes are downloaded in advance. See more about "ants.apply_transforms" in the documentation of "ANTs" libary.
            tansformer_path = [path + '/' + index + '.nii.gz', path + '/' + index + '.mat']
            original_skull_path = './OriginalSkull/sub' + index + '.nrrd'
    
            # inverse registration
            original_skull = ants.image_read(original_skull_path)
            averaged_skull = ants.apply_transforms(original_skull, registered_pred, 
                                                   transformlist=tansformer_path, whichtoinvert=[False, True])
            
            # Save Skull
            output_skull_path = os.path.join(path, 'pred_skull')
            if not os.path.exists(output_skull_path):
                os.makedirs(output_skull_path)            
            name_out = name.replace("warped.nrrd", "averaged_skull.nrrd".format(i))
            ants.image_write(averaged_skull, os.path.join(output_skull_path, name_out))  # Save Skull prediction
            print("  Saved: {}.".format(name_out))
            
            # Save Skull
            averaged_implant = (averaged_skull - original_skull) > 0.5
            output_implant_path = os.path.join(path, 'pred_implant')
            if not os.path.exists(output_implant_path):
                os.makedirs(output_implant_path)
            name_out = name.replace("warped.nrrd", "averaged_implant.nrrd".format(i))
            ants.image_write(averaged_implant, os.path.join(output_implant_path, name_out))  # Save Implant prediction

In [2]:
train_file = './train.csv'  # Train files
test = PCASkullAverager(train_file)

Training images:  25


In [3]:
test.fit()

In [8]:
test.test_path = './test.csv'  # Test files
test.load_test_imgs()
test.predict()

Test images:  11
  Saved: 01_averaged_skull.nrrd.
  Saved: 02_averaged_skull.nrrd.
  Saved: 03_averaged_skull.nrrd.
  Saved: 04_averaged_skull.nrrd.
  Saved: 05_averaged_skull.nrrd.
  Saved: 06_averaged_skull.nrrd.
  Saved: 07_averaged_skull.nrrd.
  Saved: 08_averaged_skull.nrrd.
  Saved: 09_averaged_skull.nrrd.
  Saved: 10_averaged_skull.nrrd.
  Saved: 11_averaged_skull.nrrd.
