# 测试集图像分类预测结果

使用训练好的图像分类模型，预测测试集的所有图像，得到预测结果表格。

同济子豪兄：https://space.bilibili.com/1900783

[代码运行云GPU环境](https://featurize.cn/?s=d7ce99f842414bfcaea5662a97581bd1)：GPU RTX 3060、CUDA v11.2

## 导入工具包

In [1]:
import os
from tqdm import tqdm

import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn.functional as F

# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

device cuda:0


## 图像预处理

In [2]:
from torchvision import transforms

# # 训练集图像预处理：缩放裁剪、图像增强、转 Tensor、归一化
# train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
#                                       transforms.RandomHorizontalFlip(),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#                                      ])

# 测试集图像预处理-RCTN：缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

## 载入测试集（和训练代码教程相同）

In [4]:
# 数据集文件夹路径
dataset_dir = 'dataset_split'
test_path = os.path.join(dataset_dir, 'val')
from torchvision import datasets
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
# 载入类别名称 和 ID索引号 的映射字典
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()
# 获得类别名称
classes = list(idx_to_labels.values())
print(classes)

测试集图像数量 556
类别个数 16
各类别名称 ['丝瓜', '冬瓜', '南瓜', '木瓜', '猕猴桃', '甘蔗', '甜瓜', '白心火龙果', '胡萝卜', '芒果', '苦瓜', '草莓', '菠萝', '西红柿', '青苹果', '香蕉']
['丝瓜', '冬瓜', '南瓜', '木瓜', '猕猴桃', '甘蔗', '甜瓜', '白心火龙果', '胡萝卜', '芒果', '苦瓜', '草莓', '菠萝', '西红柿', '青苹果', '香蕉']


## 导入训练好的模型

In [7]:
model = torch.load('checkpoint/best-0.840.pth')
model = model.eval().to(device)

## 表格A-测试集图像路径及标注

In [8]:
test_dataset.imgs[:10]

[('dataset_split/val/丝瓜/107.jpg', 0),
 ('dataset_split/val/丝瓜/113.png', 0),
 ('dataset_split/val/丝瓜/125.jpg', 0),
 ('dataset_split/val/丝瓜/135.jpg', 0),
 ('dataset_split/val/丝瓜/167.jpg', 0),
 ('dataset_split/val/丝瓜/173.jpg', 0),
 ('dataset_split/val/丝瓜/174.jpeg', 0),
 ('dataset_split/val/丝瓜/175.png', 0),
 ('dataset_split/val/丝瓜/176.jpg', 0),
 ('dataset_split/val/丝瓜/177.jpg', 0)]

In [9]:
img_paths = [each[0] for each in test_dataset.imgs]

In [10]:
df = pd.DataFrame()
df['图像路径'] = img_paths
df['标注类别ID'] = test_dataset.targets
df['标注类别名称'] = [idx_to_labels[ID] for ID in test_dataset.targets]

In [11]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称
0,dataset_split/val/丝瓜/107.jpg,0,丝瓜
1,dataset_split/val/丝瓜/113.png,0,丝瓜
2,dataset_split/val/丝瓜/125.jpg,0,丝瓜
3,dataset_split/val/丝瓜/135.jpg,0,丝瓜
4,dataset_split/val/丝瓜/167.jpg,0,丝瓜
...,...,...,...
551,dataset_split/val/香蕉/86.png,15,香蕉
552,dataset_split/val/香蕉/94.jpg,15,香蕉
553,dataset_split/val/香蕉/96.jpg,15,香蕉
554,dataset_split/val/香蕉/97.jpg,15,香蕉


## 表格B-测试集每张图像的图像分类预测结果，以及各类别置信度

In [12]:
# 记录 top-n 预测结果
n = 3

In [13]:
df_pred = pd.DataFrame()
for idx, row in tqdm(df.iterrows()):
    img_path = row['图像路径']
    img_pil = Image.open(img_path).convert('RGB')
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测，得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

    pred_dict = {}

    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    
    # top-n 预测结果
    for i in range(1, n+1):
        pred_dict['top-{}-预测ID'.format(i)] = pred_ids[i-1]
        pred_dict['top-{}-预测名称'.format(i)] = idx_to_labels[pred_ids[i-1]]
    pred_dict['top-n预测正确'] = row['标注类别ID'] in pred_ids
    # 每个类别的预测置信度
    for idx, each in enumerate(classes):
        pred_dict['{}-预测置信度'.format(each)] = pred_softmax[0][idx].cpu().detach().numpy()
        
    df_pred = df_pred.append(pred_dict, ignore_index=True)

556it [00:12, 43.59it/s]


In [14]:
df_pred

Unnamed: 0,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,丝瓜-预测置信度,冬瓜-预测置信度,南瓜-预测置信度,...,甜瓜-预测置信度,白心火龙果-预测置信度,胡萝卜-预测置信度,芒果-预测置信度,苦瓜-预测置信度,草莓-预测置信度,菠萝-预测置信度,西红柿-预测置信度,青苹果-预测置信度,香蕉-预测置信度
0,10.0,苦瓜,3.0,木瓜,6.0,甜瓜,0.0,0.04840641,0.07282121,0.0021479204,...,0.11867881,0.0017725051,0.0033943707,0.013316912,0.52182674,0.00014369469,0.021278393,0.010869962,0.013985876,0.002235779
1,10.0,苦瓜,0.0,丝瓜,1.0,冬瓜,1.0,0.37343544,0.053470347,0.0013646453,...,0.026696296,0.002228043,0.0048406865,0.006601241,0.48405343,0.0019523422,0.0015444165,0.0063533513,0.0013095672,0.01580053
2,0.0,丝瓜,2.0,南瓜,10.0,苦瓜,1.0,0.40001786,0.06310639,0.12782323,...,0.059697114,0.00594394,0.0047240024,0.025729949,0.1101512,0.008134677,0.044341482,0.008649456,0.0021602323,0.014768641
3,10.0,苦瓜,3.0,木瓜,4.0,猕猴桃,0.0,0.085274145,0.03315321,0.079508156,...,0.03682621,0.0071949377,0.017246274,0.03286594,0.26644373,0.040147908,0.021891003,0.08000263,0.006402887,0.0044703064
4,0.0,丝瓜,1.0,冬瓜,3.0,木瓜,1.0,0.7372114,0.1965079,0.0016882734,...,0.0054939296,0.0006053373,0.0010448286,0.00093456806,0.017406022,0.0026436672,0.00015025907,0.00047128706,9.9715675e-05,0.007284635
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
551,15.0,香蕉,0.0,丝瓜,12.0,菠萝,1.0,0.2617197,0.013565229,0.022342164,...,0.018018581,0.044590365,0.0022651025,0.0033425468,0.051030938,0.0024886683,0.17084219,0.0007328393,0.0019967197,0.31624484
552,12.0,菠萝,6.0,甜瓜,4.0,猕猴桃,0.0,0.018425575,0.023954533,0.06655343,...,0.15769675,0.02511323,0.036050174,0.048695706,0.057719395,0.0068120975,0.2821566,0.09963151,0.0070520076,0.024425346
553,15.0,香蕉,0.0,丝瓜,8.0,胡萝卜,1.0,0.0005506438,0.00016553093,0.0002731422,...,0.00035602247,2.1813368e-05,0.00047610415,0.00030610847,1.828313e-05,5.250358e-06,2.7238044e-05,2.3757432e-06,1.7083017e-05,0.9974482
554,15.0,香蕉,0.0,丝瓜,7.0,白心火龙果,1.0,0.03978421,0.0047999634,0.00058869517,...,0.013944677,0.022109129,0.0011740384,0.0059756073,0.0022295625,0.0015292587,0.008131748,8.622511e-05,0.0014160236,0.89128304


## 拼接AB两张表格

In [15]:
df = pd.concat([df, df_pred], axis=1)

In [16]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,...,甜瓜-预测置信度,白心火龙果-预测置信度,胡萝卜-预测置信度,芒果-预测置信度,苦瓜-预测置信度,草莓-预测置信度,菠萝-预测置信度,西红柿-预测置信度,青苹果-预测置信度,香蕉-预测置信度
0,dataset_split/val/丝瓜/107.jpg,0,丝瓜,10.0,苦瓜,3.0,木瓜,6.0,甜瓜,0.0,...,0.11867881,0.0017725051,0.0033943707,0.013316912,0.52182674,0.00014369469,0.021278393,0.010869962,0.013985876,0.002235779
1,dataset_split/val/丝瓜/113.png,0,丝瓜,10.0,苦瓜,0.0,丝瓜,1.0,冬瓜,1.0,...,0.026696296,0.002228043,0.0048406865,0.006601241,0.48405343,0.0019523422,0.0015444165,0.0063533513,0.0013095672,0.01580053
2,dataset_split/val/丝瓜/125.jpg,0,丝瓜,0.0,丝瓜,2.0,南瓜,10.0,苦瓜,1.0,...,0.059697114,0.00594394,0.0047240024,0.025729949,0.1101512,0.008134677,0.044341482,0.008649456,0.0021602323,0.014768641
3,dataset_split/val/丝瓜/135.jpg,0,丝瓜,10.0,苦瓜,3.0,木瓜,4.0,猕猴桃,0.0,...,0.03682621,0.0071949377,0.017246274,0.03286594,0.26644373,0.040147908,0.021891003,0.08000263,0.006402887,0.0044703064
4,dataset_split/val/丝瓜/167.jpg,0,丝瓜,0.0,丝瓜,1.0,冬瓜,3.0,木瓜,1.0,...,0.0054939296,0.0006053373,0.0010448286,0.00093456806,0.017406022,0.0026436672,0.00015025907,0.00047128706,9.9715675e-05,0.007284635
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
551,dataset_split/val/香蕉/86.png,15,香蕉,15.0,香蕉,0.0,丝瓜,12.0,菠萝,1.0,...,0.018018581,0.044590365,0.0022651025,0.0033425468,0.051030938,0.0024886683,0.17084219,0.0007328393,0.0019967197,0.31624484
552,dataset_split/val/香蕉/94.jpg,15,香蕉,12.0,菠萝,6.0,甜瓜,4.0,猕猴桃,0.0,...,0.15769675,0.02511323,0.036050174,0.048695706,0.057719395,0.0068120975,0.2821566,0.09963151,0.0070520076,0.024425346
553,dataset_split/val/香蕉/96.jpg,15,香蕉,15.0,香蕉,0.0,丝瓜,8.0,胡萝卜,1.0,...,0.00035602247,2.1813368e-05,0.00047610415,0.00030610847,1.828313e-05,5.250358e-06,2.7238044e-05,2.3757432e-06,1.7083017e-05,0.9974482
554,dataset_split/val/香蕉/97.jpg,15,香蕉,15.0,香蕉,0.0,丝瓜,7.0,白心火龙果,1.0,...,0.013944677,0.022109129,0.0011740384,0.0059756073,0.0022295625,0.0015292587,0.008131748,8.622511e-05,0.0014160236,0.89128304


## 导出完整表格

In [17]:
df.to_csv('测试集预测结果.csv', index=False)