In [None]:
import os
import torch
import numpy as np
import warnings
from settings import OUT_DIR, IMAGE_HEIGHT, IMAGE_WIDTH
from core.models.nts_net import NTSModel
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
from torchvision.datasets import FGVCAircraft
from PIL import Image

warnings.filterwarnings("ignore")

### Define the test function

In [None]:
def test(model_file, test_data, batch_size):

    # Identify device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the model
    ckpt = torch.load(model_file)
    model = NTSModel(top_n=ckpt["proposal_num"], n_classes=ckpt["n_classes"], image_height=IMAGE_HEIGHT, image_width=IMAGE_WIDTH).to(device)
    model.load_state_dict(ckpt["state_dict"])
    model = nn.DataParallel(model)

    # Setup dataloader
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=2)

    y_pred = []
    y_true = []
    with tqdm(total=(len(test_loader))) as pbar:
        with torch.no_grad():
            for i, (inputs, labels) in enumerate(test_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                batch_size = inputs.size(0)
                y_true += labels.cpu().detach().numpy().astype(int).tolist()
                _, concat_logits, _, _, _ = model(inputs)
                y_pred += concat_logits.argmax(dim=1).cpu().detach().numpy().astype(int).tolist()

                pbar.update(1)

    y_pred = np.array(y_pred)
    y_true = np.array(y_true)

    accuracy = np.mean(y_pred == y_true)
    print("Accuracy:", accuracy)

    return y_true, y_pred


### Load the test data

In [None]:
transform = Compose([
  Resize((IMAGE_HEIGHT, IMAGE_WIDTH), Image.BILINEAR),
  ToTensor(),
])

test_data = FGVCAircraft(root="data", split="test", transform=transform, download=True)

print("Test data size:", len(test_data))

### Test the model

In [None]:
BATCH_SIZE = 8

# Load model weights
model_file = os.path.join(OUT_DIR, "latest_model.ckpt")

y_true, y_pred = test(model_file, test_data, batch_size=BATCH_SIZE)

### Create classification report

In [None]:
import pandas as pd
from sklearn.metrics import classification_report
from sklearn.preprocessing import OneHotEncoder

classes = test_data.classes
ohe = OneHotEncoder()
y_true_enc = ohe.fit_transform(y_true.reshape(-1, 1)).toarray()
y_pred_enc = ohe.transform(y_pred.reshape(-1, 1)).toarray()

report = classification_report(y_true_enc, y_pred_enc, target_names=classes, output_dict=True)

df_report = pd.DataFrame(report).transpose()[:-4]

#### Top 10 performers

In [None]:
df_report.sort_values("recall", ascending=False).head(10)

#### Top 10 worst performers

In [None]:
df_report.sort_values("recall", ascending=True).head(10)