In [1]:
# Load autoreload extension
%load_ext autoreload

# Set autoreload behavior
%autoreload 2
import os
import time
import argparse
import random
import numpy as np
from models.ProTACT import build_ProTACT
import tensorflow as tf
from configs.configs import Configs
from utils.read_data_pr import read_pos_vocab, read_word_vocab, read_prompts_we, read_essays_prompts, read_prompts_pos
from utils.general_utils import get_scaled_down_scores, pad_hierarchical_text_sequences, get_attribute_masks, load_word_embedding_dict, build_embedd_table
from evaluators.multitask_evaluator_all_attributes import Evaluator as AllAttEvaluator
from tensorflow import keras
import matplotlib.pyplot as plt
    

[nltk_data] Downloading package punkt to /Users/joohwan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /Users/joohwan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [2]:
# parser = argparse.ArgumentParser(description="ProTACT model")
# parser.add_argument('--test_prompt_id', type=int, default=1, help='prompt id of test essay set')
# parser.add_argument('--seed', type=int, default=12, help='set random seed')
# parser.add_argument('--model_name', type=str,
#                     choices=['ProTACT'],
#                     help='name of model')
# parser.add_argument('--num_heads', type=int, default=2, help='set the number of heads in Multihead Attention')
# parser.add_argument('--features_path', type=str, default='data/hand_crafted_v3.csv')

test_prompt_id = 1
seed = 1
num_heads = 2
features_path = '../data/hand_crafted_v3.csv'

np.random.seed(seed)
tf.random.set_seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

print("Test prompt id is {} of type {}".format(test_prompt_id, type(test_prompt_id)))
print("Seed: {}".format(seed))

configs = Configs()

data_path = configs.DATA_PATH
train_path = data_path + str(test_prompt_id) + '/train.pk'
dev_path = data_path + str(test_prompt_id) + '/dev.pk'
test_path = data_path + str(test_prompt_id) + '/test.pk'
pretrained_embedding = configs.PRETRAINED_EMBEDDING
embedding_path = configs.EMBEDDING_PATH
readability_path = configs.READABILITY_PATH
prompt_path = configs.PROMPT_PATH
vocab_size = configs.VOCAB_SIZE
epochs = configs.EPOCHS
batch_size = configs.BATCH_SIZE
print("Numhead : ", num_heads, " | Features : ", features_path, " | Pos_emb : ", configs.EMBEDDING_DIM)

read_configs = {
    'train_path': train_path,
    'dev_path': dev_path,
    'test_path': test_path,
    'features_path': features_path,
    'readability_path': readability_path,
    'vocab_size': vocab_size
}

Test prompt id is 1 of type <class 'int'>
Seed: 1
Numhead :  2  | Features :  ../data/hand_crafted_v3.csv  | Pos_emb :  50


In [3]:
pos_vocab = read_pos_vocab(read_configs)
# read POS for prompts
prompt_pos_data = read_prompts_pos(prompt_path, pos_vocab) # for prompt POS embedding 

word_vocab = read_word_vocab(read_configs)
# read words for prompts 
prompt_data = read_prompts_we(prompt_path, word_vocab) # for prompt word embedding 

# read essays and prompts
train_data, dev_data, test_data = read_essays_prompts(read_configs, prompt_data, prompt_pos_data, pos_vocab) 

if pretrained_embedding:
    embedd_dict, embedd_dim, _ = load_word_embedding_dict(embedding_path)
    embedd_matrix = build_embedd_table(word_vocab, embedd_dict, embedd_dim, caseless=True)
    embed_table = [embedd_matrix]
else:
    embed_table = None

 prompt_pos size: 8
 prompt_words size: 8
 pos_x size: 9513
 readability_x size: 9513
 pos_x size: 1680
 readability_x size: 1680
 pos_x size: 1783
 readability_x size: 1783
Loading GloVe ...
OOV number =189, OOV ratio = 0.047262


In [4]:
import pandas as pd
print(len(train_data['features_x'][0]))
pd.DataFrame(train_data)
# embedd_dict

51


Unnamed: 0,essay_ids,pos_x,prompt_words,prompt_pos,readability_x,features_x,data_y,prompt_ids,max_sentnum,max_sentlen
0,7532,"[[2, 3, 4, 2, 5, 6, 2, 5, 4, 7, 3, 8], [9, 5, ...","[[662, 2552, 736, 281, 165, 319, 106, 4], [255...","[[7, 5, 10, 3, 12, 7, 5, 8], [5, 20, 3, 20, 5,...","[0.49193152567610715, 0.29664566400513215, 0.5...","[0.4506668631640999, 0.3840642800992836, 0.062...","[1, 1, -1, -1, -1, -1, 1, 1, 1]",3,97,50
1,7229,"[[2, 3, 4, 2, 5, 20, 2, 11, 7, 5, 4, 2, 5, 8],...","[[662, 2552, 736, 281, 165, 319, 106, 4], [255...","[[7, 5, 10, 3, 12, 7, 5, 8], [5, 20, 3, 20, 5,...","[0.5579662498866749, 0.37479701404401994, 0.59...","[0.4513708636511923, 0.3586686128111479, 0.085...","[3, 3, -1, -1, -1, -1, 3, 3, 2]",3,97,50
2,4154,"[[9, 4, 2, 3, 4, 2, 29, 3, 4, 29, 3, 10, 5, 4,...","[[218, 125, 4], [122, 72, 73, 340, 1007, 124, ...","[[5, 3, 8], [15, 20, 5, 20, 5, 3, 3, 16, 5, 8]...","[0.8612956765106549, 0.7423848792139746, 0.619...","[0.5640345883998679, 0.2858415150037477, 0.066...","[4, 4, 4, 4, 5, 4, -1, -1, -1]",2,97,50
3,17950,"[[12, 5, 10, 7, 13, 18, 24, 18, 3, 18, 20, 18,...","[[662, 248, 4], [133, 405, 1090, 2011, 4], [13...","[[7, 5, 8], [5, 10, 7, 5, 8], [7, 5, 5, 3, 4, ...","[0.20688685950461458, 0.1460481802533566, 0.35...","[0.21863912823111997, 0.1030009415010439, 0.04...","[11, 2, 2, -1, -1, 4, -1, -1, -1]",7,97,50
4,9785,"[[2, 5, 10, 2, 5, 4, 2, 5, 4, 14, 2, 11, 7, 12...","[[90, 271, 131, 84, 4], [190, 108, 150, 814, 8...","[[16, 7, 7, 5, 8], [18, 11, 5, 6, 11, 12, 3, 6...","[0.7090757475733677, 0.48555603009134607, 0.79...","[0.5980343980343981, 0.5082625502524476, 0.090...","[1, 1, -1, -1, -1, -1, 1, 1, 1]",4,97,50
...,...,...,...,...,...,...,...,...,...,...
9508,6877,"[[4, 2, 5, 4, 5, 11, 20, 11, 18, 16, 7, 4, 29,...","[[662, 2552, 736, 281, 165, 319, 106, 4], [255...","[[7, 5, 10, 3, 12, 7, 5, 8], [5, 20, 3, 20, 5,...","[0.6420710849952196, 0.4709075862939792, 0.767...","[0.5665003864259118, 0.4089167978068259, 0.093...","[2, 2, -1, -1, -1, -1, 2, 1, 1]",3,97,50
9509,3661,"[[5, 28, 2, 5, 4, 12, 7, 3, 4, 2, 7, 5, 8], [1...","[[218, 125, 4], [122, 72, 73, 340, 1007, 124, ...","[[5, 3, 8], [15, 20, 5, 20, 5, 3, 3, 16, 5, 8]...","[0.7470330548911367, 0.5322450678768516, 0.688...","[0.7336936048109188, 0.3952062286100636, 0.031...","[4, 5, 4, 4, 5, 4, -1, -1, -1]",2,97,50
9510,10364,"[[4, 15, 20, 11, 24, 5, 6, 11, 13, 18, 24, 4, ...","[[90, 271, 131, 84, 4], [190, 108, 150, 814, 8...","[[16, 7, 7, 5, 8], [18, 11, 5, 6, 11, 12, 3, 6...","[0.6775704345317742, 0.5015464633073149, 0.707...","[0.44895975111802455, 0.27287771704373115, 0.1...","[1, 1, -1, -1, -1, -1, 1, 1, 1]",4,97,50
9511,18367,"[[9, 5, 14, 5, 6, 14, 3, 22, 5, 6, 15, 8], [14...","[[662, 248, 4], [133, 405, 1090, 2011, 4], [13...","[[7, 5, 8], [5, 10, 7, 5, 8], [7, 5, 5, 3, 4, ...","[0.4465909693384573, 0.375902607097603, 0.5005...","[0.2542676650242186, 0.08406393382610401, 0.09...","[18, 5, 5, -1, -1, 4, -1, -1, -1]",7,97,50


In [5]:
max_sentlen = max(train_data['max_sentlen'], dev_data['max_sentlen'], test_data['max_sentlen'])
max_sentnum = max(train_data['max_sentnum'], dev_data['max_sentnum'], test_data['max_sentnum'])
prompt_max_sentlen = prompt_data['max_sentlen']
prompt_max_sentnum = prompt_data['max_sentnum']

print('max sent length: {}'.format(max_sentlen))
print('max sent num: {}'.format(max_sentnum))
print('max prompt sent length: {}'.format(prompt_max_sentlen))
print('max prompt sent num: {}'.format(prompt_max_sentnum))

train_data['y_scaled'] = get_scaled_down_scores(train_data['data_y'], train_data['prompt_ids'])
dev_data['y_scaled'] = get_scaled_down_scores(dev_data['data_y'], dev_data['prompt_ids'])
test_data['y_scaled'] = get_scaled_down_scores(test_data['data_y'], test_data['prompt_ids'])

X_train_pos = pad_hierarchical_text_sequences(train_data['pos_x'], max_sentnum, max_sentlen)
X_dev_pos = pad_hierarchical_text_sequences(dev_data['pos_x'], max_sentnum, max_sentlen)
X_test_pos = pad_hierarchical_text_sequences(test_data['pos_x'], max_sentnum, max_sentlen)

X_train_pos = X_train_pos.reshape((X_train_pos.shape[0], X_train_pos.shape[1] * X_train_pos.shape[2]))
X_dev_pos = X_dev_pos.reshape((X_dev_pos.shape[0], X_dev_pos.shape[1] * X_dev_pos.shape[2]))
X_test_pos = X_test_pos.reshape((X_test_pos.shape[0], X_test_pos.shape[1] * X_test_pos.shape[2]))

X_train_prompt = pad_hierarchical_text_sequences(train_data['prompt_words'], max_sentnum, max_sentlen)
X_dev_prompt = pad_hierarchical_text_sequences(dev_data['prompt_words'], max_sentnum, max_sentlen)
X_test_prompt = pad_hierarchical_text_sequences(test_data['prompt_words'], max_sentnum, max_sentlen)

X_train_prompt = X_train_prompt.reshape((X_train_prompt.shape[0], X_train_prompt.shape[1] * X_train_prompt.shape[2]))
X_dev_prompt = X_dev_prompt.reshape((X_dev_prompt.shape[0], X_dev_prompt.shape[1] * X_dev_prompt.shape[2]))
X_test_prompt = X_test_prompt.reshape((X_test_prompt.shape[0], X_test_prompt.shape[1] * X_test_prompt.shape[2]))

X_train_prompt_pos = pad_hierarchical_text_sequences(train_data['prompt_pos'], max_sentnum, max_sentlen)
X_dev_prompt_pos = pad_hierarchical_text_sequences(dev_data['prompt_pos'], max_sentnum, max_sentlen)
X_test_prompt_pos = pad_hierarchical_text_sequences(test_data['prompt_pos'], max_sentnum, max_sentlen)

X_train_prompt_pos = X_train_prompt_pos.reshape((X_train_prompt_pos.shape[0], X_train_prompt_pos.shape[1] * X_train_prompt_pos.shape[2]))
X_dev_prompt_pos = X_dev_prompt_pos.reshape((X_dev_prompt_pos.shape[0], X_dev_prompt_pos.shape[1] * X_dev_prompt_pos.shape[2]))
X_test_prompt_pos = X_test_prompt_pos.reshape((X_test_prompt_pos.shape[0], X_test_prompt_pos.shape[1] * X_test_prompt_pos.shape[2]))

X_train_linguistic_features = np.array(train_data['features_x'])
X_dev_linguistic_features = np.array(dev_data['features_x'])
X_test_linguistic_features = np.array(test_data['features_x'])

X_train_readability = np.array(train_data['readability_x'])
X_dev_readability = np.array(dev_data['readability_x'])
X_test_readability = np.array(test_data['readability_x'])

Y_train = np.array(train_data['y_scaled'])
Y_dev = np.array(dev_data['y_scaled'])
Y_test = np.array(test_data['y_scaled'])

X_train_attribute_rel = get_attribute_masks(Y_train)
X_dev_attribute_rel = get_attribute_masks(Y_dev)
X_test_attribute_rel = get_attribute_masks(Y_test)

print('================================')
print('X_train_pos: ', X_train_pos.shape)
print('X_train_prompt_words: ', X_train_prompt.shape)
print('X_train_prompt_pos: ', X_train_prompt_pos.shape)
print('X_train_readability: ', X_train_readability.shape)
print('X_train_ling: ', X_train_linguistic_features.shape)
print('X_train_attribute_rel: ', X_train_attribute_rel.shape)
print('Y_train: ', Y_train.shape)

print('================================')
print('X_dev_pos: ', X_dev_pos.shape)
print('X_dev_prompt_words: ', X_dev_prompt.shape)
print('X_dev_prompt_pos: ', X_dev_prompt_pos.shape)
print('X_dev_readability: ', X_dev_readability.shape)
print('X_dev_ling: ', X_dev_linguistic_features.shape)
print('X_dev_attribute_rel: ', X_dev_attribute_rel.shape)
print('Y_dev: ', Y_dev.shape)

print('================================')
print('X_test_pos: ', X_test_pos.shape)
print('X_test_prompt_words: ', X_test_prompt.shape)
print('X_test_prompt_pos: ', X_test_prompt_pos.shape)
print('X_test_readability: ', X_test_readability.shape)
print('X_test_ling: ', X_test_linguistic_features.shape)
print('X_test_attribute_rel: ', X_test_attribute_rel.shape)
print('Y_test: ', Y_test.shape)
print('================================')

max sent length: 50
max sent num: 97
max prompt sent length: 18
max prompt sent num: 8
X_train_pos:  (9513, 4850)
X_train_prompt_words:  (9513, 4850)
X_train_prompt_pos:  (9513, 4850)
X_train_readability:  (9513, 35)
X_train_ling:  (9513, 51)
X_train_attribute_rel:  (9513, 9)
Y_train:  (9513, 9)
X_dev_pos:  (1680, 4850)
X_dev_prompt_words:  (1680, 4850)
X_dev_prompt_pos:  (1680, 4850)
X_dev_readability:  (1680, 35)
X_dev_ling:  (1680, 51)
X_dev_attribute_rel:  (1680, 9)
Y_dev:  (1680, 9)
X_test_pos:  (1783, 4850)
X_test_prompt_words:  (1783, 4850)
X_test_prompt_pos:  (1783, 4850)
X_test_readability:  (1783, 35)
X_test_ling:  (1783, 51)
X_test_attribute_rel:  (1783, 9)
Y_test:  (1783, 9)


In [6]:
pd.DataFrame(train_data)

Unnamed: 0,essay_ids,pos_x,prompt_words,prompt_pos,readability_x,features_x,data_y,prompt_ids,max_sentnum,max_sentlen,y_scaled
0,7532,"[[2, 3, 4, 2, 5, 6, 2, 5, 4, 7, 3, 8], [9, 5, ...","[[662, 2552, 736, 281, 165, 319, 106, 4], [255...","[[7, 5, 10, 3, 12, 7, 5, 8], [5, 20, 3, 20, 5,...","[0.49193152567610715, 0.29664566400513215, 0.5...","[0.4506668631640999, 0.3840642800992836, 0.062...","[1, 1, -1, -1, -1, -1, 1, 1, 1]",3,97,50,"[0.3333333333333333, 0.3333333333333333, -1, -..."
1,7229,"[[2, 3, 4, 2, 5, 20, 2, 11, 7, 5, 4, 2, 5, 8],...","[[662, 2552, 736, 281, 165, 319, 106, 4], [255...","[[7, 5, 10, 3, 12, 7, 5, 8], [5, 20, 3, 20, 5,...","[0.5579662498866749, 0.37479701404401994, 0.59...","[0.4513708636511923, 0.3586686128111479, 0.085...","[3, 3, -1, -1, -1, -1, 3, 3, 2]",3,97,50,"[1.0, 1.0, -1, -1, -1, -1, 1.0, 1.0, 0.6666666..."
2,4154,"[[9, 4, 2, 3, 4, 2, 29, 3, 4, 29, 3, 10, 5, 4,...","[[218, 125, 4], [122, 72, 73, 340, 1007, 124, ...","[[5, 3, 8], [15, 20, 5, 20, 5, 3, 3, 16, 5, 8]...","[0.8612956765106549, 0.7423848792139746, 0.619...","[0.5640345883998679, 0.2858415150037477, 0.066...","[4, 4, 4, 4, 5, 4, -1, -1, -1]",2,97,50,"[0.6, 0.6, 0.6, 0.6, 0.8, 0.6, -1, -1, -1]"
3,17950,"[[12, 5, 10, 7, 13, 18, 24, 18, 3, 18, 20, 18,...","[[662, 248, 4], [133, 405, 1090, 2011, 4], [13...","[[7, 5, 8], [5, 10, 7, 5, 8], [7, 5, 5, 3, 4, ...","[0.20688685950461458, 0.1460481802533566, 0.35...","[0.21863912823111997, 0.1030009415010439, 0.04...","[11, 2, 2, -1, -1, 4, -1, -1, -1]",7,97,50,"[0.36666666666666664, 0.3333333333333333, 0.33..."
4,9785,"[[2, 5, 10, 2, 5, 4, 2, 5, 4, 14, 2, 11, 7, 12...","[[90, 271, 131, 84, 4], [190, 108, 150, 814, 8...","[[16, 7, 7, 5, 8], [18, 11, 5, 6, 11, 12, 3, 6...","[0.7090757475733677, 0.48555603009134607, 0.79...","[0.5980343980343981, 0.5082625502524476, 0.090...","[1, 1, -1, -1, -1, -1, 1, 1, 1]",4,97,50,"[0.3333333333333333, 0.3333333333333333, -1, -..."
...,...,...,...,...,...,...,...,...,...,...,...
9508,6877,"[[4, 2, 5, 4, 5, 11, 20, 11, 18, 16, 7, 4, 29,...","[[662, 2552, 736, 281, 165, 319, 106, 4], [255...","[[7, 5, 10, 3, 12, 7, 5, 8], [5, 20, 3, 20, 5,...","[0.6420710849952196, 0.4709075862939792, 0.767...","[0.5665003864259118, 0.4089167978068259, 0.093...","[2, 2, -1, -1, -1, -1, 2, 1, 1]",3,97,50,"[0.6666666666666666, 0.6666666666666666, -1, -..."
9509,3661,"[[5, 28, 2, 5, 4, 12, 7, 3, 4, 2, 7, 5, 8], [1...","[[218, 125, 4], [122, 72, 73, 340, 1007, 124, ...","[[5, 3, 8], [15, 20, 5, 20, 5, 3, 3, 16, 5, 8]...","[0.7470330548911367, 0.5322450678768516, 0.688...","[0.7336936048109188, 0.3952062286100636, 0.031...","[4, 5, 4, 4, 5, 4, -1, -1, -1]",2,97,50,"[0.6, 0.8, 0.6, 0.6, 0.8, 0.6, -1, -1, -1]"
9510,10364,"[[4, 15, 20, 11, 24, 5, 6, 11, 13, 18, 24, 4, ...","[[90, 271, 131, 84, 4], [190, 108, 150, 814, 8...","[[16, 7, 7, 5, 8], [18, 11, 5, 6, 11, 12, 3, 6...","[0.6775704345317742, 0.5015464633073149, 0.707...","[0.44895975111802455, 0.27287771704373115, 0.1...","[1, 1, -1, -1, -1, -1, 1, 1, 1]",4,97,50,"[0.3333333333333333, 0.3333333333333333, -1, -..."
9511,18367,"[[9, 5, 14, 5, 6, 14, 3, 22, 5, 6, 15, 8], [14...","[[662, 248, 4], [133, 405, 1090, 2011, 4], [13...","[[7, 5, 8], [5, 10, 7, 5, 8], [7, 5, 5, 3, 4, ...","[0.4465909693384573, 0.375902607097603, 0.5005...","[0.2542676650242186, 0.08406393382610401, 0.09...","[18, 5, 5, -1, -1, 4, -1, -1, -1]",7,97,50,"[0.6, 0.8333333333333334, 0.8333333333333334, ..."


In [None]:
# train_features_list = [X_train_pos, X_train_prompt, X_train_prompt_pos, X_train_linguistic_features, X_train_readability]
X_train_pos

In [None]:
Y_train[0]

In [None]:
import keras
print(keras.__version__)

In [10]:
train_features_list = [X_train_pos, X_train_prompt, X_train_prompt_pos, X_train_linguistic_features, X_train_readability]
dev_features_list = [X_dev_pos, X_dev_prompt, X_dev_prompt_pos, X_dev_linguistic_features, X_dev_readability]
test_features_list = [X_test_pos, X_test_prompt, X_test_prompt_pos, X_test_linguistic_features, X_test_readability]

model = build_ProTACT(len(pos_vocab), len(word_vocab), max_sentnum, max_sentlen, 
                  X_train_readability.shape[1],
                  X_train_linguistic_features.shape[1],
                  configs, Y_train.shape[1], num_heads, embed_table)

evaluator = AllAttEvaluator(test_prompt_id, dev_data['prompt_ids'], test_data['prompt_ids'], dev_features_list,
                            test_features_list, Y_dev, Y_test, seed)

evaluator.evaluate(model, -1, print_info=True)



[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 496ms/step
[1m56/56[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 230ms/step
CURRENT EPOCH: -1
[DEV] AVG QWK: 0.004
[DEV] score QWK: 0.083
[DEV] content QWK: 0.109
[DEV] organization QWK: -0.128
[DEV] word_choice QWK: -0.086
[DEV] sentence_fluency QWK: 0.036
[DEV] conventions QWK: -0.081
[DEV] prompt_adherence QWK: 0.063
[DEV] language QWK: 0.018
[DEV] narrativity QWK: 0.021
------------------------
[TEST] AVG QWK: -0.0
[TEST] score QWK: 0.185
[TEST] content QWK: 0.021
[TEST] organization QWK: -0.096
[TEST] word_choice QWK: -0.019
[TEST] sentence_fluency QWK: -0.02
[TEST] conventions QWK: -0.072
------------------------
[BEST TEST] AVG QWK: -0.0, {epoch}: -1
[BEST TEST] score QWK: 0.185
[BEST TEST] content QWK: 0.021
[BEST TEST] organization QWK: -0.096
[BEST TEST] word_choice QWK: -0.019
[BEST TEST] sentence_fluency QWK: -0.02
[BEST TEST] conventions QWK: -0.072
-------------------------------------------------

In [None]:
model.summary()

In [None]:
# class CustomHistory(keras.callbacks.Callback):
#     def init(self):
#         self.train_loss = []
#         self.val_loss = []
#         self.train_acc = []
#         self.val_acc = []        
        
#     def on_epoch_end(self, batch, logs={}):
#         self.train_loss.append(logs.get('loss'))
#         self.val_loss.append(logs.get('val_loss'))
#         self.train_acc.append(logs.get('acc'))
#         self.val_acc.append(logs.get('val_acc'))
# custom_hist = CustomHistory()
# custom_hist.init() 

#  for ii in range(epochs):
#     print('Epoch %s/%s' % (str(ii + 1), epochs))
#     start_time = time.time()
#     model.fit(
#         train_features_list,
#         Y_train, batch_size=batch_size, epochs=5, verbose=0, shuffle=True, validation_data=(dev_features_list,Y_dev),callbacks=[custom_hist,checkpoint])
#     tt_time = time.time() - start_time
#     print("Training one epoch in %.3f s" % tt_time)
#     evaluator.evaluate(model, ii + 1)
#     print("Train Loss: ", custom_hist.train_loss[-1], "|| Val Loss: ", custom_hist.val_loss[-1])

# evaluator.print_final_info()

'''# show the loss as the graph
fig, loss_graph = plt.subplots()
loss_graph.plot(custom_hist.train_loss,'y',label='train loss')
loss_graph.plot(custom_hist.val_loss,'r',label='val loss')
loss_graph.set_xlabel('epoch')
loss_graph.set_ylabel('loss')
plt.savefig(str('images/protact/test_prompt_'+ str(test_prompt_id) + '_seed_' + str(seed) + '_loss.png'))'''

In [None]:
# 저장한 체크포인트 있다면: 이어서 학습 시작할때 
# Checkpoint 폴더 안에 있는 .h5 파일 지울것
# model.load_weights('Checkpoint/tensor{epoch}')

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    # epoch 마다 파일명 다르게 저장
    filepath='Checkpoint/bestmodel{epoch}.weights.h5',

    # epoch 마다 weights 들만 저장
    save_freq='epoch',
    save_weights_only = True,

    # validation accruary 가 최대일때만 저장 
    monitor='val_loss',
    mode='min'
)

In [None]:
class CustomHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.train_loss = []
        self.val_loss = []
        self.epoch_times = []

    def on_epoch_begin(self, epoch, logs=None):
        self.start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        self.train_loss.append(logs.get('loss'))
        self.val_loss.append(logs.get('val_loss'))
        epoch_time = time.time() - self.start_time
        self.epoch_times.append(epoch_time)
        print(f"Epoch {epoch + 1}: Train Loss: {logs.get('loss')} || Val Loss: {logs.get('val_loss')}")
        print(f"Epoch {epoch + 1} completed in {epoch_time:.3f} seconds")

        # Evaluate the model (you might need to adjust this to your specific evaluation function)
        evaluator.evaluate(self.model, epoch + 1)

custom_hist = CustomHistory()
model.fit(
    train_features_list,
    Y_train,
    batch_size=batch_size,
    epochs=epochs,
    verbose=1,
    shuffle=True,
    validation_data=(dev_features_list, Y_dev),
    callbacks=[custom_hist, checkpoint]
)

In [None]:
# 실행 X
# TEST: 위에서 진행된곳 까지 결과가 같다.
model.load_weights('Checkpoint/bestmodel1.h5')
evaluator.evaluate(model,1)