In [2]:
!pip install timm
import torch
import torch.nn as nn
from torchvision import transforms
import timm
import unicodedata
import cv2
from PIL import Image

class VerificationModel:
  def __init__(self, model_path, pickle_path):
    self.model = timm.create_model('efficientnet_b0', pretrained=True)
    self.model.classifier = nn.Sequential(
        nn.Linear(1280, 512),
        nn.ReLU(),
        nn.Linear(512, 128),
    )
    checkpoint = torch.load(model_path)
    self.model.load_state_dict(checkpoint)
    self.model.eval()  # Set the model to inference mode

    with open(pickle_path, 'rb') as handle:
        self.image_test_dict = pickle.load(handle)

    self.transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

  def data_load(self, input_path):
    data = cv2.imread(input_path)
    data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
    data = cv2.resize(data, (224, 224))
    data = Image.fromarray(data)
    data = self.transform(data)
    return data.unsqueeze(dim=0)

  def predict(self, input_path, keyword):
    data = self.data_load(input_path)
    output1 = self.image_test_dict[keyword]
    output2 = self.model(data)
    n = output1.shape[0]
    output2 = torch.cat([output2] * n)

    count = sum(torch.pow(output1 - output2, 2).sum(dim=1).sqrt() < 0.68)
    result = (count >= 1).int()
    return result.item()

# input_path = 입력 이미지 경로
# keyword = 해당 장소

model_path = './best_model.pth'
pickle_path = "image_test_dict.pickle"

inference_model = VerificationModel(model_path, pickle_path)
result = inference_model.predict(input_path, keyword)
print(result)

NameError: ignored