In [24]:
import torch
import torch.nn as nn
import sys
from config import input_size, proposalN, channels
from utils.read_dataset import read_dataset
from utils.auto_laod_resume import auto_load_resume
from networks.model import MainNet
from torchvision import transforms

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [25]:
DEVICE = torch.device("cpu")


In [26]:
root = './datasets/FGVC-aircraft'  # dataset path
# model path
pth_path = "./models/air_epoch146.pth"
num_classes = 100

batch_size = 10

In [27]:
_, testloader = read_dataset(input_size, batch_size, root, "Aircraft")

Loading Aircraft trainset
Loading Aircraft testset


In [28]:
model = MainNet(proposalN=proposalN, num_classes=num_classes, channels=channels)

In [29]:
model = model.to(DEVICE)
criterion = nn.CrossEntropyLoss()

In [30]:
if os.path.exists(pth_path):
    epoch = auto_load_resume(model, pth_path, status='test')
else:
    sys.exit('There is not a pth exist.')

Load model from ./models/air_epoch146.pth
Resume from ./models/air_epoch146.pth


In [31]:
# 이미지 ID와 파일명 매핑
images = {idx: path.split("/")[-1].split(".")[0] for path, idx in testloader.dataset.test_img_label}

# 이미지 라벨 읽기
with open("image_label.txt", "r") as f:
    labels = {}
    for line in f:
        parts = line.strip().split()
        labels[parts[0]] = ' '.join(parts[1:])

# 실제 라벨 매핑
real = {idx: labels[path] for idx, path in images.items() if path in labels}

# 결과 출력
print(real)
print(images[19])

{0: '707-320', 1: '727-200', 2: '737-200', 3: '737-300', 4: '737-400', 5: '737-500', 6: '737-600', 7: '737-700', 8: '737-800', 9: '737-900', 10: '747-100', 11: '747-200', 12: '747-300', 13: '747-400', 14: '757-200', 15: '757-300', 16: '767-200', 17: '767-300', 18: '767-400', 19: '777-200', 20: '777-300', 21: 'A300B4', 22: 'A310', 23: 'A318', 24: 'A319', 25: 'A320', 26: 'A321', 27: 'A330-200', 28: 'A330-300', 29: 'A340-200', 30: 'A340-300', 31: 'A340-500', 32: 'A340-600', 33: 'A380', 34: 'ATR-42', 35: 'ATR-72', 36: 'An-12', 37: 'BAE 146-200', 38: 'BAE 146-300', 39: 'BAE-125', 40: 'Beechcraft 1900', 41: 'Boeing 717', 42: 'C-130', 43: 'C-47', 44: 'CRJ-200', 45: 'CRJ-700', 46: 'CRJ-900', 47: 'Cessna 172', 48: 'Cessna 208', 49: 'Cessna 525', 50: 'Cessna 560', 51: 'Challenger 600', 52: 'DC-10', 53: 'DC-3', 54: 'DC-6', 55: 'DC-8', 56: 'DC-9-30', 57: 'DH-82', 58: 'DHC-1', 59: 'DHC-6', 60: 'DHC-8-100', 61: 'DHC-8-300', 62: 'DR-400', 63: 'Dornier 328', 64: 'E-170', 65: 'E-190', 66: 'E-195', 67: 

In [32]:
import imageio
import numpy as np
from PIL import Image

def preprocess_image(img):
    if len(img.shape) == 2:
        img = np.stack([img] * 3, 2)
    img = Image.fromarray(img, mode='RGB')
    img = transforms.Resize((448, 448), Image.BILINEAR)(img)
    img = transforms.ToTensor()(img)
    img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
    img = img.unsqueeze(0)
    return img

img = preprocess_image(imageio.imread("image.png"))

model.eval()

with torch.no_grad():
    output = model(img, None, None, "test", "cpu")[-2:]
    pred = output[0].max(1, keepdim=True)[1]
    print(pred.numpy()[0][0], real[pred.numpy()[0][0]])

33 A380
