In [4]:
import glob
import os

import nibabel as nib
import numpy as np

from core.postprocess_service import PostprocessService

from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

import graphviz

from IPython.display import IFrame, display

In [5]:
# load global vars
postproc_srv = PostprocessService()
basedir = "/home/ymerel/storage/private/ymerel/test"
mean_path = os.path.join(basedir, 'mean_result.nii')
df_path = os.path.join(basedir, 'dataset.csv')
ids = []
results = []
paths = glob.glob(os.path.join(basedir, '*/'), recursive=True)
for path in paths:
    ids.append(os.path.basename(os.path.dirname(path)))
    results.append(os.path.join(path, '_subject_id_01', 'result.nii'))

In [11]:
# Write mean image from result
mean_nifti_image = postproc_srv.get_mean_image(results, 10)
nib.save(mean_nifti_image, mean_path)

In [3]:
# Load dataset as dataframe and write it as CSV
dataframe = postproc_srv.get_dataframe(basedir, ids)
dataframe.to_csv(df_path, index=False, sep=';')

In [10]:
# Prepare dataset
X = dataframe.drop(columns=["id", "from_mean", "from_ref"])
y = dataframe["from_ref"]

test_size=0.9
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)

In [11]:
# Train model
reg = tree.DecisionTreeRegressor(max_depth=4)
reg.fit(X_train, y_train)

In [14]:
# Test model
def mean_absolute_percentage_error(y_true, y_pred):
    return np.mean(np.abs((y_true - y_pred) / y_true)) * 100
print(f"Mean absolute error (MAE) : {mean_absolute_error(y_test, reg.predict(X_test))}" )
print(f"Mean squared error (MSE) : {mean_squared_error(y_test, reg.predict(X_test))}" )
print(f"Mean absolute percentage error (MAPE) : {mean_absolute_percentage_error(y_test, reg.predict(X_test))}" )

Mean absolute error (MAE) : 0.12354707857755146
Mean squared error (MSE) : 0.025259883653750335
Mean absolute percentage error (MAPE) : 41.306364962874156


In [15]:
# Display tree
def print_tree(clf, f_names, name):
    
    dot_data = tree.export_graphviz(clf, out_file=None, 
                         feature_names=f_names,  
                         filled=True, rounded=True,
                         special_characters=True)  
    graph = graphviz.Source(dot_data)  
    graph.render(name)
    
print_tree(reg, X_train.columns.values, "tree")

filepath = "tree.pdf"
IFrame(filepath, width=700, height=500)