In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
import random as rn
from my_utils import Workout_dataset, class_weight_dict
from my_model import make_CNN_RNN_model

import os

In [2]:
# seed 고정
os.environ['PYTHONHASHSEED'] = str(42)

os.environ['TF_DETERMINISTIC_OPS'] = '1'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

tf.random.set_seed(42)
np.random.seed(42)
rn.seed(42)

In [3]:
def scheduler(epoch, lr):
    if (epoch>20) and (lr > 0.00001):
        lr = lr*0.9
        return lr
    else:
        return lr

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)

train_dir = './data/train'
label_dir = './data/data_y_train.csv'
test_dir = './data/test'
test_label_dir = './data/data_y_test.csv'
checkpoint_filepath = "./save/cnn_gru_best.hdf5"

BATCH_SIZE = 64

train_loader = Workout_dataset(
    train_dir, label_dir, mode='Train',
    fold=0, batch_size=BATCH_SIZE, augment=True, shuffle=True)

valid_loader = Workout_dataset(
    train_dir, label_dir, mode='Valid',
    fold=0, batch_size=16, shuffle=True)

test_loader = Workout_dataset(
    test_dir, test_label_dir, mode='Test',
    batch_size=625, shuffle=False)


In [4]:
model = make_CNN_RNN_model(
    lr = 0.001,
    leakyrelu_alpha = 0.2,
    input_kernels = 10,
    input_kernel_width = 3,
    res_kernels = 60,
    res_kernel_width = 3,
    res_regularize_coeff=0.1, #0.2
    res_num = 7, #5
    )

model.summary()

__________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 150, 60)      240         leaky_re_lu_14[0][0]             
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 150, 60)      240         leaky_re_lu_18[0][0]             
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 150, 60)      240         leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
gru_1 (GRU)                     (None, 150, 60)      21960       batch_normalization_2[0][0]      
__________________________________________________________________________________________________
gru_3 (GRU)                     (None, 150, 60)      21960       batch_normalization_6[0][0]      
__________________________

In [5]:

save_best = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath, monitor='val_loss', verbose=1, save_best_only=True,
    save_weights_only=True, mode='auto', save_freq='epoch', options=None)

early_stop = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',min_delta=0.0001,
    patience=20,verbose=1)

history = model.fit_generator(
    generator=train_loader,
    validation_data=valid_loader,
    epochs=2000,
    callbacks=[save_best,early_stop,lr_scheduler],
    class_weight=class_weight_dict)


Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Epoch 1/2000

Epoch 00001: val_loss improved from inf to 175.48611, saving model to ./save/cnn_gru_best.hdf5
Epoch 2/2000

Epoch 00002: val_loss improved from 175.48611 to 89.83839, saving model to ./save/cnn_gru_best.hdf5
Epoch 3/2000

Epoch 00003: val_loss improved from 89.83839 to 48.18031, saving model to ./save/cnn_gru_best.hdf5
Epoch 4/2000

Epoch 00004: val_loss improved from 48.18031 to 27.18104, saving model to ./save/cnn_gru_best.hdf5
Epoch 5/2000

Epoch 00005: val_loss improved from 27.18104 to 16.52025, saving model to ./save/cnn_gru_best.hdf5
Epoch 6/2000

Epoch 00006: val_loss improved from 16.52025 to 11.36938, saving model to ./save/cnn_gru_best.hdf5
Epoch 7/2000

Epoch 00007: val_loss improved from 11.36938 to 8.15333, saving model to ./save/cnn_gru_best.hdf5
Epoch 8/2000

Epoch 00008: val_loss improved from 8.15333 to 6.59477, savin

In [6]:
model.load_weights(checkpoint_filepath)
model.evaluate_generator(generator=test_loader,verbose=1)



[0.7734295129776001, 0.8335999846458435]