In [None]:
import os
import nibabel as nib
import pandas as pd
import numpy as np
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import AdaBoostClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, recall_score, precision_score
from sklearn.model_selection import KFold


def extract_features(image_path, num_bins=100):
    nii_image =nib.load(image_path)
    
    image_data = nii_image.get_fdata()
    
    flat_data = image_data.ravel()
    
    hist, _ = np.histogram(flat_data, bins=num_bins)
    
    histo_norm = hist / hist.sum()
    
    return histo_norm

def extract_stat_features(image_path):
    features = []
    
    nii_image =nib.load(image_path)
    
    image_data = nii_image.get_fdata()
    
    flat_data = image_data.ravel()
    
    mean = np.mean(flat_data)
    
    median = nd.median(flat_data)
    
    maximum = np.max(flat_data)
    
    std = nd.standard_deviation(flat_data)
    
    var = nd.variance(flat_data)
    
    skew = stats.skew(flat_data,axis=None)
    
    kurtosis = stats.kurtosis(flat_data,axis=None)
    
    features.append([mean,median,maximum,std,var,skew,kurtosis])
    
    return features

#define directorys

input_dir = 'ATR_smoth'
excel_file = 'ATR_training.xlsx'
num_bins = 50

#read labels
labels_df = pd.read_excel(excel_file)
labels = labels_df['label'].values

# Extract histo features for all images

feature_list = []
file_list = sorted(os.listdir(input_dir))
for filename in file_list:
    if filename.endswith('.nii.gz'):
        input_path = os.path.join(input_dir, filename)
        features = extract_features(input_path, num_bins)
        stat_features = extract_stat_features(input_path)
        feature_list.append(features)
        total_features = features_list + stat_features
        
#convert list to array
X = np.array(total_features)

#labels
y = np.array(labels)

#standarize the features 

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

#Cross Validation function for each classifier
def cross_val(clf,X,y,clf_string, cv=5):
    scores = cross_val_score(clf,X, y, cv=cv)
    print('Clf: {}\nAccuracy Mean: {:0.2f}\nStandard Deviation{:0.2f}'.format(clf_string, scores.mean(), scores.std()))


#Testing on multiple models
clfs = []
svm = SVC(kernel='linear', C=5)
clfs.append([svm,'Support Vector Machine'])
lr = LogisticRegression(random_state = 0, solver = 'lbfgs',multi_class='multinomial')
clfs.append([lr,'Logistic Regression'])
ada = AdaBoostClassifier(n_estimators=100)
clfs.append([ada,'AdaBoost'])
knn = KNeighborsClassifier(n_neighbors = 3)
clfs.append([knn,'K-Neighbors'])

for clf, clf_str in clfs:
    cross_val(clf,X_scaled,y,clf_str)

#Returns overall Accuracy Score
def overall_acc(y_true,y_pred):
    return accuracy_score(y_true,y_pred)

#Returns the percent detected(#targets detected and #targets)
def PD(y_true,y_pred):
    return recall_score(y_true,y_pred)

#Returns the percent of false alarms(#false alarms and #non-targets)
def PFA(y_true,y_pred):
    return 1 - accuracy_score(1-y_true,1-y_pred)

def results(X,y,clf):
    kf = KFold(n_splits=5)
    per_det_0 = []
    per_det_1 = []
    per_det_2 = []
    per_det_3 = []
    per_false_alarm = []
    acc = []
    
    for train_i, test_i in kf.split(X):
        X_train, X_test = X[train_i],X[test_i]
        y_train, y_test = y[train_i], y[test_i]
        
        predict = clf.fit(X_train, y_train != 0).predict(X_test)
        
        acc.append(overall_acc(y_test != 0, predict))
        per_det_0.append(PD(y_test != 0),predict)
        per_det_1.append(PD(y_test == 1),predict)
        per_det_2.append(PD(y_test == 2),predict)
        per_det_3.append(PD(y_test == 3),predict)
        per_false_alarm.append(PFA(y_test != 0, predict))
        
    print('Accuracy: ', np.mean(acc))
    print('Percent Detected: ', np.mean(per_det_0))
    print('Percent False Alarm: ', np.mean(per_false_alarm))
    print('Percent Saline Detected: ', np.mean(per_det_1))
    print('Percent Rubber Detected: ', np.mean(per_det_2))
    print('Percent Clay Detected: ', np.mean(per_det_3))

for clf, clf_str in clfs:
    cross_val(clf,X_scaled,y,clf_str)
    