In [22]:
import os, os.path
import numpy as np
import cv2

from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2

from torch.utils.data import DataLoader

from efficientnet_pytorch import EfficientNet
from tqdm.auto import tqdm

import ttach as tta

import time

In [23]:
test_data_path = '../test_dataset'

In [24]:
class EffNetModel(nn.Module):
    def __init__(self, model_name):
        super(EffNetModel, self).__init__()

        self.backbone = EfficientNet.from_pretrained(model_name, num_classes=150)
        
    def forward(self, x):
        x = self.backbone(x)
        
        return x

In [25]:
test_transforms = A.Compose([
        A.Resize(256, 256),
        A.CenterCrop(224, 224),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

# Test Time Augmentation을 위한 transform 목록
# 3개의 transform이 있으므로 
# 하나의 이미지 데이터당 8개의 이미지가 생성됨
tta_transforms = tta.Compose([
    tta.Rotate90(angles=[0, 90]),
    tta.HorizontalFlip(),
    tta.Multiply(factors=[0.7, 1]),
])

In [26]:
num_to_class = ['갈비구이', '갈치구이', '고등어구이', '곱창구이', '닭갈비', '더덕구이', '떡갈비', '불고기', '삼겹살', '장어구이', '조개구이', '조기구이', '황태구이', '훈제오리', '계란국', '떡국_만두국', '무국', '미역국', '북엇국', '시래기국', '육개장', '콩나물국', '과메기', '양념치킨', '젓갈', '콩자반', '편육', '피자', '후라이드치킨', '갓김치', '깍두기', '나박김치', '무생채', '배추김치', '백김치', '부추김치', '열무김치', '오이소박이', '총각김치', '파김치', '가지볶음', '고사리나물', '미역줄기볶음', '숙주나물', '시금치나물', '애호박볶음', '경단', '꿀떡', '송편', '만두', '라면', '막국수', '물냉면', '비빔냉면', '수제비', '열무국수', '잔치국수', '짜장면', '짬뽕', '쫄면', '칼국수', '콩국수', '꽈리고추무침', '도라지무침', '도토리묵', '잡채', '콩나물무침', '홍어무침', '회무침', '김밥', '김치볶음밥', '누룽지', '비빔밥', '새우볶음밥', '알밥', '유부초밥', '잡곡밥', '주먹밥', '감자채볶음', '건새우볶음', '고추장진미채볶음', '두부김치', '떡볶이', '라볶이', '멸치볶음', '소세지볶음', '어묵볶음', '오징어채볶음', '제육볶음', '주꾸미볶음', '보쌈', '수정과', '식혜', '간장게장', '양념게장', '깻잎장아찌', '떡꼬치', '감자전', '계란말이', '계란후라이', '김치전', '동그랑땡', '생선전', '파전', '호박전', '곱창전골', '갈치조림', '감자조림', '고등어조림', '꽁치조림', '두부조림', '땅콩조림', '메추리알장조림', '연근조림', '우엉조림', '장조림', '코다리조림', '전복죽', '호박죽', '김치찌개', '닭계장', '동태찌개', '된장찌개', '순두부찌개', '갈비찜', '계란찜', '김치찜', '꼬막찜', '닭볶음탕', '수육', '순대', '족발', '찜닭', '해물찜', '갈비탕', '감자탕', '곰탕_설렁탕', '매운탕', '삼계탕', '추어탕', '고추튀김', '새우튀김', '오징어튀김', '약과', '약식', '한과', '멍게', '산낙지', '물회', '육회']

def predict():
    valid_images = [".jpg",".png"]

    result = {}
    
    modelA_name = 'efficientnet-b1'
    modelB_name = 'efficientnet-b0'
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # 모델을 선언하고
    modelA = EffNetModel(modelA_name)
    modelB = EffNetModel(modelB_name)
    # 저장된 가중치들을 로드한다
    modelA.load_state_dict(torch.load('../models/1_efficientnet-b1_0.8407_0.7333_epoch_25.pth'))
    modelB.load_state_dict(torch.load('../models/1_efficientnet-b0_0.8429_0.6997_epoch_21.pth'))
    
    modelA.to(device)
    modelB.to(device)
    
    modelA.eval()
    modelB.eval()
    
    for f in os.listdir(test_data_path):
        
        # 확장자를 검사한다
        ext = os.path.splitext(f)[1]
        if ext.lower() not in valid_images:
            continue
        
        # 8개의 이미지를 예측한 결과를 리스트에 담는다
        preds_list = []
        
        # 확장자를 제외한 파일의 이름을 확인한다
        name = os.path.basename(f)
        hash_num = name[4:9]
        
        img_path = os.path.join(test_data_path, f)
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        augmented = test_transforms(image=img)
        # img format : [C, H, W]
        img = augmented['image']
        # img format : [B, C, H, W]
        batch_img = img.unsqueeze(0)
        
        for transformer in tta_transforms:
            # img format : [B, C, H, W]
            trans_image = transformer.augment_image(batch_img)
            with torch.no_grad():
                trans_image = trans_image.to(device)
                outputA = modelA(trans_image)
                outputB = modelB(trans_image)
                outputs = (outputA + outputB) / 2
                _, preds = torch.max(outputs, 1)
            
            preds_list.extend(preds.tolist())
        
        # 가장 많이 나온 클래스가 예측값으로 결정된다
        prediction = max(set(preds_list), key=preds_list.count)
            
        result[int(hash_num)] = num_to_class[prediction]
        
    return result

In [27]:
def print_accuracy():
    # True label
    true_labels = dict()

    with open(test_data_path + '/hash_table.txt', 'rt', encoding='cp949') as f:
        lines = f.readlines()

        for line in lines:
            line = line.strip()

            hash_num, label = tuple(line.split(', '))
            hash_num = int(hash_num)

            true_labels[hash_num] = label

    # Compare
    start = time.time()
    predicted_labels = predict()
    print(time.time()-start)
    try:
        cnt = 0

        for hash_num in true_labels.keys():
            if true_labels[hash_num] == predicted_labels[hash_num]:
                cnt = cnt + 1

        print(f"Accuracy: {cnt / len(true_labels)}")

    except Exception as e:
        print("predict()의 반환 양식이 올바르지 않습니다.")

In [28]:
print_accuracy()

Loaded pretrained weights for efficientnet-b1
Loaded pretrained weights for efficientnet-b0
34.86095380783081
Accuracy: 0.9098360655737705
