In [None]:
import json
import utils
import data_handler as dh

from collections import defaultdict
from itertools import product

import pandas as pd
import numpy as np

### Data settting

In [None]:
data_path = '/home/donghyun/eye_writing_classification/v2_dataset/200_points_dataset/'

with open(data_path + 'eog_raw_numbers_200.json') as f:
  eog_raw_numbers = json.load(f)

with open(data_path + 'eog_katakana_200.json') as f:
  eog_katakana = json.load(f)

with open(data_path + 'reference_data_200.json') as f:
  reference_data = json.load(f)

In [None]:
katakana_le = {'1':'10', '2':'11', '3':'12', '4':'13', '5':'14',
               '6':'15', '7':'16', '8':'17', '9':'18', '10':'19',
               '11':'20', '12':'21'}

all_data = eog_raw_numbers.copy()
for k in eog_katakana.keys():
    cnvt_key = katakana_le[k]
    all_data[cnvt_key] = eog_katakana[k]

all_ref = defaultdict()
all_ref['all'] = reference_data['numbers'].copy()
for k in reference_data['katakana'].keys():
    cnvt_key = katakana_le[k]
    all_ref['all'][cnvt_key] = reference_data['katakana'][k]


In [None]:
save_path = '/home/donghyun/eye_writing_classification/experiments/save/'

best_perform_df = pd.read_csv(save_path + 'experiment2_vit_hyperparams.csv')

### Experiment

In [None]:
class ViT_Config:
    split_ratio = 0.3
    ref_key = 'all'
    batch_size = 22            # fix : must be equaled with number of test pairs 
    n_batch = 50
    model_type = 'ViTBaseModel'
    ViT_params = {}
    epochs = 1000

In [None]:
cfg_list = []
for i in range(5):
    cfg = ViT_Config()
    cfg.ViT_params = best_perform_df.loc[i]
    cfg.ViT_params['mlp_units'] = list(map(int, cfg.ViT_params['mlp_units']))
    cfg_list.append(cfg)

In [None]:
classes = all_data.keys()

result_dict = defaultdict()
for k in classes:
    zero_shot_cls = k
    zero_shot_data = all_data[zero_shot_cls].copy()
    zero_shot_ref = all_ref['all'][zero_shot_cls].copy()

    learn_data = all_data.pop(zero_shot_cls).copy()
    learn_ref = all_ref['all'].pop(zero_shot_cls).copy()

    # train without a class for zero shot learning
    model_list = []
    for i, cfg in enumerate(cfg_list):
        model ,_ ,_ ,_ = utils.experiment(cfg, learn_data, learn_ref)
        model_list.append(model)

    # zero shot inference
    zero_shot_batch, zero_shot_targets = dh.get_test_batch(zero_shot_data, zero_shot_ref, ref_key='all')

    correct = 0
    i = 0
    for batch in zip(zero_shot_batch, zero_shot_targets):
        data, target = batch

        probs = np.zeros((data.shape[0],1))
        for m in model_list:
            probs += model.predict_on_batch(data)
        
        if np.argmax(probs) == np.argmax(target):
            correct += 1
        i+=1
    
    acc = (correct/i)*100
    result_dict[k] = acc

    print('class : {}\'s accuracy using ensemble : {:.4f}%'.format(k, acc))

clear_output()

### Save

In [None]:
save_path = '/home/donghyun/eye_writing_classification/experiments/save/'

with open(save_path + 'experiment4_zero_shot_results.json', 'w') as f:
    json.dump(dict(result_dict),f)

### Visualization

In [None]:
save_path = '/home/donghyun/eye_writing_classification/experiments/save/'

zero_shot_dict = json.dumps(save_path + 'experiment4_zero_shot_results.json')