In [1]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os

In [3]:
# Enter the directory and read the csv files for the History of each Grid-Searched model
dir_list = {'CNN2', 'CNN3', 'CNN4'}

history_dict = {}
for dir in os.listdir('./'):
    if dir in dir_list:
        for file in os.listdir('./'+dir+'/History/'):
            history_dict[file[:-4]] = list(pd.read_csv('./'+dir+'/History/'+file)['val_accuracy'])[-1]

In [6]:
# Split the dictioanry into three - one for each type of architecture (2-Layer CNN, 3-Layer CNN, 4-Layer CNN)
# We only want those with validation accuracy > 0.7 to reduce clutter in our graphs

CNN2_history = {}
CNN3_history = {}
CNN4_history = {}

for k, v in history_dict.items():
    if v >= 0.7:
        if 'CNN2' in k:
            CNN2_history[k] = v
        elif 'CNN3' in k:
            CNN3_history[k] = v
        else:
            CNN4_history[k] = v

{'History_CNN4_history_200_100_60_30_10': 0.8920000195503235,
 'History_CNN4_history_200_100_60_30_20': 0.8892999887466431,
 'History_CNN4_history_200_100_60_50_20': 0.8560000061988831,
 'History_CNN4_history_200_100_90_30_10': 0.8716999888420105,
 'History_CNN4_history_200_100_90_30_20': 0.8180000185966492,
 'History_CNN4_history_200_100_90_50_10': 0.8305000066757202,
 'History_CNN4_history_200_100_90_50_20': 0.8919000029563904,
 'History_CNN4_history_200_200_60_30_10': 0.8935999870300293,
 'History_CNN4_history_200_200_60_30_20': 0.8950999975204468,
 'History_CNN4_history_200_200_60_50_20': 0.7825999855995178,
 'History_CNN4_history_200_200_90_30_10': 0.8752999901771545,
 'History_CNN4_history_200_200_90_30_20': 0.8677999973297119,
 'History_CNN4_history_200_200_90_50_10': 0.8822000026702881,
 'History_CNN4_history_200_200_90_50_20': 0.8715000152587891,
 'History_CNN4_history_20_100_60_30_10': 0.843999981880188,
 'History_CNN4_history_20_100_60_30_20': 0.7784000039100647,
 'History_C

In [10]:
# Get the top three validation accuracies and model's hyper-parameters for each
# of the three architectures. So in total we will get 9 items.

from collections import Counter 

for k, v in CNN2_history.items():
    CNN2_top3 = Counter(CNN2_history)
    CNN2_top3 = CNN2_top3.most_common(3)
for k, v in CNN3_history.items():
    CNN3_top3 = Counter(CNN3_history)
    CNN3_top3 = CNN3_top3.most_common(3)
for k, v in CNN4_history.items():
    CNN4_top3 = Counter(CNN4_history)
    CNN4_top3 = CNN4_top3.most_common(3)


In [51]:
# The earlier step used a tuple. We will now do some data cleaning for presentation sake
# when plotting our graphs later.

CNN2_list = []
CNN3_list = []
CNN4_list = []

CNN2_list_acc = []
CNN3_list_acc = []
CNN4_list_acc = []

for item in CNN2_top3:
    name = item[0]
    name = name[8:]
    name = name.replace('_history', '')
    CNN2_list.append(name)
    CNN2_list_acc.append(item[1])
    
for item in CNN3_top3:
    name = item[0]
    name = name[8:]
    name = name.replace('_history', '')

    CNN3_list.append(name)
    CNN3_list_acc.append(item[1])

for item in CNN4_top3:
    name = item[0]
    name = name[8:]
    name = name.replace('_history', '')

    CNN4_list.append(name)
    CNN4_list_acc.append(item[1])

In [52]:
# Plot a bar plot of all 9 items found earlier. We compare their validation accuracy by making the 
# bar plot into descending order (Left hand side will ahve the highest accuracy)

fig = go.Figure(data=[
    go.Bar(name='2-Layer CNN', x=CNN2_list, y=CNN2_list_acc,
            text=CNN2_list_acc,
            textposition='auto',),
    go.Bar(name='3-Layer CNN', x=CNN3_list, y=CNN3_list_acc,
            text=CNN3_list_acc,
            textposition='auto',),
    go.Bar(name='4-Layer CNN', x=CNN4_list, y=CNN4_list_acc,
            text=CNN4_list_acc,
            textposition='auto',)
])
fig.update_layout(xaxis={'categoryorder':'total descending'}, yaxis_title='Accuracy', title="CNN Grid Search Optimal Hyper-Parameters and Architecture")

fig.show()