In [None]:
import tensorflow as tf
import pydot

import matplotlib.pyplot as plt
import numpy as np

import pickle
from PIL import Image

## Plot Models

Load in the pretrained network models before plotting their architectures.

In [None]:
cnn = tf.keras.models.load_model('CNN GAP 2')
cnn.summary()

In [None]:
tf.keras.utils.plot_model(cnn, 'cnn_model.png', show_shapes=True, show_layer_names=False, show_layer_activations=True, dpi=600)

In [None]:
ffn = tf.keras.models.load_model('FFN')
ffn.summary()

In [None]:
tf.keras.utils.plot_model(ffn, 'ffn_model.png', show_shapes=True, show_layer_names=False, show_layer_activations=True, dpi=600)

## Figures

Load model architectures and training histories.

In [None]:
with open('FFN_training_history.pkl', 'rb') as f:
    ffn_hist = pickle.load(f)
    f.close()

all_loss = np.array([np.array(ffn_hist['loss']), np.array(ffn_hist['val_loss'])])
all_acc = np.array([np.array(ffn_hist['accuracy']), np.array(ffn_hist['val_accuracy'])])
min_max_loss = (all_loss.min() - 0.1, all_loss.max() + 0.1)
min_max_acc = (all_acc.min() - 0.05, all_acc.max() + 0.05)

In [None]:
fig = plt.figure(figsize=(6, 4), dpi=600)
spec = fig.add_gridspec(2, 2)
ax1 = fig.add_subplot(spec[:, 0])
ax2 = fig.add_subplot(spec[0, 1])
ax3 = fig.add_subplot(spec[1, 1])

ax1.imshow(Image.open("ffn_model.png"))
ax1.set_axis_off()

ax2_twin = ax2.twinx()
ax2.plot(ffn_hist['val_loss'])
ax2_twin.plot(ffn_hist['val_accuracy'], color='#ff7f0e')
ax2.set_ylabel("SCC Loss", weight='bold')
ax2_twin.set_ylabel("Accuracy", weight='bold')
ax2.set(ylim=(min_max_loss[0], min_max_loss[1]))
ax2_twin.set(ylim=(min_max_acc[0], min_max_acc[1]))

ax3_twin = ax3.twinx()
ax3.plot(ffn_hist['loss'])
ax3_twin.plot(ffn_hist['accuracy'], color='#ff7f0e')
ax3.set_ylabel("SCC Loss", weight='bold')
ax3_twin.set_ylabel("Accuracy", weight='bold')
ax3.set(ylim=(min_max_loss[0], min_max_loss[1]))
ax3_twin.set(ylim=(min_max_acc[0], min_max_acc[1]))

ax3.set_xlabel("Epochs", weight='bold')

ax1.annotate('(a)', xy=(0, 1.1), xycoords='axes fraction', weight='bold')
ax2.annotate('(b)', xy=(0, 1.1), xycoords='axes fraction', weight='bold')
ax3.annotate('(c)', xy=(0, 1.1), xycoords='axes fraction', weight='bold')

ax2.tick_params(width=1.5)
ax3.tick_params(width=1.5)
for axis in ['top', 'right', 'bottom', 'left']:
    ax2.spines[axis].set_linewidth(1.5)
    ax3.spines[axis].set_linewidth(1.5)
        

fig.tight_layout()

fig.savefig("FFN SI.png")

In [None]:
with open('CNN_training_history.pkl', 'rb') as f:
    ffn_hist = pickle.load(f)
    f.close()

all_loss = np.array([np.array(ffn_hist['loss']), np.array(ffn_hist['val_loss'])])
all_acc = np.array([np.array(ffn_hist['accuracy']), np.array(ffn_hist['val_accuracy'])])
min_max_loss = (all_loss.min() - 0.1, all_loss.max() + 0.1)
min_max_acc = (all_acc.min() - 0.05, all_acc.max() + 0.05)

In [None]:
fig = plt.figure(figsize=(6, 4), dpi=600)
spec = fig.add_gridspec(2, 2)
ax1 = fig.add_subplot(spec[:, 0])
ax2 = fig.add_subplot(spec[0, 1])
ax3 = fig.add_subplot(spec[1, 1])

ax1.imshow(Image.open("cnn_model.png"))
ax1.set_axis_off()

ax2_twin = ax2.twinx()
ax2.plot(ffn_hist['val_loss'])
ax2_twin.plot(ffn_hist['val_accuracy'], color='#ff7f0e')
ax2.set_ylabel("SCC Loss", weight='bold')
ax2_twin.set_ylabel("Accuracy", weight='bold')
ax2.set(ylim=(min_max_loss[0], min_max_loss[1]))
ax2_twin.set(ylim=(min_max_acc[0], min_max_acc[1]))

ax3_twin = ax3.twinx()
ax3.plot(ffn_hist['loss'])
ax3_twin.plot(ffn_hist['accuracy'], color='#ff7f0e')
ax3.set_ylabel("SCC Loss", weight='bold')
ax3_twin.set_ylabel("Accuracy", weight='bold')
ax3.set(ylim=(min_max_loss[0], min_max_loss[1]))
ax3_twin.set(ylim=(min_max_acc[0], min_max_acc[1]))

ax3.set_xlabel("Epochs", weight='bold')

ax1.annotate('(a)', xy=(0, 1.1), xycoords='axes fraction', weight='bold')
ax2.annotate('(b)', xy=(0, 1.1), xycoords='axes fraction', weight='bold')
ax3.annotate('(c)', xy=(0, 1.1), xycoords='axes fraction', weight='bold')

ax2.tick_params(width=1.5)
ax3.tick_params(width=1.5)
for axis in ['top', 'right', 'bottom', 'left']:
    ax2.spines[axis].set_linewidth(1.5)
    ax3.spines[axis].set_linewidth(1.5)

fig.tight_layout()

fig.savefig("CNN SI.png")