In [None]:
import os
import random
import numpy as np
import struct
from scipy import signal
from scipy import ndimage
from random import randrange

from sklearn import metrics
from sklearn import svm
from sklearn import preprocessing
from sklearn.kernel_approximation import *
from sklearn.preprocessing import normalize
from sklearn.metrics import f1_score
%run Utility_general.ipynb

class Utility_MEDICAL(object):
    
    @staticmethod
    def dilate_binary_mask(data, iterations):
        data = data.astype('int32')
        return ndimage.morphology.binary_dilation(data, iterations=iterations).astype(data.dtype)
    
    @staticmethod
    def erode_binary_mask(data, iterations):
        data = data.astype('int32')
        return ndimage.morphology.binary_erosion(data, iterations=iterations).astype(data.dtype)
    
    @staticmethod
    def normalize(data, scale):
        num = data.shape[0]
        for i in range(num):
            data[i,:] = data[i,:] / np.linalg.norm(data[i,:]) * scale
        return data

    @staticmethod
    def shrink_box(data):
        '''
        This function shrinks the whole volume to the mask region
        3D data is required
        '''
        data = data.astype('int32')
        shape = data.shape
        for i in range(shape[0]-1, -1, -1):
            if np.sum(data[i,:,:]) > 0:
                right = i
                break
        for i in range(shape[0]):
            if np.sum(data[i,:,:]) > 0:
                left = i
                break
        for i in range(shape[1]-1, -1, -1):
            if np.sum(data[:,i,:]) > 0:
                back = i
                break
        for i in range(shape[1]):
            if np.sum(data[:,i,:]) > 0:
                front = i
                break
        for i in range(shape[2]-1, -1, -1):
            if np.sum(data[:,:,i]) > 0:
                bottom = i
                break
        for i in range(shape[2]):
            if np.sum(data[:,:,i]) > 0:
                top = i
                break
        return [left, right, front, back, top, bottom]
    
    @staticmethod
    def divide_layers_around_core(data, layer_num):
        '''
        @data: input segmentation, should be binary [0, 1]
        @plane: plane example [0, 0, 0], [shape[0]-1,0,0], [0,shape[1]-1,0]
        @layer_num: segment the segmentation into layer_num layers according to
        their distance from the plane
        '''
        data = data.astype('int32')
        shape = data.shape
        dist_vec = np.zeros(np.sum(data==1))
        cnt = 0
        for i in range(shape[0]-1, -1, -1):
            if data[i,int(shape[1]/2),int(shape[2]/2)] == 1:
                core = [i, int(shape[1]/2),int(shape[2]/2)]
                break
        for i in range(shape[0]):
            for j in range(shape[1]):
                for k in range(shape[2]):
                    if data[i,j,k] == 1:
                        dist_vec[cnt] = (i-core[0])*(i-core[0])+(j-core[1])*(j-core[1])+(k-core[2])*(k-core[2])
                        cnt = cnt + 1
        assert(np.amin(dist_vec) == 0)
                        
        max_dist = np.amax(dist_vec)
        dist_layout = np.zeros(layer_num+1)
        dist_layout[0] = 0
        for i in range(layer_num):
            dist_layout[i+1] = max_dist/layer_num * (i+1)
        dist_layout[layer_num] = dist_layout[layer_num] + 1
            
        layers = [None] * layer_num
        for i in range(layer_num):
            layers[i]  = np.zeros(shape, dtype=np.int32)
        cnt = 0
        for i in range(shape[0]):
            for j in range(shape[1]):
                for k in range(shape[2]):
                    if data[i,j,k] == 1:
                        for interval in range(layer_num):
                            if  dist_vec[cnt] >= dist_layout[interval] and dist_vec[cnt] < dist_layout[interval+1]:
                                layers[interval][i,j,k] = 1
                                break
                        cnt = cnt + 1
        return layers
    
    @staticmethod
    def divide_layers_against_plane(data, plane, layer_num):
        '''
        @data: input segmentation, should be binary [0, 1]
        @plane: plane example [0, 0, 0], [shape[0]-1,0,0], [0,shape[1]-1,0]
        @layer_num: segment the segmentation into layer_num layers according to
        their distance from the plane
        '''
        data = data.astype('int32')
        shape = data.shape
        dist_vec = np.zeros(np.sum(data==1))
        cnt = 0
        for i in range(shape[0]):
            for j in range(shape[1]):
                for k in range(shape[2]):
                    if data[i,j,k] == 1:
                        dist_vec[cnt] = np.sum(np.abs(np.asarray([i,j,k]) - plane))
                        cnt = cnt + 1
        max_dist = np.amax(dist_vec)
        min_dist = np.amin(dist_vec)
        dist_layout = np.zeros(layer_num+1)
        dist_layout[0] = 0
        for i in range(layer_num):
            dist_layout[i+1] = (max_dist - min_dist)/layer_num * (i+1) + min_dist
        dist_layout[layer_num] = dist_layout[layer_num] + 1
            
        layers = [None] * layer_num
        for i in range(layer_num):
            layers[i]  = np.zeros(shape, dtype=np.int32)
        cnt = 0
        for i in range(shape[0]):
            for j in range(shape[1]):
                for k in range(shape[2]):
                    if data[i,j,k] == 1:
                        for interval in range(layer_num):
                            if  dist_vec[cnt] >= dist_layout[interval] and dist_vec[cnt] < dist_layout[interval+1]:
                                layers[interval][i,j,k] = 1
                                break
                        cnt = cnt + 1
        return layers
    
    @staticmethod
    def draw_on_volume(data, shape, dim, items):
        '''
        @data: output from function read_bnd_red_unifieddim
        @shape: shape of the original volume
        @dim: dim to draw, can be 0, 1, or 2
        @items: a list of indices indicating which structures to draw
        '''
        curtain = np.zeros(shape, dtype=np.int32)
        assert(dim < len(data))
        assert(len(items) > 0)
        for item in items:
            assert(item < len(data[dim]))
            for i in range(data[dim][item].shape[1]):
                curtain[data[dim][item][2,i]][data[dim][item][1,i]][data[dim][item][0,i]] = 1
        return curtain
    
    @staticmethod
    def draw_on_volume_ori_intensity(data, vol, seg, tum, erode_iter, kernel_size, dim, items, val_range):
        '''
        @data: output from function read_bnd_red_unifieddim
        @vol: volume with original intensity
        @seg: binary segmentation
        @dim: dim to draw, can be 0, 1, or 2
        @items: a list of indices indicating which structures to draw
        '''
        assert(seg.shape == vol.shape)
        assert(dim < len(data))
        assert(len(items) > 0)
        curtain = np.zeros(vol.shape, dtype=np.float32)
        seg = Utility_MEDICAL.erode_binary_mask(seg, erode_iter)
        if kernel_size > 0:
            kernel = Util_gen.generate_gaussian_kernel(kernel_size, 1, 3)
        for item in items:
            assert(item < len(data[dim]))
            for i in range(data[dim][item].shape[1]):
                x = data[dim][item][2,i]
                y = data[dim][item][1,i]
                z = data[dim][item][0,i]
                if seg[x,y,z] == 1 and tum[x,y,z] == 0:
                    curtain[x][y][z] = 1
        if kernel_size > 0:
            curtain = signal.convolve(curtain, kernel, mode='same')
            new_curtain = np.ones(vol.shape, dtype=np.float32) * np.amin(vol)
            new_curtain[curtain > 1] = vol[curtain > 1]
        else:
            new_curtain = np.ones(vol.shape, dtype=np.float32) * np.amin(vol)
            new_curtain[curtain > 0] = np.random.uniform(low=val_range[0], high=val_range[1], size=vol.shape)[curtain > 0]
            #new_curtain[curtain > 0] = vol[curtain > 0]
        return new_curtain
    
    @staticmethod
    def align_tumor_segmentations(pathVol, pathSeg, pathOut, ext):
        
        def create_bounding_box(shape, dim):
            dim0 = shape[1] - shape[0] + 1
            dim1 = shape[3] - shape[2] + 1
            dim2 = shape[5] - shape[4] + 1
            diff0 = dim - dim0
            diff1 = dim - dim1
            diff2 = dim - dim2
            dim0_left  = int(np.floor(diff0 / 2.0))
            dim1_left  = int(np.floor(diff1 / 2.0))
            dim2_left  = int(np.floor(diff2 / 2.0))
            return [dim0_left, dim0_left+dim0-1, dim1_left, dim1_left+dim1-1,
                    dim2_left, dim2_left+dim2-1]
        '''
        @pathVol: path to volume folder example: "E:/Data2/BreastMass_refine/volumes"
        @pathSeg: path to tumor segmentation folder example: "E:/Data2/BreastMass_refine/tumors_mask"
        @pathOut: path to output tumor mask folder example: "E:/Data2/BreastMass_refine/tumors"
        @ext: extension of the target files example: "nii"
        '''
        idx_arr = []
        filesVol = []
        filesSeg = []
        filesOut  = []
        os.chdir(pathVol)
        for file in glob.glob("*."+ext):
            idx_arr.append(file.split('_')[1])
            filesVol.append(os.path.join(pathVol, file))
            filesOut.append(os.path.join(pathOut, file.split('_')[0]+'_'+file.split('_')[1]+'_tumor.nii'))
        os.chdir(pathSeg)
        cnt = 0
        for file in glob.glob("*."+ext):
            idx_ = file.split('_')[1]
            assert(idx_arr[cnt] == idx_)
            filesSeg.append(os.path.join(pathSeg, file))
            cnt = cnt + 1
        
        dim0_arr = np.zeros(len(filesSeg), dtype=np.int32)
        dim1_arr = np.zeros(len(filesSeg), dtype=np.int32)
        dim2_arr = np.zeros(len(filesSeg), dtype=np.int32)
        for i in range(len(filesSeg)):
            seg = FileIO_MEDICAL.load_nii(filesSeg[i])
            srk_shape = Utility_MEDICAL.shrink_box(seg)
            dim0_arr[i] = srk_shape[1] - srk_shape[0] + 1
            dim1_arr[i] = srk_shape[3] - srk_shape[2] + 1
            dim2_arr[i] = srk_shape[5] - srk_shape[4] + 1
        print(np.amax(dim0_arr), np.amax(dim1_arr), np.amax(dim2_arr))
        
        unified_dim = 256
        Path(pathOut).mkdir(parents=True, exist_ok=True)
        for i in range(len(filesSeg)):
            seg = FileIO_MEDICAL.load_nii(filesSeg[i])
            srk_shape = Utility_MEDICAL.shrink_box(seg)
            rmd_shape = create_bounding_box(srk_shape, unified_dim)
            vol = FileIO_MEDICAL.load_nii(filesVol[i])
            min_val = np.amin(vol)-1
            vol[seg==0] = min_val
            tum = np.ones((unified_dim,unified_dim,unified_dim),dtype=vol.dtype) * min_val
            tum[rmd_shape[0]:rmd_shape[1]+1, rmd_shape[2]:rmd_shape[3]+1, rmd_shape[4]:rmd_shape[5]+1]=vol[srk_shape[0]:srk_shape[1]+1, srk_shape[2]:srk_shape[3]+1, srk_shape[4]:srk_shape[5]+1]
            FileIO_MEDICAL.save_nii(tum, filesOut[i])
    
    def generate_tda_volumes(pathTDA, pathOut, threshold, target_dim, dil_iter, labels,
                             fromVol, kernel_size, sigma, patchMode, patch_shape=(0,0,0)):
        '''
        This function draws extracted topological features on empty volumes.
        The output volumes have the same spatial size.
        @pathTDA: path to TDA results folder example: "E:/Data2/BreastMass_refine/sup"
        @pathOut: path to output tumor mask folder example: "E:/Data2/BreastMass_refine/tumors"
        @threshold: threshold to filter tda structures
        @target_dim: which dimension to draw, 0 1 or 2
        @dil_iter: iterations to dilate mask, int
        @labels: int
        @fromVol: if the intensity of the tda structures are actual volume intensity, bool
        @kernel_size: kernel size of the 3D gaussian kernel
        @sigma: sigma of the 3D gaussian kernel
        @patchMode: if to sample patch, bool
        @patch_shape: if patchMode is true, define the shape of the sample
        '''

        def probe_dimensions(files):
            num = len(files)
            dim0_arr = np.zeros(num, dtype=np.int32)
            dim1_arr = np.zeros(num, dtype=np.int32)
            dim2_arr = np.zeros(num, dtype=np.int32)
            for i in range(num):
                vol = FileIO_MEDICAL.read_dat(files[i])
                dim0_arr[i] = vol.shape[0]
                dim1_arr[i] = vol.shape[1]
                dim2_arr[i] = vol.shape[2]
            print(np.amax(dim0_arr), np.amax(dim1_arr), np.amax(dim2_arr))
            return dim0_arr, dim1_arr, dim2_arr

        def create_bounding_box(shape, dim):
                dim0 = shape[0]
                dim1 = shape[1]
                dim2 = shape[2]
                diff0 = dim - dim0
                diff1 = dim - dim1
                diff2 = dim - dim2
                dim0_left  = int(np.floor(diff0 / 2.0))
                dim1_left  = int(np.floor(diff1 / 2.0))
                dim2_left  = int(np.floor(diff2 / 2.0))
                return [dim0_left, dim0_left+dim0-1, dim1_left, dim1_left+dim1-1,
                        dim2_left, dim2_left+dim2-1]

        def sample_patch(data, sample_shape):
            shape = data.shape
            dims = len(shape)
            assert(dims == 3)
            assert(len(sample_shape) == dims)
            for dim in range(dims):
                if not sample_shape[dim] <= shape[dim]:
                    return []
            num = int(np.floor(np.prod(shape) / np.prod(sample_shape)))
            dim0 = shape[0] - sample_shape[0]
            dim1 = shape[1] - sample_shape[1]
            dim2 = shape[2] - sample_shape[2]
            patches = [None] * num
            for i in range(num):
                s0 = randrange(dim0+1)
                s1 = randrange(dim1+1)
                s2 = randrange(dim2+1)
                patch = data[s0:s0+sample_shape[0],s1:s1+sample_shape[1],s2:s2+sample_shape[2]]
                patches[i] = patch
            return patches

        filesVol  = []
        filesBnd  = []
        filesPers = []
        filesOut  = []
        label_cnt = 0
        os.chdir(pathTDA)
        for file in glob.glob("*.dat"):
            filesVol.append(os.path.join(pathTDA, file))
            filesOut.append(os.path.join(pathOut, file.split('_')[0]+'_'+file.split('_')[1]+'_'+str(labels[label_cnt])+'_TDA.nii'))   
            label_cnt = label_cnt + 1
        for file in glob.glob("*.bnd"):
            filesBnd.append(os.path.join(pathTDA, file))
        for file in glob.glob("*.pers.txt"):
            filesPers.append(os.path.join(pathTDA, file))
        num = len(filesVol)
        assert(num == len(filesBnd))
        assert(num == len(filesPers))
        assert(num == len(labels))
        dim0_arr, dim1_arr, dim2_arr = probe_dimensions(filesVol)
        if kernel_size > 0:
            kernel = Util_gen.generate_gaussian_kernel(kernel_size, sigma, 3)

        patch_cnt = 0
        unified_dim = 256
        Path(pathOut).mkdir(parents=True, exist_ok=True)
        for i in range(num):
            bnd = FileIO_MEDICAL.read_bnd_red_unifieddim(filesBnd[i])
            pers = FileIO_MEDICAL.read_pers_txt(filesPers[i])
            bnd_filt = Compute_CycleKernel.filter_bnd_or_red(bnd, pers, threshold)
#             bnd_drw = Utility_MEDICAL.draw_on_volume(bnd_filt, (dim0_arr[i], dim1_arr[i], dim2_arr[i]),
#                                                      target_dim, np.arange(0,len(bnd_filt[target_dim])))

            bnd_drw0 = Utility_MEDICAL.draw_on_volume(bnd_filt, (dim0_arr[i], dim1_arr[i], dim2_arr[i]),
                             1, np.arange(0,len(bnd_filt[1])))
            bnd_drw1 = Utility_MEDICAL.draw_on_volume(bnd_filt, (dim0_arr[i], dim1_arr[i], dim2_arr[i]),
                             2, np.arange(0,len(bnd_filt[2])))
            bnd_drw = bnd_drw0 | bnd_drw1

            if (dil_iter > 0):
                bnd_drw = Utility_MEDICAL.dilate_binary_mask(bnd_drw, dil_iter)
            if (fromVol):
                vol = FileIO_MEDICAL.read_dat(filesVol[i])
                vol = -vol
                #vol = vol + np.abs(np.amin(vol)) + 1
                vol[bnd_drw==0] = np.amin(vol)
                bnd_drw = vol
            if (kernel_size > 0):
                bnd_drw = signal.convolve(bnd_drw, kernel, mode='same')
            if (patchMode):
                patches = sample_patch(bnd_drw, patch_shape)
                for j in range(len(patches)):
                    patch_path = os.path.join(pathOut, 'ISPY_'+str(patch_cnt)+'_'+str(labels[i])+'_TDA.nii')
                    patch_cnt = patch_cnt + 1
                    FileIO_MEDICAL.save_nii(patches[j], patch_path)
            else:
                rmd_shape = create_bounding_box(bnd_drw.shape, unified_dim)
                tum = np.ones((unified_dim,unified_dim,unified_dim),dtype=np.float32) * np.amin(bnd_drw)
                tum[rmd_shape[0]:rmd_shape[1]+1, rmd_shape[2]:rmd_shape[3]+1, rmd_shape[4]:rmd_shape[5]+1]=bnd_drw
                FileIO_MEDICAL.save_nii(tum, filesOut[i])
    
    @staticmethod
    def mass_srk_volume_segmentation(pathVol, pathSeg, pathTum, pathOutVol, pathOutSeg, pathOutTum, ext):
        '''
        @pathVol: path to volumes
        @pathSeg: path to breast segmentations
        @pathOutVol: path to output shrinked volumes
        @pathOutSeg: path to output shrinked segmentations
        @ext: extension of the target files example: "nii"
        '''
        idx_arr = []
        filesVol = []
        filesSeg = []
        filesTum = []
        outVol = []
        outSeg = []
        outTum = []
        os.chdir(pathVol)
        for file in glob.glob("*."+ext):
            idx_arr.append(file.split('_')[1])
            filesVol.append(os.path.join(pathVol, file))
            outVol.append(os.path.join(pathOutVol, file.split('_')[0]+'_'+file.split('_')[1]+'_vol2_srk.nii'))
            outSeg.append(os.path.join(pathOutSeg, file.split('_')[0]+'_'+file.split('_')[1]+'_seg_srk.nii'))
            outTum.append(os.path.join(pathOutTum, file.split('_')[0]+'_'+file.split('_')[1]+'_tum_srk.nii'))
        os.chdir(pathSeg)
        cnt = 0
        for file in glob.glob("*."+ext):
            idx_ = file.split('_')[1]
            assert(idx_arr[cnt] == idx_)
            filesSeg.append(os.path.join(pathSeg, file))
            cnt = cnt + 1
        os.chdir(pathTum)
        cnt = 0
        for file in glob.glob("*."+ext):
            idx_ = file.split('_')[1]
            assert(idx_arr[cnt] == idx_)
            filesTum.append(os.path.join(pathTum, file))
            cnt = cnt + 1
        for i in range(len(filesVol)):
            vol = FileIO_MEDICAL.load_nii(filesVol[i])
            seg = FileIO_MEDICAL.load_nii(filesSeg[i])
            tum = FileIO_MEDICAL.load_nii(filesTum[i])
            seg = Utility_MEDICAL.erode_binary_mask(seg, 2)
            dil_outer = Utility_MEDICAL.dilate_binary_mask(seg, 5)
            dil_inner = Utility_MEDICAL.dilate_binary_mask(seg, 3)
            srk_shape = Utility_MEDICAL.shrink_box(dil_outer)
            seg_srk = dil_inner[srk_shape[0]:srk_shape[1]+1, srk_shape[2]:srk_shape[3]+1, srk_shape[4]:srk_shape[5]+1]
            vol_srk = vol[srk_shape[0]:srk_shape[1]+1, srk_shape[2]:srk_shape[3]+1, srk_shape[4]:srk_shape[5]+1]
            tum_srk = tum[srk_shape[0]:srk_shape[1]+1, srk_shape[2]:srk_shape[3]+1, srk_shape[4]:srk_shape[5]+1]
            vol_srk[seg_srk==0] = np.amin(vol)-1
            #vol_srk = vol_srk + np.abs(np.amin(vol_srk)) + 1
            #vol_srk = vol_srk * 10
            FileIO_MEDICAL.save_nii(vol_srk, outVol[i])
            #FileIO_MEDICAL.save_nii(seg_srk, outSeg[i])
            FileIO_MEDICAL.save_nii(tum_srk, outTum[i])
            
    @staticmethod
    def mass_write_dat(pathIn, pathOut, ext, mode):
        os.chdir(pathIn)
        for file in glob.glob("*."+ext):
            vol = FileIO_MEDICAL.load_nii(file)
            out_path = os.path.join(pathOut, file.split('_')[0]+'_'+file.split('_')[1]+'_vol2_'+mode+'.dat')
            if mode == "sup":
                vol = -vol
            FileIO_MEDICAL.write_dat(vol, out_path)
            
    @staticmethod
    def binary_balanced_evaluation(labels, preds):
        '''
        @labels: ground truth labels
        @preds: predicted labels
        binary labels with 0 or 1 with balaned formula:
        0.5*(correctly_predicted_0/total_num_0 + correctly_predicted_1/total_num1)
        '''
        labels = np.array(labels)
        preds = np.array(preds)
        num_0 = np.sum(labels==0)
        num_1 = np.sum(labels==1)
        correctly_predicted_0 = 0
        correctly_predicted_1 = 1
        length = len(labels)
        for i in range(length):
            if labels[i] == preds[i]:
                if labels[i] == 0:
                    correctly_predicted_0 = correctly_predicted_0 + 1
                else:
                    correctly_predicted_1 = correctly_predicted_1 + 1
        return 0.5 * (correctly_predicted_0/num_0 + correctly_predicted_1/num_1)
    
    @staticmethod
    def compute_accuracy(labels, preds):
        '''
        @labels: ground truth labels
        @preds: predicted labels with the same size as labels
        '''
        labels = np.array(labels)
        preds = np.array(preds)
        assert(len(labels) == len(preds))
        return np.sum(labels == preds) / float(len(labels))
    
    @staticmethod
    def compute_F1(labels, preds):
        '''
        Compute F1 score which accounts for both precision and sensitivity
        @labels: ground truth labels
        @preds: predicted labels with the same size as labels
        '''
        labels = np.array(labels)
        preds = np.array(preds)
        assert(len(labels) == len(preds))
        return f1_score(labels, preds)
    
    @staticmethod
    def compute_specificity_sensitivity(labels, preds):
        '''
        @labels: ground truth labels
        @preds: predicted labels with the same size as labels
        '''
        labels = np.array(labels)
        preds = np.array(preds)
        assert(len(labels) == len(preds))
        cm = metrics.confusion_matrix(labels, preds)
        total = sum(sum(cm))
        sensitivity = cm[0,0]/(cm[0,0]+cm[0,1])
        specificity = cm[1,1]/(cm[1,0]+cm[1,1])
        return specificity, sensitivity
    
    @staticmethod
    def binary_auc_score(labels, preds):
        '''
        @labels: ground truth labels
        @preds: predicted labels
        '''
        labels = np.array(labels)
        preds = np.array(preds)
        fpr, tpr, thresholds = metrics.roc_curve(labels, preds, pos_label=1)
        return metrics.auc(fpr, tpr)
    
    @staticmethod
    def SVC_baseline(feat_train, label_train, feat_test, label_test):
        '''
        Sample classification, run a grid search for best hyper parameters
        '''
        c_grid = np.arange(1, 10, 1)
        best_acc = np.ones(4, dtype=np.float64) * -1
        best_auc = np.ones(4, dtype=np.float64) * -1
        for i in range(len(c_grid)):
            clf = svm.LinearSVC(penalty='l2', C=c_grid[i])
            clf.fit(feat_train, label_train)
            res = clf.predict(feat_test)
            acc = Utility_MEDICAL.binary_balanced_evaluation(label_test, res)
            auc = Utility_MEDICAL.binary_auc_score(label_test, res)
            if acc > best_acc[0]:
                best_acc[0] = acc
            if auc > best_auc[0]:
                best_auc[0] = auc
            #print(acc)

        #print("========================================")
        for i in range(len(c_grid)):
            clf = svm.SVC(C=c_grid[i])
            clf.fit(feat_train, label_train)
            res = clf.predict(feat_test)
            acc = Utility_MEDICAL.binary_balanced_evaluation(label_test, res)
            auc = Utility_MEDICAL.binary_auc_score(label_test, res)
            if acc > best_acc[1]:
                best_acc[1] = acc
            if auc > best_auc[1]:
                best_auc[1] = auc
            #print(acc)

        #print("========================================")
        for i in range(len(c_grid)):
            clf = svm.SVC(C=c_grid[i], kernel='poly')
            clf.fit(feat_train, label_train)
            res = clf.predict(feat_test)
            acc = Utility_MEDICAL.binary_balanced_evaluation(label_test, res)
            auc = Utility_MEDICAL.binary_auc_score(label_test, res)
            if acc > best_acc[2]:
                best_acc[2] = acc
            if auc > best_auc[2]:
                best_auc[2] = auc
            #print(acc)

        #print("========================================")
        for i in range(len(c_grid)):
            clf = svm.SVC(C=c_grid[i], kernel='sigmoid')
            clf.fit(feat_train, label_train)
            res = clf.predict(feat_test)
            acc = Utility_MEDICAL.binary_balanced_evaluation(label_test, res)
            auc = Utility_MEDICAL.binary_auc_score(label_test, res)
            if acc > best_acc[3]:
                best_acc[3] = acc
            if auc > best_auc[3]:
                best_auc[3] = auc
        return best_acc, best_auc
    
    @staticmethod
    def probe_persistence(pers, labels, bord, mode, threshold):
        '''
        This function is to be used in combination with Cyclekernel.compute_cyckernel_wBornDeathLimit
        It does not compute anything, but to probe persistence and find reasonable birth, death threshold
        @pers: output from read_pers_txt, [[struct_num x 2] x dim] x file_num
        @labels: [file_num]
        @bord: birth or death, 0 for birth, 1 for death
        @mode: ">=" or "<="
        @treshold: double
        '''
        file_num = len(pers)
        assert(file_num == len(labels))
        dims = len(pers[0])
        cnt0 = np.zeros(dims, dtype=np.float32)
        cnt1 = np.zeros(dims, dtype=np.float32)
        num0 = np.sum(labels==0)
        num1 = np.sum(labels==1)
        
        for i in range(file_num):
            if mode == ">=":
                for dim in range(dims):
                    filt = pers[i][dim][:,bord] >= threshold
                    if labels[i] == 0:
                        cnt0[dim] = cnt0[dim] + np.sum(filt) / pers[i][dim].shape[0]
                    else:
                        cnt1[dim] = cnt1[dim] + np.sum(filt) / pers[i][dim].shape[0]
            elif mode == "<=":
                for dim in range(dims):
                    filt = pers[i][dim][:,bord] <= threshold
                    if labels[i] == 0:
                        cnt0[dim] = cnt0[dim] + np.sum(filt) / pers[i][dim].shape[0]
                    else:
                        cnt1[dim] = cnt1[dim] + np.sum(filt) / pers[i][dim].shape[0]
            else:
                print("Error input mode")
        print("Dim 0", " 0/1: ", cnt0[0]/num0, " ", cnt1[0]/num1, " ", (cnt0[0]/num0)/(cnt1[0]/num1))
        print("Dim 1", " 0/1: ", cnt0[1]/num0, " ", cnt1[1]/num1, " ", (cnt0[1]/num0)/(cnt1[1]/num1))
        print("Dim 2", " 0/1: ", cnt0[2]/num0, " ", cnt1[2]/num1, " ", (cnt0[2]/num0)/(cnt1[2]/num1))
        
    @staticmethod
    def SVM_classifier(feat, dim, labels, train_percentage, split_num, normalize, pairwise_dist):
        '''
        @feat: list of n x m matrix, the first matrix is accessed with: feat[0]
        @dim: the dimension of features to use
        @labels: list of integers
        @train_percentage: percentage of samples used for training
        @split_num: number of random splits of training and test
        @normalize: if to normalize the features
        @pairwise_dist: if to delete distances to the test subjects
        '''
        assert(len(feat) > dim)
        file_num = len(labels)
        train_num = int(np.floor(file_num * train_percentage))
        
        best_accuracy = np.zeros(4, dtype=np.float64)
        best_auc = np.zeros(4, dtype=np.float64)
        for split_ in range(split_num):
            train_index = np.array(random.sample(range(0, file_num), train_num))
            test_index = []
            for i in range(file_num):
                if np.sum(train_index == i) == 0:
                    test_index.append(i)

            train_labels = labels[train_index]
            test_labels = labels[test_index]
            if pairwise_dist == True:
                feat_ = feat[dim][:, train_index]
#                 feat0 = feat[0][:, train_index]
#                 feat1 = feat[1][:, train_index]
#                 feat2 = feat[2][:, train_index]
            else:
                feat_ = feat[dim]
#                 feat0 = feat[0]
#                 feat1 = feat[1]
#                 feat2 = feat[2]
#             feat_ = np.concatenate((feat0,feat1, feat2), axis=1)
#             feat_ = feat0
            if normalize:
                feat_ = Utility_MEDICAL.normalize(feat_, 1.0)

            feat_train = feat_[train_index,:]
            feat_test = feat_[test_index,:]
            cur_acc, cur_auc = Utility_MEDICAL.SVC_baseline(feat_train, train_labels, feat_test, test_labels)
            best_accuracy = best_accuracy + cur_acc
            best_auc = best_auc + cur_auc
            
        best_accuracy = best_accuracy / split_num
        best_auc = best_auc / split_num
        print(best_accuracy)
        print(best_auc)