# 1. Import Packages

In [None]:
import tensorflow as tf 
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
from sklearn.model_selection import train_test_split

import pickle
import numpy as np
import re
import os

from osAdvanced import File_Control
from ProgressBar import Progress_Bar

import matplotlib.pyplot as plt
import matplotlib.font_manager as fm

# 2. 테스트 데이터 로드
train 데이터와 validation 데이터는 필요 없으므로 메모리에서 삭제.

## 2.1 테스트 데이터 로드해서

In [None]:
with open('/raid/korean_food_pkl/preprocessed_data_0603_ResNet50V2_cropped.pkl', 'rb') as f:
    data = pickle.load(f) # 단 한줄씩 읽어옴

x = data[0]
y = data[1]

x_train, x_valtest, y_train, y_valtest = train_test_split(x, y, test_size = 0.3, random_state=1)
x_val, x_test, y_val, y_test = train_test_split(x_valtest, y_valtest, test_size = 0.5, random_state=1)

print("train size : ", y_train.shape[0])
print("test size : ", y_test.shape[0])
print("validation size : ", y_val.shape[0])

del x_train
del y_train

del x_val
del y_val

## 2.2 라벨 정보 불러오기

In [None]:
dataset = File_Control.searchAllFilesInDirectoryByDir("/raid/korean_food_cropped/", "jpg")
label_dict = {}
p = re.compile("\/[가-힣]*\/.*\/")
for i in range(len(dataset)):
    label = p.search(dataset[i][0]).group()
    label = label.replace("/", "|")
    label_dict[str(i)] = label[1:len(label)-1] 

## 2.3 데이터 확인
데이터를 제대로 불러왔는지 확인

In [None]:
import random 

"""
plt.rcParams['axes.unicode_minus'] = False
path = '/usr/share/fonts/truetype/nanum/NanumGothic.ttf'
font_name = mpl.font_manager.FontProperties(fname=path).get_name()
plt.rc('font', family=font_name)
print([f.fname for f in matplotlib.font_manager.fontManager.ttflist])
"""

path = '/usr/share/fonts/truetype/nanum/NanumGothic.ttf'
prop = fm.FontProperties(fname=path, size=18)

w = 10
h = 10
columns = 4
rows = 2
fig = plt.figure(figsize=(20, 10))

ax = []
for i in range(columns*rows):
    img_index = random.randint(0, len(y_test))
    img = x_test[img_index]+100/255
    #img = img[:,:,::-1]
    ax.append(fig.add_subplot(rows, columns, i+1))
    y_str = str(np.argmax(y_test[img_index]))
    ax[-1].set_title(label_dict[y_str], fontproperties=prop)  # set title
    plt.imshow(img)
plt.show()


# 3. 모델 로드
2번 GPU 사용.

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

IMAGE_SHAPE = (224, 224, 3)

model = ResNet50V2(
    include_top=True,
    input_shape=IMAGE_SHAPE,
    weights=None,
    classes=150,
)

model.summary()

# 4. 모델 가중치 로드

In [None]:
WEIGHTS_PATH = "./checkpoints_0609_ResNetV2_Original_cropped/ckpt"

model.load_weights(WEIGHTS_PATH)

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 5. 모델  평가
테스트 데이터셋으로 모델을 평가함.

In [None]:
eval = model.evaluate(x_test, y_test)

# 6. 테스트 결과 확인

In [None]:
y_predicted = model.predict(x_test)

## 6.1 Array내의 값중 가장 큰 값 n개를 추출하는 함수

In [None]:
def argmax_top_n(array, n):
    argmax_n = []

    ind = np.argpartition(array, abs(n)*(-1))[abs(n)*(-1):]
    top4 = array[ind]
    for i in range(n):
        argmax_n.append(np.where(array == top4[i])[0][0])
    return np.array(argmax_n)

## 6.2 테스트 결과 출력
Predict된 값 중 가장 확률이 높은 값 2개만 출력.

In [None]:
import random 

path = '/usr/share/fonts/truetype/nanum/NanumGothic.ttf'
prop = fm.FontProperties(fname=path, size=18)

print(y_predicted.shape)

w = 10
h = 13
columns = 5
rows = 2
fig = plt.figure(figsize=(20, 10))

n = 2

ax = []
plt.subplots_adjust(left=0.125, bottom=0.1, right=0.9, top=0.9, wspace=0.2, hspace=0.3)
for i in range(columns*rows):
    img_index = random.randint(0, len(x_test))
    img = x_test[img_index]+100/255

    ax.append(fig.add_subplot(rows, columns, i+1))
    
    result_topn = argmax_top_n(y_predicted[img_index], n)
    
    result_str = "정답 : " + label_dict[str(np.argmax(y_test[img_index]))]
    for i in range(n):
        i_str = str(result_topn[i])
        y_predicted_str = label_dict[i_str]
        result_str = result_str + "\n" + str(i+1) + " : " + y_predicted_str
    if len(result_str.split(label_dict[str(np.argmax(y_test[img_index]))])) > 2:
        c = 'g'
    else:
        c = 'r'
    ax[-1].set_title(result_str, fontproperties=prop, color=c)  # set title
    plt.imshow(img)
plt.show()