# Convolutional Neural Network Training

In [None]:
%load_ext tensorboard

### Load code for this project

In [None]:
import h5py
import numpy as np
import os
import tensorflow as tf
import datetime
import json
import sys

import numpy as np
import os
import warnings
from sklearn.utils.class_weight import compute_class_weight
import sklearn.preprocessing

FEATURES = 0
TARGETS = 1

In [None]:
def load_image_h5_colab():
    """Loads in the AT-TPC data.
        
    Returns:
        A tuple of the form ((train_features, train_targets), (test_features, test_targets))
    """

    data_origin = 'https://github.com/CompPhysics/MachineLearningMSU/raw/master/Day2_materials/data/real-attpc-events.h5'
    
    path = tf.keras.utils.get_file('xyimages.h5', origin=data_origin)
    
    h5 = h5py.File(path, 'r')

    train_features = h5['train_features'][:]
    train_targets = h5['train_targets'][:]
    test_features = h5['test_features'][:]
    test_targets = h5['test_targets'][:]
    
    return (train_features, train_targets), (test_features, test_targets)

In [None]:
def train(train, log_dir, epochs=10, batch_size=32, data_combine=False, rebalance=False, binary=False, lr=0.00001, decay=0., validation_split=0.15, freeze=False,
         examples_limit=-1, seed=71, reverse_labels=True, validation_size=None, use_dropout=True):
    """This function will train a CNN classifier using the VGG16 architecture with ImageNet weights."""
    #assert data.endswith('.h5'), 'train_path must point to an HDF5 file'

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Set random seeds
    np.random.seed(seed)
    tf.random.set_seed(seed)

    ## Load data
    #if data_combine:
    #    a, b = load_image_h5(data, categorical=True, binary=binary, reverse_labels=reverse_labels)
    #    train = np.concatenate([a[FEATURES], b[FEATURES]], axis=0), np.concatenate([a[TARGETS], b[TARGETS]], axis=0)
    #else:
    #    train, _ = load_image_h5(data, categorical=False, binary=binary, reverse_labels=reverse_labels)
        
    #train = sklearn.preprocessing.StandardScaler().fit_transform(train)

    print("TARGETS shape:", len(train),train[TARGETS].shape)
    print("FEATURES shape:", len(train),train[FEATURES].shape, train[FEATURES].shape[1:])
    #num_categories = train[TARGETS].shape[1]
    num_categories = 1


    # Build model
    vgg16_base = tf.keras.applications.VGG16(include_top=False, input_shape=train[FEATURES].shape[1:], weights='imagenet')
    net = vgg16_base.output
    net = tf.keras.layers.Flatten()(net)
    net = tf.keras.layers.Dense(256, activation=tf.nn.relu)(net)
    if use_dropout:
        net = tf.keras.layers.Dropout(0.5)(net)
    preds = tf.keras.layers.Dense(num_categories, activation=tf.nn.sigmoid)(net) 
    model = tf.keras.Model(vgg16_base.input, preds)

    # Freeze convolutional layers if needed
    if freeze:
        for layer in model.layers[:-4]:
            layer.trainable = False

    opt = tf.keras.optimizers.Adam(lr=lr, decay=decay)
    loss = 'binary_crossentropy'# if num_categories == 2 else 'categorical_crossentropy'

    #print("Loss:", loss)

    model.compile(loss=loss,
                  optimizer=opt,
                  metrics=['accuracy'])
    
    print(model.summary())

    #os.makedirs(os.path.join(log_dir, 'ckpt'), exist_ok=True)
    #ckpt_path = os.path.join(log_dir, 'ckpt', 'epoch-{epoch:02d}.h5')
    
    log_run = "freeze{}_dropout{}_lr{}_decay{}_samples{}".format(freeze, use_dropout, lr, decay, examples_limit)
    #print("Log for run:", log_run)
    
    log_dir = os.path.join(log_dir, log_run, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    os.makedirs(log_dir, exist_ok=True)
    print("\nWriting fits to:", log_dir)
    ckpt_path = os.path.join(log_dir, 'epoch-{epoch:02d}.h5')
    print("Checkpoint path:", ckpt_path, "\n")

    # Get class weights
    if rebalance:
        targets_argmax = np.argmax(train[TARGETS], axis=1)
        class_weight = compute_class_weight('balanced', np.unique(targets_argmax), targets_argmax)
        class_weight = dict(enumerate(class_weight))
    else:
        class_weight = None

    # Setup checkpoint callback
    ckpt_callback = tf.keras.callbacks.ModelCheckpoint(ckpt_path,
                                                       save_weights_only=False,
                                                       save_frequency=1,
                                                       save_best_only=False,
                                                       monitor='val_loss')

    # Setup TensorBoard callback
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir)

    val = None

    if validation_size is not None:
        if validation_size >= train[TARGETS].shape[0]:
            raise ValueError('The given validation size must be smaller than the size of the training set ({}).'.format(
                train[TARGETS].shape[0]))
        val = train[FEATURES][-validation_size:], train[TARGETS][-validation_size:]
        train = train[FEATURES][:-validation_size], train[TARGETS][:-validation_size]

    if examples_limit == -1:
        examples_limit = train[TARGETS].shape[0]

    if examples_limit > train[TARGETS].shape[0]:
        warnings.warn('`examples_limit` is larger than the number of examples in the training set. The entire training '
                      'set will be used.')
        examples_limit = train[TARGETS].shape[0]

    # Train the model
    train_start_time = datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")

    history = model.fit(train[FEATURES][:examples_limit],
                        train[TARGETS][:examples_limit],
                        epochs=epochs,
                        batch_size=batch_size,
                        validation_split=validation_split,
                        validation_data=val,
                        verbose=1,
                        class_weight=class_weight,
                        callbacks=[tb_callback, ckpt_callback])

    train_end_time = datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")

    history_filename = os.path.join(log_dir, 'history.json')
    info_filename = os.path.join(log_dir, 'info.txt')
    
    model_filename = os.path.join(log_dir, 'saved_model.h5')
    model.save(model_filename)

    with open(history_filename, 'w') as file:
        json.dump(history.history, file)

    with open(info_filename, 'w') as file:
        file.write('***Training Info***\n')
        file.write('Training Start: {}'.format(train_start_time))
        file.write('Training End: {}\n'.format(train_end_time))
        file.write('Arguments:\n')
        for arg in sys.argv:
            file.write('\t{}\n'.format(arg))


In [None]:
train, _ = load_image_h5_colab()

In [None]:
!rm -rf logs/*

In [None]:
%%time
cnn.train.train(train=train, 
                log_dir='logs/', 
                validation_split=0.15,
                lr=1e-3, 
                freeze=True, 
                examples_limit=160,
                epochs=20, 
                batch_size=32,
                seed=71,
                decay=0.,
                use_dropout=False,
               )

In [None]:
%tensorboard --logdir logs/ --port 6007

In [None]:
cnn.eval.eval(model_file='cnn/logs/saved_model.h5', data=train, name='CNN')