In [4]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import RobustScaler
from keras.utils import to_categorical
from keras.layers import Input, Dense, Dropout, Conv1D, Flatten, BatchNormalization, GlobalMaxPool1D, concatenate
from keras.models import Model


df_train = pd.read_csv("wdata\\train_feature_bin_30_slice.csv")

def get_training_data(df_train):

    
    target = df_train.iloc[:, -1]
    y = to_categorical(target, num_classes=len(np.unique(target)))
    x_trn = df_train.iloc[:,1:-1]  

    # scale train ==========================================================================================================  
    X = x_trn.values
    where_are_NaNs = np.isnan(X)
    where_are_infs = np.isinf(X)
    X[where_are_NaNs] = 0
    X[where_are_infs] = 0

    scaler = RobustScaler()
    scaler.fit(X)
    scaled_train_X = scaler.transform(X)
    X = scaled_train_X
    X = X.reshape(len(df_train), len(X[0]), 1)

    return (X, y)

def init_model(num_features):

    inp = Input(shape=(num_features, 1))

    a = Conv1D(64, 5, activation="relu", kernel_initializer="uniform", )(inp)
    a = BatchNormalization()(a)
    a = Conv1D(64, 5, activation="relu", kernel_initializer="uniform", )(a)
    a = BatchNormalization()(a)
    max_pool = GlobalMaxPool1D()(a)

    b = Flatten()(inp)
    ab = concatenate([ max_pool, b])

    a = Dense(128, activation="relu", kernel_initializer="uniform")(ab)
    a = Dropout(0.5)(a)
    a = Dense(128, activation="relu", kernel_initializer="uniform")(a)

    output = Dense(7, activation="softmax", kernel_initializer="uniform")(a)
    model = Model(inp, output)

    return model




In [11]:
_ids = df_train.index[df_train['152'] == 1].to_list()
all_length = len(_ids)
fold_len = int(all_length / 5)

init_idx = 0
_train_idx = _ids[init_idx: init_idx + fold_len]
df_train.loc[_train_idx, 'fold'] = 0

In [13]:
df_train.columns

Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
       ...
       '144', '145', '146', '147', '148', '149', '150', '151', '152', 'fold'],
      dtype='object', length=154)