# Model for classification of triangles vs. squares

### Settings

In [None]:
# Imports
import warnings
warnings.simplefilter('ignore')

import numpy as np
from numpy import load
import os.path
import matplotlib.pyplot as plt

import keras.backend as K
if(K.tensorflow_backend):
    import tensorflow as tf
    tf.logging.set_verbosity(tf.logging.ERROR)
from keras.utils import np_utils
from keras.models import load_model

from code import shape_images as shi
from code import model as mod

### Loading of pre-generated data

In [None]:
# filenames
# data paths
original_data_path = "/home/elena/eStep/XAI/Data/TrianglesAndSquaresRotationScale"

train_data_fname = os.path.join(original_data_path, 'split_npz','train_data.npz')
test_data_fname = os.path.join(original_data_path, 'split_npz','test_data.npz')
val_data_fname = os.path.join(original_data_path, 'split_npz','validation_data.npz')

# loading
train_data = np.load(train_data_fname)
test_data = np.load(test_data_fname)
val_data = np.load(val_data_fname)


In [None]:
images_train = train_data['images_train']
labels_train = train_data['labels_train']
images_test = test_data['images_test']
labels_test_or = test_data['labels_test']
images_val = val_data['images_val']
labels_val = val_data['labels_val']

print("Size of training data: ", np.shape(images_train), "and labels: ", np.shape(labels_train))
print("Size of validation data: ", np.shape(images_val), "and labels: ", np.shape(labels_val))
print("Size of testing data: ", np.shape(images_test), "and labels: ", np.shape(labels_test_or))

#### Images formatting

In [None]:
img_rows = 64
img_cols = 64
#print(K.image_data_format())
if K.image_data_format() == 'channels_first':
    images_train = images_train.reshape(images_train.shape[0], 1, img_rows, img_cols)
    images_test = images_test.reshape(images_test.shape[0], 1, img_rows, img_cols)
    images_val = images_val.reshape(images_val.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    images_train = images_train.reshape(images_train.shape[0], img_rows, img_cols, 1)
    images_test = images_test.reshape(images_test.shape[0], img_rows, img_cols, 1)
    images_val = images_val.reshape(images_val.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)
print("Size of training data: ", np.shape(images_train))
print("Size of validation data: ", np.shape(images_val))
print("Size of testing data: ", np.shape(images_test))

In [None]:
# plot 12 random train images
shi.plot_12images(images_train, labels_train) 

#### Labels formatting

In [None]:
# convert class vectors to binary class matrices
labels_train = np_utils.to_categorical(labels_train, num_classes=2)
labels_test = np_utils.to_categorical(labels_test_or, num_classes=2)
labels_val = np_utils.to_categorical(labels_val, num_classes=2)
print(labels_train)
print('labels_train shape:', labels_train.shape)
print('labels_test shape:', labels_test.shape)
print('labels_val shape:', labels_val.shape)

## CNN model

In [None]:
# parameters of the training
batch_size = 200
epochs = 5
num_classes = 2

In [None]:
# generate the model
model = mod.generate_model(input_shape, num_classes)

print(model.summary())

In [None]:
# train 
mod.train_model(model, images_train, labels_train, images_val, labels_val, batch_size, epochs)

### Save the model

In [None]:
# filename for model saving
model_fname = os.path.join(original_data_path, 'Models','model.h5')

In [None]:
# save the trained model
model.save(model_fname)
print("Saved model to disk")

## Test the model

In [None]:
# load the trained model
model = load_model(model_fname) 
print("Loaded model from disk")

In [None]:
score = model.evaluate(images_test, labels_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

In [None]:
# pick up 10 random images and classify them using the trained model
figsize = (8, 6)
plt.figure(figsize=figsize)
for _ in range(10):
    ind=int(np.random.randint(1,nim))
    img=images_test[ind,:]
    img=np.reshape(img,(64,64))
    label=labels_test_or[ind]
       
    predictions = model.predict(img);
    #print(predictions)
    pred = np.argmax(predictions) 
    #print(pred)
    
    j = j+1
    plt.subplot(5, 2, j)
    plt.imshow(img*255,cmap=cm.gray, vmin=0, vmax=255)
    plt.xticks([])
    plt.yticks([])
    plt.title('n=%d n̂=%d' % (label, pred))

plt.show()