# 测试集图像分类预测结果

## 导入工具包

In [1]:
import os
from tqdm import tqdm

import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn.functional as F

# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

device cuda:0


## 图像预处理

In [2]:
from torchvision import transforms

# # 训练集图像预处理：缩放裁剪、图像增强、转 Tensor、归一化
# train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
#                                       transforms.RandomHorizontalFlip(),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#                                      ])

# 测试集图像预处理-RCTN：缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

## 载入测试集（和训练代码教程相同）

In [3]:
# 数据集文件夹路径
dataset_dir = 'dataset_split'
test_path = os.path.join(dataset_dir, 'val')
from torchvision import datasets
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
# 载入类别名称 和 ID索引号 的映射字典
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()
# 获得类别名称
classes = list(idx_to_labels.values())
print(classes)

测试集图像数量 576
类别个数 7
各类别名称 ['乌龟', '仓鼠', '兔子', '狗', '猫', '金鱼', '鹦鹉']
['乌龟', '仓鼠', '兔子', '狗', '猫', '金鱼', '鹦鹉']


## 导入训练好的模型

In [5]:
model = torch.load('checkpoint/best-0.972.pth')
model = model.eval().to(device)

## 表格A-测试集图像路径及标注

In [6]:
test_dataset.imgs[:10]

[('dataset_split/val/乌龟/0.jpg', 0),
 ('dataset_split/val/乌龟/106.jpg', 0),
 ('dataset_split/val/乌龟/115.jpg', 0),
 ('dataset_split/val/乌龟/123.jpg', 0),
 ('dataset_split/val/乌龟/164.jpg', 0),
 ('dataset_split/val/乌龟/165.jpg', 0),
 ('dataset_split/val/乌龟/166.jpg', 0),
 ('dataset_split/val/乌龟/177.jpg', 0),
 ('dataset_split/val/乌龟/183.jpg', 0),
 ('dataset_split/val/乌龟/186.jpg', 0)]

In [7]:
img_paths = [each[0] for each in test_dataset.imgs]

In [8]:
df = pd.DataFrame()
df['图像路径'] = img_paths
df['标注类别ID'] = test_dataset.targets
df['标注类别名称'] = [idx_to_labels[ID] for ID in test_dataset.targets]

In [9]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称
0,dataset_split/val/乌龟/0.jpg,0,乌龟
1,dataset_split/val/乌龟/106.jpg,0,乌龟
2,dataset_split/val/乌龟/115.jpg,0,乌龟
3,dataset_split/val/乌龟/123.jpg,0,乌龟
4,dataset_split/val/乌龟/164.jpg,0,乌龟
...,...,...,...
571,dataset_split/val/鹦鹉/77.jpg,6,鹦鹉
572,dataset_split/val/鹦鹉/85.jpeg,6,鹦鹉
573,dataset_split/val/鹦鹉/88.jpeg,6,鹦鹉
574,dataset_split/val/鹦鹉/90.jpg,6,鹦鹉


## 表格B-测试集每张图像的图像分类预测结果，以及各类别置信度

In [10]:
# 记录 top-n 预测结果
n = 3

In [11]:
df_pred = pd.DataFrame()
for idx, row in tqdm(df.iterrows()):
    img_path = row['图像路径']
    img_pil = Image.open(img_path).convert('RGB')
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测，得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

    pred_dict = {}

    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    
    # top-n 预测结果
    for i in range(1, n+1):
        pred_dict['top-{}-预测ID'.format(i)] = pred_ids[i-1]
        pred_dict['top-{}-预测名称'.format(i)] = idx_to_labels[pred_ids[i-1]]
    pred_dict['top-n预测正确'] = row['标注类别ID'] in pred_ids
    # 每个类别的预测置信度
    for idx, each in enumerate(classes):
        pred_dict['{}-预测置信度'.format(each)] = pred_softmax[0][idx].cpu().detach().numpy()
        
    df_pred = df_pred.append(pred_dict, ignore_index=True)

576it [00:20, 28.21it/s]


In [13]:
df_pred

Unnamed: 0,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,乌龟-预测置信度,仓鼠-预测置信度,兔子-预测置信度,狗-预测置信度,猫-预测置信度,金鱼-预测置信度,鹦鹉-预测置信度
0,0.0,乌龟,5.0,金鱼,3.0,狗,1.0,0.9997904,1.3496485e-05,5.6886364e-07,1.8777657e-05,1.4501727e-05,0.00015052377,1.1673086e-05
1,0.0,乌龟,6.0,鹦鹉,5.0,金鱼,1.0,0.9705659,0.0003541137,0.00090848963,0.00064750965,0.0015962358,0.0054022856,0.020525517
2,0.0,乌龟,5.0,金鱼,1.0,仓鼠,1.0,0.9631217,0.0047347657,0.000522849,0.0021564714,0.0009090698,0.027109139,0.0014460153
3,0.0,乌龟,5.0,金鱼,6.0,鹦鹉,1.0,0.9915172,0.0008954164,0.00026855883,0.00027924875,8.785519e-05,0.003941732,0.0030100478
4,0.0,乌龟,5.0,金鱼,3.0,狗,1.0,0.9995976,9.74761e-06,6.1473984e-05,8.786799e-05,2.2297016e-05,0.00019073513,3.0321597e-05
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
571,6.0,鹦鹉,1.0,仓鼠,2.0,兔子,1.0,0.00021789361,0.1331616,0.07737693,0.0019512477,0.0072641987,0.002794568,0.7772335
572,6.0,鹦鹉,1.0,仓鼠,5.0,金鱼,1.0,0.00029225982,0.00074982306,0.00041521664,3.5584948e-05,8.747625e-05,0.0006452718,0.9977743
573,6.0,鹦鹉,1.0,仓鼠,2.0,兔子,1.0,0.00056746375,0.083318576,0.020077255,0.0020920448,0.013484335,0.011416547,0.8690438
574,6.0,鹦鹉,1.0,仓鼠,2.0,兔子,1.0,0.00022838438,0.0077728555,0.005680962,0.0010858163,0.0006157943,0.00090340787,0.9837128


## 拼接AB两张表格

In [14]:
df = pd.concat([df, df_pred], axis=1)

In [15]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,乌龟-预测置信度,仓鼠-预测置信度,兔子-预测置信度,狗-预测置信度,猫-预测置信度,金鱼-预测置信度,鹦鹉-预测置信度
0,dataset_split/val/乌龟/0.jpg,0,乌龟,0.0,乌龟,5.0,金鱼,3.0,狗,1.0,0.9997904,1.3496485e-05,5.6886364e-07,1.8777657e-05,1.4501727e-05,0.00015052377,1.1673086e-05
1,dataset_split/val/乌龟/106.jpg,0,乌龟,0.0,乌龟,6.0,鹦鹉,5.0,金鱼,1.0,0.9705659,0.0003541137,0.00090848963,0.00064750965,0.0015962358,0.0054022856,0.020525517
2,dataset_split/val/乌龟/115.jpg,0,乌龟,0.0,乌龟,5.0,金鱼,1.0,仓鼠,1.0,0.9631217,0.0047347657,0.000522849,0.0021564714,0.0009090698,0.027109139,0.0014460153
3,dataset_split/val/乌龟/123.jpg,0,乌龟,0.0,乌龟,5.0,金鱼,6.0,鹦鹉,1.0,0.9915172,0.0008954164,0.00026855883,0.00027924875,8.785519e-05,0.003941732,0.0030100478
4,dataset_split/val/乌龟/164.jpg,0,乌龟,0.0,乌龟,5.0,金鱼,3.0,狗,1.0,0.9995976,9.74761e-06,6.1473984e-05,8.786799e-05,2.2297016e-05,0.00019073513,3.0321597e-05
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
571,dataset_split/val/鹦鹉/77.jpg,6,鹦鹉,6.0,鹦鹉,1.0,仓鼠,2.0,兔子,1.0,0.00021789361,0.1331616,0.07737693,0.0019512477,0.0072641987,0.002794568,0.7772335
572,dataset_split/val/鹦鹉/85.jpeg,6,鹦鹉,6.0,鹦鹉,1.0,仓鼠,5.0,金鱼,1.0,0.00029225982,0.00074982306,0.00041521664,3.5584948e-05,8.747625e-05,0.0006452718,0.9977743
573,dataset_split/val/鹦鹉/88.jpeg,6,鹦鹉,6.0,鹦鹉,1.0,仓鼠,2.0,兔子,1.0,0.00056746375,0.083318576,0.020077255,0.0020920448,0.013484335,0.011416547,0.8690438
574,dataset_split/val/鹦鹉/90.jpg,6,鹦鹉,6.0,鹦鹉,1.0,仓鼠,2.0,兔子,1.0,0.00022838438,0.0077728555,0.005680962,0.0010858163,0.0006157943,0.00090340787,0.9837128


## 导出完整表格

In [16]:
df.to_csv('测试集预测结果.csv', index=False)