## Experiment 2
<br>
-- ViT Base model hyperparameters explore


In [None]:
import json
import re
import random
import os
from collections import defaultdict
from itertools import product

import tensorflow as tf
import pandas as pd
import numpy as np

import utils

In [None]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    tf.random.set_seed(seed)

seed_everything()

### Data Load

In [None]:
data_path = '/home/donghyun/eye_writing_classification/v2_dataset/200_points_dataset/'

with open(data_path + 'eog_raw_numbers_200.json') as f:
  eog_raw_numbers = json.load(f)

with open(data_path + 'reference_data_200.json') as f:
  reference_data = json.load(f)

### Experiment

In [None]:
# hyperparameters

vit_hidden_size = [128,256,512]
vit_patch_size = [5,10]
vit_heads = [4,8]
vit_n_layers = [8,12]
vit_mlp_units = [[128,64],
                 [64,32]]
vit_dropout = [0]
vit_mlp_dropout = [0]

In [None]:
# Config class

class Config:
    split_ratio = 0.3
    ref_key = 'numbers'
    batch_size = 10            # fix : must be equaled with number of test pairs
    n_batch = 180
    lr = 0.0005
    model_type = 'ViTBaseModel'
    ViT_params = {}
    epochs = 100

In [None]:
# grid search for hyperparameters

cols = ['hidden_size', 'batch_size', 'patch_size', 'heads', 'n_layers', 'mlp_units', 'dropout', 'mlp_dropout', 'score']
best_perform_df = pd.DataFrame(columns=cols)

raw_numbers_dict = defaultdict(list)

i = 0
for hs, ps, heads, n_layers, mlp_units, dropout, mlp_dropout in product(vit_hidden_size,
                                                                            vit_patch_size,
                                                                            vit_heads,
                                                                            vit_n_layers,
                                                                            vit_mlp_units,
                                                                            vit_dropout,
                                                                            vit_mlp_dropout
                                                                            ):
    i+=1
    print('index : ', i)

    cfg = Config
    cfg.ViT_params['hidden_size'] = hs
    cfg.ViT_params['batch_size'] = cfg.batch_size
    cfg.ViT_params['patch_size'] = ps
    cfg.ViT_params['heads'] = heads
    cfg.ViT_params['n_layers'] = n_layers
    cfg.ViT_params['mlp_units'] = mlp_units
    cfg.ViT_params['dropout'] = dropout
    cfg.ViT_params['mlp_dropout'] = mlp_dropout

    _, _, _, test_acc_list = utils.experiment(cfg, eog_raw_numbers, reference_data)
    score = np.mean(test_acc_list[-3:])

    best_perform_df.loc[i] = [hs, cfg.batch_size, ps, heads, n_layers, str(mlp_units), dropout, mlp_dropout, score]

best_perform_df = best_perform_df.sort_values(by='score',ascending=False)

In [None]:
# hyperparameters save

save_path = '/home/donghyun/eye_writing_classification/experiments/save/'
best_perform_df.to_csv(save_path+'experiment2_vit_hyperparams.csv', index=True)

In [None]:
# load the hyperparameters

save_path = '/home/donghyun/eye_writing_classification/experiments/save/'
best_perform_df = pd.read_csv(save_path+'experiment2_vit_hyperparams.csv')

best_perform_df.head(10)

In [None]:
class Config:
    split_ratio = 0.3
    ref_key = 'numbers'
    batch_size = 10            # fix : Not must be equaled with number of test pairs 
    n_batch = 180
    lr = 0.0005
    model_type = 'ViTBaseModel'
    ViT_params = {}
    epochs = 100

best_params = best_perform_df.iloc[0].to_dict()
best_params['mlp_units'] = re.sub('[\[\]]','',best_params['mlp_units'])
best_params['mlp_units'] = list(map(int,best_params['mlp_units'].split(',')))       # str to list

cfg = Config
cfg.ViT_params = best_params

times = 10
raw_numbers_dict = defaultdict(list)
for t in range(times):
    _, raw_train_acc, raw_train_loss, raw_test_acc = utils.experiment(cfg, eog_raw_numbers, reference_data)
    raw_numbers_dict[t] = [raw_train_acc, raw_train_loss, raw_test_acc]


### Save

In [None]:
save_path = '/home/donghyun/eye_writing_classification/experiments/save/'

with open(save_path + 'experiment2_raw_numbers_results.json', 'w') as f:
    json.dump(dict(raw_numbers_dict),f)

### Visualization

In [None]:
save_path = '/home/donghyun/eye_writing_classification/experiments/save/'

with open(save_path+'experiment1_raw_numbers_results.json') as f:
    hybrid_raw_numbers_results = json.load(f)

with open(save_path+'experiment2_raw_numbers_results.json') as f:
    vit_raw_numbers_results = json.load(f)

In [None]:
hybrid_test_acc = []
vit_test_acc = []
for t in range(10):
    key = str(t)
    hybrid_test_acc.append(hybrid_raw_numbers_results[key][2])
    vit_test_acc.append(vit_raw_numbers_results[key][2])

hybrid_avg_results = np.array(hybrid_test_acc).mean(axis=0)
vit_avg_results  =np.array(vit_test_acc).mean(axis=0)

In [None]:
def analysis(data_list):
    return np.mean(data_list), max(data_list), min(data_list), np.std(data_list)

hybrid_numbers_test_performance = [t[-1] for t in hybrid_test_acc]
vit_numbers_test_performance = [t[-1] for t in vit_test_acc]

print('Accuracy base on raw numbers with 10 repetitions')
print(' '*29 +'1,     2,    3,      4,      5,      6,     7,     8,     9,     10,       Avg.   Best.   Worst.  Std.')
print('hybrid model performance : {}, {}'.format(hybrid_numbers_test_performance, analysis(hybrid_numbers_test_performance)))
print('ViT model performance    : {}, {}'.format(vit_numbers_test_performance, analysis(vit_numbers_test_performance)))

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1,2, figsize = (20,8))

# test accuracy
axes[0].plot(hybrid_avg_results, c = 'b', linestyle = 'solid', linewidth = 3)
axes[0].plot(vit_avg_results, c = 'r', linestyle = 'solid', linewidth = 3)

axes[0].set_ylim(20,100)

axes[0].set_title("Evaluation", fontsize=20)
axes[0].set_xlabel('Epoch', fontsize = 20)
axes[0].set_ylabel('Accuracy', fontsize = 20)

axes[0].legend(['Hybrid base model', 'ViT base model'], fontsize = 15)

# plot
plt.show()