In [None]:
REPEAT_START = 1
REPEAT_END = 1

In [None]:
import sys
sys.path.append('..')

In [None]:
import os
import os.path as pth
from itertools import product, combinations

import numpy as np
from numba import jit, njit
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy.stats import pearsonr
import nilearn as nl
import nibabel as nib
import h5py
import pandas as pd
from multiprocessing import Pool
from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split, KFold, RepeatedKFold, \
                                    GroupKFold, RepeatedStratifiedKFold
from sklearn.utils import class_weight

import json
import shutil

import gc

from numba import jit

In [None]:
from IPython.display import clear_output

In [None]:
plt.rcParams.update({'font.size': 16})

### Setting

In [None]:
config = {
    'gpu_num': 4,
    
    'missing_value': 'exclude', ### 'exclude', 'mean', 'median'
    
    'is_zscore':True,
    
    'output_label_list': None,
    'output_activation': 'linear',
    
    'conv':{
        'conv_num': (3,3,3,3),
        'base_channel_num': 32,
        'kernel_size':(3, 3, 3),
        'padding':'same',
        'stride':1
    },
    'pool':{
        'type':'NP',
        'size':(2, 2, 2),
        'stride':2,
        'padding':'same'
    },
    'fc':{
        'fc_num': 128,
     },
    
    'activation':'relu',
    
    'is_batchnorm': True,
    'is_dropout': True,
    'dropout_rate': 0.50,
    
    'batch_size': 32,

    'loss': 'mse',
    'loss_weights': {
        'age':0.30,
        'domain1_var1':0.175, 
        'domain1_var2':0.175,
        'domain2_var1':0.175,
        'domain2_var2':0.175,
    },
    
    'num_epoch':1000,
    'learning_rate': 1e-4,
    
    'num_fold': 5,
    'num_repeat': 30,
    
    'random_state': 7777
}

In [None]:
BASE_MODEL_NAME = '3D_CNN_regression_average_GICA_xception-custom-2'

In [None]:
loss_list = ['weighted_mse']

conv_comb_list = [(None,)]

fc_list = [0]

base_channel_list = [4, 8]

pool_type_list = [None, 'avg']

### 'same', 'valid'
conv_padding_list = ['same']
pool_padding_list = ['same']

activation_list = ['relu']

is_batchnorm_list = [True, False]

batch_size_list = [32]

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = str(config['gpu_num'])
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'

In [None]:
import tensorflow as tf
import tensorflow.keras as keras

from tensorflow.keras.utils import to_categorical, Sequence
from tensorflow.keras.layers import Input, Dense, Activation, BatchNormalization, \
                                    Flatten, Conv3D, AveragePooling3D, MaxPooling3D, Dropout, \
                                    Concatenate, GlobalMaxPool1D, GlobalAvgPool1D
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.callbacks import ModelCheckpoint,LearningRateScheduler, \
                                        EarlyStopping, BaseLogger, History
from tensorflow.keras.losses import mean_squared_error, mean_absolute_error
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import max_norm

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
from collections import Counter

### Data

In [None]:
BASE_PATH = pth.join('/users/lww/data/trend_competition')
os.listdir(BASE_PATH)

['fMRI_mask.nii',
 'fMRI_test',
 'fMRI_train',
 'fnc.csv',
 'ICN_numbers.csv',
 'loading.csv',
 'reveal_ID_site2.csv',
 'sample_submission.csv',
 'train_scores.csv',
 'model',
 'ttest',
 'two_sample_ttest',
 'average_GICA',
 'transposed_GICA']

In [None]:
# image and mask directories
train_data_dir = pth.join(BASE_PATH, 'average_GICA', 'fMRI_train')
test_data_dir = pth.join(BASE_PATH, 'average_GICA', 'fMRI_test')

train_data = pd.read_csv(pth.join(BASE_PATH, 'train_scores.csv'))
# loading_data = pd.read_csv(f'{BASE_PATH}/loading.csv')
# fnc_data = pd.read_csv(f'{BASE_PATH}/fnc.csv')

In [None]:
if config['missing_value'] == 'exclude':
    train_data = train_data[train_data.isnull().sum(axis=1)==0] ## exclude missing rows
elif config['missing_value'] == 'mean':
    pass
elif config['missing_value'] == 'median':
    pass
len(train_data)

5434

In [None]:
mask_filename = pth.join(BASE_PATH, 'fMRI_mask.nii')
mask_niimg = nib.load(mask_filename)
affine_array = mask_niimg.affine

In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, all_x, data_index, config, all_y=None, shuffle=True):
        self.all_x = all_x
        self.all_y = all_y
        self.data_index = data_index.copy()
        self.config = config
        self.batch_size = self.config['batch_size']
        self.shuffle = shuffle
        np.random.seed(config['random_state'])

        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.data_index)/self.batch_size))

    def __getitem__(self, generator_index):
        # Generate indexes of the batch
        indexes = self.data_index[generator_index*self.batch_size:(generator_index+1)*self.batch_size]

        # Generate data
        X = self.all_x[indexes]
        
        if self.all_y is not None:
            y = {
                'output_{}'.format(label): each_y
                    for label, each_y in zip(self.config['output_label_list'], 
                                             self.all_y[indexes].T)
            }
            return X, y
        else:
            return X
        
    def sample_generator(self, num_samples):
        # Generate indexes of the batch
        for i_sample in range(num_samples):
            target_index = self.data_index[i_sample]

            # Generate data
            X_sample = self.all_x[target_index]

            if self.config['is_zscore']:
                for i in range(X_sample.shape[-1]):
                    X_sample[...,i] = self.zscore(X_sample[...,i])
            
            if self.all_y is not None:
                y_sample = {
                    'output_{}'.format(label): each_y
                        for label, each_y in zip(self.config['output_label_list'], 
                                                 self.all_y[target_index].T)
                }
                yield X_sample, y_sample
            else:
                yield X_sample

    def data_generator_for_tfdata(self, num_samples=32):
        y_type = {
            'output_{}'.format(label): tf.float32
                for label in self.config['output_label_list']
        }

        y_shape = {
            'output_{}'.format(label): ()
                for label in self.config['output_label_list']
        }
        
        if self.all_y is not None:
            return tf.data.Dataset.from_generator(
                    self.sample_generator, args=[num_samples], 
                    output_types=(tf.float32, y_type), 
                    output_shapes=([*self.config['input_shape']], 
                                   y_shape)
                    ).prefetch(tf.data.experimental.AUTOTUNE).batch(config['batch_size'])
        else:
            return tf.data.Dataset.from_generator(
                    self.sample_generator, args=[num_samples], 
                    output_types=(tf.float32), 
                    output_shapes=([*self.config['input_shape']])
                    ).prefetch(tf.data.experimental.AUTOTUNE).batch(config['batch_size'])

    def batch_generator(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)
        self.on_epoch_end()
    
    def batch_generator_for_tfdata(self):
        y_type = {
            'output_{}'.format(label): tf.float32
                for label in self.config['output_label_list']
        }

        y_shape = {
            'output_{}'.format(label): [None,]
                for label in self.config['output_label_list']
        }
        
        if self.all_y is not None:
            return tf.data.Dataset.from_generator(
                        self.batch_generator, args=[], 
                        output_types=(tf.float32, y_type), 
                        output_shapes=([None, *config['input_shape']], 
                                       y_shape)
                    ).prefetch(tf.data.experimental.AUTOTUNE)
        else:
            return tf.data.Dataset.from_generator(
                        self.batch_generator, args=[], 
                        output_types=(tf.float32), 
                        output_shapes=([None, *config['input_shape']])
                    ).prefetch(tf.data.experimental.AUTOTUNE)

    def on_epoch_end(self):
        if self.shuffle == True:
            np.random.shuffle(self.data_index)
    
    @staticmethod
    def load_subject(filename):
        with h5py.File(filename, 'r') as f:
            subject_data = f['SM_feature'][()]
        subject_data = subject_data[...,np.newaxis]
        return subject_data
    
    @staticmethod
    def zscore(data):
        subject_mask = data != 0
        data[subject_mask] = (data[subject_mask]-data[subject_mask].mean())/ np.std(data[subject_mask])
        return data

In [None]:
X_train = np.zeros((len(train_data), 53, 63, 52, 1), dtype='<f8') ### '<f8' -> '<f4' Downcasting?
subject_index_list = train_data.to_numpy()[:,0].astype(int)
subject_filename_list = list(map(lambda index: pth.join(train_data_dir, '{}.mat'.format(index)), subject_index_list))

with Pool(4) as pool:
    for i, subject_data in tqdm(enumerate(pool.imap(DataGenerator.load_subject, subject_filename_list, chunksize=4)), total=len(train_data)):
        X_train[i] = subject_data

HBox(children=(FloatProgress(value=0.0, max=5434.0), HTML(value='')))




In [None]:
if config['is_zscore']:
    for i in tqdm(range(X_train.shape[0]), total=X_train.shape[0]):
        for j in range(X_train.shape[-1]):
            X_train[i,...,j] = DataGenerator.zscore(X_train[i,...,j])

HBox(children=(FloatProgress(value=0.0, max=5434.0), HTML(value='')))




In [None]:
config['output_label_list'] = train_data.columns.to_list()[1:]
config['output_label_list']

['age', 'domain1_var1', 'domain1_var2', 'domain2_var1', 'domain2_var2']

In [None]:
y_train = train_data.to_numpy()[:,1:]

y_train.shape

(5434, 5)

In [None]:
# config['input_shape'] = X_train[0].shape
config['input_shape'] = (53, 63, 52, 1)
config['input_shape']

(53, 63, 52, 1)

In [None]:
from keras_application_3D import keras_applications
from keras_application_3D.keras_applications import xception

import tempfile

Using TensorFlow backend.


In [None]:
keras_applications._KERAS_BACKEND = tf.keras.backend
keras_applications._KERAS_LAYERS = tf.keras.layers
keras_applications._KERAS_MODELS = tf.keras.models
keras_applications._KERAS_UTILS = tf.keras.utils

### Model

In [None]:
def build_cnn(config):
    input_layer = Input(shape=config['input_shape'], name='input_layer')
    x = xception.CustomXception3D_2(
                include_top=False, weights=None, 
                input_tensor=input_layer, input_shape=config['input_shape'], 
                pooling=config['pool']['type'], classes=None, 
                base_channel=config['conv']['base_channel_num'], 
                use_batchnorm=config['is_batchnorm']
            ).output

    if config['pool']['type'] == None:
        x = Flatten(name='flatten_layer')(x)
    if config['is_dropout']:
        x = Dropout(config['dropout_rate'], name='output_dropout')(x)    
            
    output_list = []
    for label in config['output_label_list']:
        output = Dense(1, activation=config['output_activation'], 
    #               kernel_regularizer=keras.regularizers.l1_l2(l1=0.001, l2=0.001), 
                  name='output_{}'.format(label))(x)
        output_list.append(output)
    model = Model(inputs=input_layer, outputs=output_list, name='{}'.format(BASE_MODEL_NAME))

    return model

In [None]:
model = build_cnn(config)
model.summary(line_length=150)
print()
del model

Model: "3D_CNN_regression_average_GICA_xception-custom-2"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_layer (InputLayer)                         [(None, 53, 63, 52, 1)]          0                                                                   
______________________________________________________________________________________________________________________________________________________
block1_conv1 (Conv3D)                            (None, 26, 31, 25, 32)           864               input_layer[0][0]                                 
______________________________________________________________________________________________________________________________________________________
block1_conv1_bn (BatchNormalization)

### Define custom function

In [None]:
def normalized_mae_loss(y_true, y_pred):
    sae = K.sum(K.abs(y_pred-y_true))
    norm_val = K.sum(K.abs(y_true))
    norm_sae = sae/norm_val
#    return K.sum(norm_sae, axis=-1)
    return norm_sae

def make_trends_score(y_true, y_pred):
    sae = np.sum(np.abs(y_pred-y_true), axis=0)
    norm_val = np.sum(np.abs(y_true), axis=0)
    norm_sae = sae/norm_val
    norm_sae *= np.array([0.3,0.175,0.175,0.175,0.175])
    return np.sum(norm_sae, axis=-1), norm_sae

In [None]:
class CustomHistory(History):
    def __init__(self, config=None, **kargs):
        super().__init__(**kargs)
        self.config = config
        self.loss_weights = {'output_{}_normalized_mae_loss'.format(label):self.config['loss_weights'][label] 
                                for label in self.config['output_label_list']}
    
    def on_train_begin(self, logs=None):
        super().on_train_begin(logs=logs)
        self.epoch = []
        self.history = {}

    def on_epoch_end(self, epoch, logs=None):
        super().on_epoch_end(epoch=epoch, logs=logs)
        train_score_array = np.array([logs[k]*v for k, v in self.loss_weights.items()])
        val_score_array = np.array([logs['val_'+k]*v for k, v in self.loss_weights.items()])
        for k, v in zip(['trends_score', 'val_trends_score'], [train_score_array, val_score_array]):
            trends_score = v.sum(axis=-1)
            self.history.setdefault(k, []).append(trends_score)
            print('{}: {}'.format(k, trends_score))
        print()

In [None]:
def calculate_sample_weight(df):
    df = df.copy()
    df -= df.min()
    df /= (df.max()*(1+1e-12))
    df *= 5
    df = np.trunc(df)
    sample_weight = compute_sample_weight(class_weight = "balanced", y=df)
    return df, sample_weight

def make_smooth_weight(weight):
    smooth_weight = weight.copy()
    smooth_weight = np.log(smooth_weight)
    smooth_weight[smooth_weight<1.0] = 1.0
    return smooth_weight

### Training

In [None]:
model_base_path = '/users/lww/data/trend_competition/model'
os.makedirs(model_base_path, exist_ok=True)

In [None]:
for loss_str in loss_list:
    config['loss'] = loss_str
    for batch_size in batch_size_list:
        config['batch_size'] = batch_size 
        for is_batchnorm in is_batchnorm_list:
            config['is_batchnorm'] = is_batchnorm 
            for activation_func in activation_list:
                config['activation'] = activation_func
                for conv_padding_str in conv_padding_list:
                    config['conv']['padding'] = conv_padding_str
                    for pool_padding_str in pool_padding_list:
                        config['pool']['padding'] = pool_padding_str
                        for pool_type in pool_type_list:
                            config['pool']['type'] = pool_type
                            for base_channel_num in base_channel_list:
                                config['conv']['base_channel_num'] = base_channel_num
                                for conv_comb in conv_comb_list:
                                    config['conv']['conv_num'] = conv_comb
                                    for fc_num in fc_list:
                                        config['fc']['fc_num'] = fc_num
                                        for missing_value in missing_value_list:
                                            config['missing_value'] = missing_value
                                            for sample_weight_str in sample_weight_list:
                                                config['sample_weight'] = sample_weight_str

                                                for target in config['label_list']:
                                                    config['output_label_list'] = [target]
                                                    w = config['loss_weights'][target]

                                                    target_df = train_data.copy()
                                                    target_df['original_index'] = train_data.index.values
                                                    if config['missing_value'] == 'exclude':
                                                        not_null_df = target_df[target].notnull()
                                                        target_df = target_df[not_null_df].reindex()
                                                    elif config['missing_value'] == 'mean':
                                                        pass
                                                    elif config['missing_value'] == 'median':
                                                        pass

                                                    if config['sample_weight'] == 'log':
                                                        target_y_df, sample_weight = calculate_sample_weight(target_df[target])
                                                        sample_weight = make_smooth_weight(sample_weight)
                                                    elif config['sample_weight'] == 'sample':
                                                        target_y_df, sample_weight = calculate_sample_weight(target_df[target])
                                                    else:
                                                        target_y_df, _ = calculate_sample_weight(target_df[target])
                                                        sample_weight = np.ones(len(target_df[target]))

                                                    rkf = RepeatedStratifiedKFold(n_splits=config['num_fold'], 
                                                                                  n_repeats=config['num_repeat'], 
                                                                                  random_state=config['random_state']
                                                                                 )
                                                    for i, (train_index_temp, val_index_temp) in enumerate(rkf.split(target_df, target_y_df)):
                                                        fold_num = (i%config['num_fold']) + 1
                                                        repeat_num = (i//config['num_fold']) + 1

                                                        if not ((repeat_num>=REPEAT_START and repeat_num<=REPEAT_END) and (fold_num>=1 and fold_num<=5)):
                                                            continue

                                                        train_index = target_df['original_index'].values[train_index_temp]
                                                        val_index = target_df['original_index'].values[val_index_temp]

                                                        train_generator = DataGenerator(all_x=X_train, all_y=train_data[target].values, 
                                                                                        data_index=train_index, 
                                                                                        config=config, shuffle=True
                                                                                       )
                                                        train_generator = train_generator.batch_generator_for_tfdata()

                                                        val_generator = DataGenerator(all_x=X_train, all_y=train_data[target].values, 
                                                                                        data_index=val_index, 
                                                                                        config=config, shuffle=True
                                                                                     )
                                                        val_generator = val_generator.batch_generator_for_tfdata()

                                                        base = BASE_MODEL_NAME

                                                        base += '_missing-value_{}'.format(config['missing_value'])
                                                        base += '_sample-weight_{}'.format(config['sample_weight'])

                                                        base += '_split-method_{}'.format(str(config['num_fold'])+'fold')
                                                        base += '_zscore_{}'.format(config['is_zscore'])

                                                        base += '_loss_{}'.format(config['loss'].replace('_', '-'))
                                                        base += '_basech_{}'.format(str(config['conv']['base_channel_num']).zfill(2))

                                                        base += '_conv_{}'.format('-'.join(map(lambda x:str(x),config['conv']['conv_num'])))
                                                        base += '_conv-pad_{}'.format(config['conv']['padding'])
                                                        base += '_pool-type_{}'.format(config['pool']['type'])
                                                        base += '_pool-pad_{}'.format(config['pool']['padding'])
                                                        base += '_fc_{}'.format(fc_num)
                                                        base += '_act_{}'.format(config['activation'])

                                                        if config['is_dropout']:
                                                            base += '_DO_'+str(config['dropout_rate']).replace('.', '')
                                                        if config['is_batchnorm']:
                                                            base += '_BN'+'_O'
                                                        else:
                                                            base += '_BN'+'_X'

                                                        model_name = base

                                                        print(model_name, 'Batch:', config['batch_size'])
                                                        print('{}th_repeat, {}th_fold'.format(repeat_num, fold_num))

                                                        if config['loss'] == 'mse':
                                                            loss_funcs = {'output_{}'.format(label):'mse' 
                                                                        for label in config['output_label_list']}
                                                            loss_weights = {'output_{}'.format(label):1.0 
                                                                        for label in config['output_label_list']}
                                                            config['output_activation'] = 'linear'
                                                        elif config['loss'] == 'mae':
                                                            loss_funcs = {'output_{}'.format(label):'mae' 
                                                                        for label in config['output_label_list']}
                                                            loss_weights = {'output_{}'.format(label):1.0 
                                                                        for label in config['output_label_list']}
                                                            config['output_activation'] = 'linear'
                                                        else:
                                                            loss_funcs = {'output_{}'.format(label):'mse' 
                                                                        for label in config['output_label_list']}
                                                            loss_weights = {'output_{}'.format(label):1.0 
                                                                        for label in config['output_label_list']}
                                                            config['output_activation'] = 'linear'

                                                        model = build_cnn(config)
#                                                         model.summary()

                                                        model.compile(loss=loss_funcs, loss_weights=loss_weights, 
                                                                      optimizer=Adam(lr=config['learning_rate']),
                                                                      metrics=['mse', 'mae', normalized_mae_loss])

                                                        model_checkpoint_base = pth.join(model_base_path, 'checkpoint')
                                                        model_path = pth.join(model_checkpoint_base, model_name, 
                                                                              'batch_{}'.format(str(config['batch_size']).zfill(3)),
                                                                              str(config['num_repeat']).zfill(2)+'_repeat', 
                                                                              str(repeat_num).zfill(2)+'th_repeat',
                                                                              str(config['num_fold']).zfill(2)+'_fold',
                                                                              str(fold_num).zfill(2)+'th_fold',
                                                                              '_'.join(config['output_label_list'])
                                                                             )
                                                        if pth.isdir(model_path):
                                                            shutil.rmtree(model_path)
                                                        os.makedirs(model_path, exist_ok=True)
                                                        model_filename = pth.join(model_path, '{epoch:06d}-{val_loss:.6f}.hdf5')

                                                        checkpointer = ModelCheckpoint(filepath = model_filename, monitor = "val_loss", 
                                                                               verbose=1, save_best_only=True)
                                                        early_stopping = EarlyStopping(monitor='val_loss', patience=10)
#                                                         history = CustomHistory(config=config)

                                                        hist = model.fit(x=train_generator, epochs=3, #config['num_epoch'], 
                                                                         validation_data=val_generator, shuffle=True,
                                                                         callbacks = [checkpointer, early_stopping], 
                #                                                          batch_size=config['batch_size'],
                #                                                          use_multiprocessing=True,
                #                                                          workers=4
                                                                        )

                                                        analysis_base_path = pth.join(model_base_path, 'analysis')

                                                        model_analysis_path = pth.join(analysis_base_path, model_name, 
                                                                                       'batch_'+str(config['batch_size']).zfill(3))
                                                        each_repeat_path = pth.join(model_analysis_path, str(config['num_repeat']).zfill(2)+'_repeat')
                                                        each_repeat_iterate_path = pth.join(each_repeat_path, str(repeat_num).zfill(2)+'th_repeat')
                                                        each_fold_path = pth.join(each_repeat_iterate_path, str(config['num_fold']).zfill(2)+'_fold')
                                                        each_fold_iterate_path = pth.join(each_fold_path, str(fold_num).zfill(2)+'th_fold')
                                                        target_label_path = pth.join(each_fold_iterate_path, '_'.join(config['output_label_list']))
                                                        visualization_path = pth.join(target_label_path,'visualization')
                                                        os.makedirs(visualization_path, exist_ok=True)

                                                        for each_label in ['loss']:
                                                            fig, ax = plt.subplots()
                                                            ax.plot(history.history[each_label], 'g', label='train_{}'.format(each_label))
                                                            ax.plot(history.history['val_{}'.format(each_label)], 'r', label='val_{}'.format(each_label))
                                                            ax.set_xlabel('epoch')
                                                            ax.set_ylabel('loss')
                                                            ax.legend(loc='upper left')
            #                                                 plt.ylim(0, 2)
                            #                                 plt.show()
                                                            filename = 'learning_curve_{}'.format(each_label)
                                                            fig.savefig(pth.join(visualization_path, filename), transparent=True)
                                                            plt.cla()
                                                            plt.clf()
                                                            plt.close('all')

                                                        np.savez_compressed(pth.join(visualization_path, 'learning_curve'), 
                                                                            history=history.history, 
                                                                           )
                                                        np.savez_compressed(pth.join(target_label_path, 'used_index'),
                                                                            train_index=train_index, val_index=val_index
                                                                           )

                #                                         model.save(pth.join(model_path, '000_last.hdf5'))
                                                        K.clear_session()
                                                        del(model)
                                                        gc.collect()
                                                        with open(pth.join(each_repeat_path, 'config.json'), 'w') as f:
                                                            json.dump(config, f)

                                                        chk_name_list = sorted([name for name in os.listdir(model_path) if name != '000_last.hdf5'])
                                                        for chk_name in chk_name_list[:-1]:
                                                            os.remove(pth.join(model_path, chk_name))
                                                        clear_output()

### Inference

In [None]:
submission_base = '/users/lww/code/Research/trend_competition/submissions'
model_base_path = '/users/lww/data/trend_competition/model'

In [None]:
from itertools import cycle

In [None]:
test_filename_list = os.listdir(test_data_dir)
test_filename_list[:5]

['10043.mat', '10029.mat', '10003.mat', '10012.mat', '10023.mat']

In [None]:
# test_filename_list = test_filename_list[:128]

In [None]:
X_test = np.zeros((len(test_filename_list), 53, 63, 52, 1), dtype='<f8')
subject_filename_list = [pth.join(test_data_dir, filename) for filename in test_filename_list]

with Pool(4) as pool:
    for i, subject_data in tqdm(enumerate(pool.imap(DataGenerator.load_subject, subject_filename_list, chunksize=4)), total=len(test_filename_list)):
        X_test[i] = subject_data

HBox(children=(FloatProgress(value=0.0, max=5877.0), HTML(value='')))




In [None]:
if config['is_zscore']:
    for i in tqdm(range(X_test.shape[0]), total=X_test.shape[0]):
        for j in range(X_test.shape[-1]):
            X_test[i,...,j] = DataGenerator.zscore(X_test[i,...,j])

HBox(children=(FloatProgress(value=0.0, max=5877.0), HTML(value='')))




In [None]:
from keras_application_3D.keras_applications.custom_layers import SeparableConv3D

In [None]:
for is_last in [False]:
    for loss_str in loss_list:
        config['loss'] = loss_str
        for batch_size in batch_size_list:
            config['batch_size'] = batch_size
            test_generator = DataGenerator(all_x=X_test, 
                                data_index=np.arange(len(X_test)), 
                                config=config, shuffle=True
                               )
            test_data = test_generator.batch_generator_for_tfdata()
            for is_batchnorm in is_batchnorm_list:
                config['is_batchnorm'] = is_batchnorm 
                for activation_func in activation_list:
                    config['activation'] = activation_func
                    for conv_padding_str in conv_padding_list:
                        config['conv']['padding'] = conv_padding_str
                        for pool_padding_str in pool_padding_list:
                            config['pool']['padding'] = pool_padding_str
                            for pool_type in pool_type_list:
                                config['pool']['type'] = pool_type
                                for base_channel_num in base_channel_list:
                                    config['conv']['base_channel_num'] = base_channel_num
                                    for conv_comb in conv_comb_list:
                                        config['conv']['conv_num'] = conv_comb
                                        for fc_num in fc_list:
                                            config['fc']['fc_num'] = fc_num
                                            
                                            for missing_value in missing_value_list:
                                                config['missing_value'] = missing_value
                                                for sample_weight_str in sample_weight_list:
                                                    config['sample_weight'] = sample_weight_str

                                                    all_pred_list = []
                                                    for target in config['label_list']:
                                                        config['output_label_list'] = [target]
                                                        w = config['loss_weights'][target]

                                                        base = BASE_MODEL_NAME

                                                        base += '_missing-value_{}'.format(config['missing_value'])
                                                        base += '_sample-weight_{}'.format(config['sample_weight'])

                                                        base += '_split-method_{}'.format(str(config['num_fold'])+'fold')
                                                        base += '_zscore_{}'.format(config['is_zscore'])

                                                        base += '_loss_{}'.format(config['loss'].replace('_', '-'))
                                                        base += '_basech_{}'.format(str(config['conv']['base_channel_num']).zfill(2))

                                                        base += '_conv_{}'.format('-'.join(map(lambda x:str(x),config['conv']['conv_num'])))
                                                        base += '_conv-pad_{}'.format(config['conv']['padding'])
                                                        base += '_pool-type_{}'.format(config['pool']['type'])
                                                        base += '_pool-pad_{}'.format(config['pool']['padding'])
                                                        base += '_fc_{}'.format(fc_num)
                                                        base += '_act_{}'.format(config['activation'])

                                                        if config['is_dropout']:
                                                            base += '_DO_'+str(config['dropout_rate']).replace('.', '')
                                                        if config['is_batchnorm']:
                                                            base += '_BN'+'_O'
                                                        else:
                                                            base += '_BN'+'_X'

                                                        model_name = base

                                                        print(model_name, 'Batch:', config['batch_size'])

                                                        if config['loss'] == 'mse':
                                                            loss_funcs = {'output_{}'.format(label):'mse' 
                                                                        for label in config['output_label_list']}
                                                            loss_weights = {'output_{}'.format(label):1.0 
                                                                        for label in config['output_label_list']}
                                                            config['output_activation'] = 'linear'
                                                        elif config['loss'] == 'mae':
                                                            loss_funcs = {'output_{}'.format(label):'mae' 
                                                                        for label in config['output_label_list']}
                                                            loss_weights = {'output_{}'.format(label):1.0 
                                                                        for label in config['output_label_list']}
                                                            config['output_activation'] = 'linear'
                                                        else:
                                                            loss_funcs = {'output_{}'.format(label):'mse' 
                                                                        for label in config['output_label_list']}
                                                            loss_weights = {'output_{}'.format(label):1.0 
                                                                        for label in config['output_label_list']}
                                                            config['output_activation'] = 'linear'                                                 

                                                        model_checkpoint_base = pth.join(model_base_path, 'checkpoint')

                                                        analysis_base_path = pth.join(model_base_path, 'analysis')

                                                        model_analysis_path = pth.join(analysis_base_path, model_name,
                                                                                      'batch_{}'.format(str(config['batch_size']).zfill(3)))
                                                        
                                                        fold_pred_list = []    

                                                        log_name = 'log.tsv' if not is_last else 'log_last.tsv'
                                                        print(pth.join(model_analysis_path, log_name))
                                                        with open(pth.join(model_analysis_path, log_name), 'w') as log_file:
                                                            log_file.write('\t'.join(['repeat_num', 'fold_num',
                                                                                      'train_loss', 'train_mse', 'train_mae',
                                                                                      'test_loss', 'test_mse', 'test_mae'])+'\n')

                                                            model_temp_path = pth.join(model_checkpoint_base, model_name, 
                                                                                   'batch_{}'.format(str(config['batch_size']).zfill(3)),
                                                                                  str(config['num_repeat']).zfill(2)+'_repeat')
                                                            repeat_num_list = [name for name in os.listdir(model_temp_path) if name.endswith('th_repeat')]
                                                            for repeat_num_str in sorted(repeat_num_list):
                                                                model_fold_temp_path = pth.join(model_temp_path, repeat_num_str, str(config['num_fold']).zfill(2)+'_fold')
                                                                fold_num_list = [name for name in os.listdir(model_fold_temp_path) if name.endswith('th_fold')]
                                                                for fold_num_str in sorted(fold_num_list):
                                                                    print('Batch:', config['batch_size'])
                                                                    print(model_name, repeat_num_str, fold_num_str)
                                                                    log_file.write(repeat_num_str.replace('th_repeat', '')+'\t')
                                                                    log_file.write(fold_num_str.replace('th_fold', '')+'\t')

                                                                    model_path = pth.join(model_fold_temp_path, 
                                                                                          fold_num_str,
                                                                                          '_'.join(config['output_label_list'])
                                                                                         )
                                                                    if not is_last:
                                                                        model_chk_name = sorted(os.listdir(model_path))[-1]
                                                                    else: 
                                                                        model_chk_name = '000_last.hdf5'

                                                                    dependencies = {
                                                                        'normalized_mae_loss': normalized_mae_loss,
                                                                    }

                                                                    model = load_model(pth.join(model_path, model_chk_name), custom_objects=dependencies)
                                                        #             model.summary()

                                                                    each_repeat_path = pth.join(model_analysis_path, str(config['num_repeat']).zfill(2)+'_repeat')
                                                                    each_repeat_iterate_path = pth.join(each_repeat_path, repeat_num_str)
                                                                    each_fold_path = pth.join(each_repeat_iterate_path, str(config['num_fold']).zfill(2)+'_fold')
                                                                    each_fold_iterate_path = pth.join(each_fold_path, fold_num_str)
                                                                    target_label_path = pth.join(each_fold_iterate_path, '_'.join(config['output_label_list']))
                                                                    visualization_path = pth.join(target_label_path,'visualization')

                                                                    each_str, each_data = 'test', test_data
                                                                    print('===', each_str, 'evaluate', '===')

            #                                                         pred = model.predict(each_data, steps=len(test_generator), verbose=1)
                                                                    pred = model.predict(
                                                                        each_data, 
                                                                        steps=len(test_generator),
            #                                                             use_multiprocessing=True,
            #                                                             workers=4,
                                                                        verbose=1,
                                                                    )

                                                                    pred_array = np.array(pred).squeeze()

                                                                    fold_pred_list.append(pred_array)

                                                        all_pred_list.append(np.mean(fold_pred_list, axis=0))

                                                    pred_array = np.array(all_pred_list).squeeze().T

                                                    test_list = os.listdir(test_data_dir)
                                                    test_list = list(map(lambda name: name.split('.')[0], test_list))

                                                    with open(pth.join(submission_base, '{}.csv'.format(model_name)), 'w') as f:
                                                        f.write('Id,Predicted\n')
                                                        for subject_id, each_pred_array in tqdm(zip(test_list, pred_array), total=len(test_list)):
                                                            for each_label, each_value in zip(config['output_label_list'], each_pred_array):
                                                                f.write('{}_{},{}\n'.format(subject_id, each_label, each_value))

                                                    K.clear_session()
                                                    del(model)
                                                    gc.collect()                        
                                                    print()

3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_O Batch: 32
/users/lww/data/trend_competition/model/analysis/3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_O/batch_032/log.tsv
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_O 01th_repeat
=== test evaluate ===
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_O 01th_repeat
=== test evaluate ===
Bat

HBox(children=(FloatProgress(value=0.0, max=5877.0), HTML(value='')))



3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_O Batch: 32
/users/lww/data/trend_competition/model/analysis/3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_O/batch_032/log.tsv
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_O 01th_repeat
=== test evaluate ===
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_O 01th_repeat
=== test evaluate ===
B

HBox(children=(FloatProgress(value=0.0, max=5877.0), HTML(value='')))



3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_O Batch: 32
/users/lww/data/trend_competition/model/analysis/3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_O/batch_032/log.tsv
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_O 01th_repeat
=== test evaluate ===
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_O 01th_repeat
=== test evaluate ===
Batch

HBox(children=(FloatProgress(value=0.0, max=5877.0), HTML(value='')))



3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_O Batch: 32
/users/lww/data/trend_competition/model/analysis/3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_O/batch_032/log.tsv
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_O 01th_repeat
=== test evaluate ===
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_O 01th_repeat
=== test evaluate ===
Batch

HBox(children=(FloatProgress(value=0.0, max=5877.0), HTML(value='')))



3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_X Batch: 32
/users/lww/data/trend_competition/model/analysis/3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_X/batch_032/log.tsv
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_X 01th_repeat
=== test evaluate ===
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_X 01th_repeat
=== test evaluate ===
B

HBox(children=(FloatProgress(value=0.0, max=5877.0), HTML(value='')))



3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_X Batch: 32
/users/lww/data/trend_competition/model/analysis/3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_X/batch_032/log.tsv
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_X 01th_repeat
=== test evaluate ===
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_None_pool-pad_same_fc_0_act_relu_DO_05_BN_X 01th_repeat
=== test evaluate ===
B

HBox(children=(FloatProgress(value=0.0, max=5877.0), HTML(value='')))



3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_X Batch: 32
/users/lww/data/trend_competition/model/analysis/3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_X/batch_032/log.tsv
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_X 01th_repeat
=== test evaluate ===
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_04_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_X 01th_repeat
=== test evaluate ===
Batch

HBox(children=(FloatProgress(value=0.0, max=5877.0), HTML(value='')))



3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_X Batch: 32
/users/lww/data/trend_competition/model/analysis/3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_X/batch_032/log.tsv
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_X 01th_repeat
=== test evaluate ===
Batch: 32
3D_CNN_regression_average_GICA_xception-custom-2_missing-value_exclude_split-method_5fold_zscore_True_loss_weighted-mse_basech_08_conv_None_conv-pad_same_pool-type_avg_pool-pad_same_fc_0_act_relu_DO_05_BN_X 01th_repeat
=== test evaluate ===
Batch

HBox(children=(FloatProgress(value=0.0, max=5877.0), HTML(value='')))





In [None]:
from IPython.display import display_html
def restartkernel():
    display_html("<script>Jupyter.notebook.kernel.restart()</script>",raw=True)
restartkernel()