# Imports

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, \
GlobalMaxPooling2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.models import Model
from sklearn.metrics import confusion_matrix
from percolation import read_percolation_data
from sklearn.model_selection import train_test_split
import sys, os

In [None]:
from datetime import datetime
# the time string for using as the name of output files
start_time_string = datetime.now().strftime("%Y.%m.%d-%H.%M")

# Loading and preprocessing the data

In [None]:
export_model_name = 'cnn-percolation'  # the base name of exported files

L = 32
round_digit = 2
pc = 0.59275
p_arr = np.round(np.arange(0.56, 0.63, 0.01), round_digit)

In [None]:
X, y, unique_labels = read_percolation_data(L, p_arr, pc, max_configs_per_p=1000)

In [None]:
print('X.shape={}\ty.shape={}'.format(X.shape, y.shape))
print('p_arr:', p_arr)
print('labels:', unique_labels)

In [None]:
N = X.shape[0]
L = X.shape[1] 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42, stratify=y)

In [None]:
print('X_train.shape={}\ty_train.shape={}'.format(X_train.shape, y_train.shape))
print('X_test.shape={}\ty_test.shape={}'.format(X_test.shape, y_test.shape))

In [None]:
K = len(set(y_train))
if K != len(unique_labels):
    print ('# ERROR: mismatch between K and len(unique_labels)')
    sys.exit()
    
print("number_of_classes:", K)

# Definition of the network and training for classification

In [None]:
def CNN_net(L, K):
    # input layer
    i = Input(shape=(L,L,1))

    # Convolution block
    x = Conv2D(32, (3,3), activation='relu', padding='same')(i)
    x = BatchNormalization()(x)
    x = Conv2D(32, (3,3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2,2))(x)
    
    x = Conv2D(64, (3,3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Conv2D(64, (3,3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2,2))(x)
    
    x = Conv2D(128, (3,3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, (3,3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2,2))(x)

    # Classification block
    x = Flatten()(x)
    x = Dropout(0.2)(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.2)(x)
    x = Dense(1024, activation='relu')(x)
    x = Dropout(0.2)(x)
    x = Dense(K, activation='softmax')(x)

    model = Model(i, x)
    return model

In [None]:
model_phase = CNN_net(L, K)

In [None]:
# Compiling the model

# Inverse time decaying learning rate
# initial_learning_rate / (1 + decay_rate * floor(step / decay_step))
initial_learning_rate = 0.01
decay_steps = 1.0
decay_rate = 0.5
learning_rate_fn = tf.keras.optimizers.schedules.InverseTimeDecay(initial_learning_rate, decay_steps, decay_rate)

opt = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
model_phase.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Callbacks
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=7, mode='min')

## Training

In [None]:
# training the model
r_phase = model_phase.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10)

## Exporting the model

In [None]:
home_pwd = os.path.expanduser("~")  # the path of home directory of current user
out_pwd = os.path.join(home_pwd, 'models') 
os.makedirs(out_pwd, exist_ok=True)  # create a directory at $home/models

#current_time_string = datetime.now().strftime("%Y.%m.%d-%H.%M")
fname = start_time_string + '--' + export_model_name + '.h5'
fpath = os.path.join(out_pwd, fname)

model_phase.save(fpath)

## Loading a model

In [None]:
# to load it:
#model_phase = tf.keras.models.load_model('~/ml/research/criticality_trained_models/cnn_ising_phase.h5')

## Plots

In [None]:
import pandas as pd
dframe = pd.DataFrame(r_phase.history)

dfrme_accu = dframe[['accuracy', 'val_accuracy']]
dfrme_accu.plot()
plt.grid(True)

dfrme_loss = dframe[['loss', 'val_loss']]
dfrme_loss.plot()
plt.grid(True)


In [None]:
# Let us plot the confusion matrix
p_test = model_phase.predict(X_test).argmax(axis=1)
conf_matrix = confusion_matrix(y_test, p_test)

In [None]:
plt.figure(figsize=(7,7))
sns.heatmap(conf_matrix, cmap='Blues', annot=True,
           xticklabels=p_arr, yticklabels=p_arr)
plt.xlabel('predicted')
plt.ylabel('true label')
#plt.savefig('./saved_images/cnn_percolation([0.57, 0.58, 0.61, 0.62, pc], L=256).jpg')
plt.show()