# 测试集图像分类结果

使用训练好的图像分类模型，预测测试集的所有图像，得到预测结果表格。

同济子豪兄：https://space.bilibili.com/1900783

[代码运行云GPU环境](https://featurize.cn/?s=d7ce99f842414bfcaea5662a97581bd1)：GPU RTX 3060、CUDA v11.2

## 导入工具包

In [1]:
import os
from tqdm import tqdm

import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn.functional as F

# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

device cuda:0


## 图像预处理

In [2]:
from torchvision import transforms

# 训练集图像预处理：缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

# 测试集图像预处理-RCTN：缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

## 载入测试集（和训练代码教程相同）

In [3]:
# 数据集文件夹路径
dataset_dir = 'melon17_split'
test_path = os.path.join(dataset_dir, 'val')
from torchvision import datasets
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)
print('训练集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
# 载入类别名称 和 ID索引号 的映射字典
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()
# 获得类别名称
classes = list(idx_to_labels.values())
print(classes)

训练集图像数量 509
类别个数 17
各类别名称 ['丝瓜', '人参果', '佛手瓜', '冬瓜', '南瓜', '哈密瓜', '木瓜', '甜瓜-伊丽莎白', '甜瓜-白', '甜瓜-绿', '甜瓜-金', '白兰瓜', '羊角蜜', '苦瓜', '西瓜', '西葫芦', '黄瓜']
['丝瓜', '人参果', '佛手瓜', '冬瓜', '南瓜', '哈密瓜', '木瓜', '甜瓜-伊丽莎白', '甜瓜-白', '甜瓜-绿', '甜瓜-金', '白兰瓜', '羊角蜜', '苦瓜', '西瓜', '西葫芦', '黄瓜']


## 导入训练好的模型

In [4]:
# !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/checkpoints/melon17_pytorch_20220812.pth -O checkpoints/melon17_pytorch_20220812.pth


In [5]:
model = torch.load('checkpoints/melon17_pytorch_20220813.pth')
model = model.eval().to(device)

# 测试集预测结果表格

## 表格-测试集图像路径及标注

In [76]:
test_dataset.imgs[:10]

[('melon17_split/val/丝瓜/109.jpg', 0),
 ('melon17_split/val/丝瓜/111.jpg', 0),
 ('melon17_split/val/丝瓜/113.jpg', 0),
 ('melon17_split/val/丝瓜/115.jpg', 0),
 ('melon17_split/val/丝瓜/120.jpg', 0),
 ('melon17_split/val/丝瓜/135.jpg', 0),
 ('melon17_split/val/丝瓜/141.jpg', 0),
 ('melon17_split/val/丝瓜/143.jpg', 0),
 ('melon17_split/val/丝瓜/150.jpg', 0),
 ('melon17_split/val/丝瓜/160.jpg', 0)]

In [77]:
img_paths = [each[0] for each in test_dataset.imgs]

In [78]:
df = pd.DataFrame()
df['图像路径'] = img_paths
df['标注类别ID'] = test_dataset.targets
df['标注类别名称'] = [idx_to_labels[ID] for ID in test_dataset.targets]

In [79]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称
0,melon17_split/val/丝瓜/109.jpg,0,丝瓜
1,melon17_split/val/丝瓜/111.jpg,0,丝瓜
2,melon17_split/val/丝瓜/113.jpg,0,丝瓜
3,melon17_split/val/丝瓜/115.jpg,0,丝瓜
4,melon17_split/val/丝瓜/120.jpg,0,丝瓜
...,...,...,...
504,melon17_split/val/黄瓜/85.jpeg,16,黄瓜
505,melon17_split/val/黄瓜/91.png,16,黄瓜
506,melon17_split/val/黄瓜/92.jpg,16,黄瓜
507,melon17_split/val/黄瓜/96.png,16,黄瓜


## 表格-测试集每张图像的图像分类预测结果，以及各类别置信度

In [80]:
# 记录 top-n 预测结果
n = 3

In [81]:
df_pred = pd.DataFrame()
for idx, row in tqdm(df.iterrows()):
    img_path = row['图像路径']
    img_pil = Image.open(img_path).convert('RGB')
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测，得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

    pred_dict = {}

    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    
    # top-n 预测结果
    for i in range(1, n+1):
        pred_dict['top-{}-预测ID'.format(i)] = pred_ids[i-1]
        pred_dict['top-{}-预测名称'.format(i)] = idx_to_labels[pred_ids[i-1]]
    pred_dict['top-n预测正确'] = row['标注类别ID'] in pred_ids
    # 每个类别的预测置信度
    for idx, each in enumerate(classes):
        pred_dict['{}-预测置信度'.format(each)] = pred_softmax[0][idx].cpu().detach().numpy()
        
    df_pred = df_pred.append(pred_dict, ignore_index=True)

509it [00:10, 47.96it/s]


In [82]:
df_pred

Unnamed: 0,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,丝瓜-预测置信度,人参果-预测置信度,佛手瓜-预测置信度,...,甜瓜-伊丽莎白-预测置信度,甜瓜-白-预测置信度,甜瓜-绿-预测置信度,甜瓜-金-预测置信度,白兰瓜-预测置信度,羊角蜜-预测置信度,苦瓜-预测置信度,西瓜-预测置信度,西葫芦-预测置信度,黄瓜-预测置信度
0,0.0,丝瓜,2.0,佛手瓜,3.0,冬瓜,1.0,0.6025479,0.0051945965,0.18266502,...,0.0009523736,0.030350773,0.002757894,0.0011861285,0.0069286656,0.067796834,0.00073874864,0.00035390825,0.00038738217,0.00041518346
1,0.0,丝瓜,13.0,苦瓜,12.0,羊角蜜,1.0,0.45952642,0.00083556655,0.0043829586,...,5.072762e-06,3.538376e-05,3.8564405e-05,1.7009173e-05,1.929875e-05,0.053255055,0.39012507,1.580059e-05,0.022745842,0.04728109
2,0.0,丝瓜,3.0,冬瓜,5.0,哈密瓜,1.0,0.7159912,0.00085558224,0.003953781,...,0.0038605933,0.02665418,0.00048441,0.001215179,0.018728737,0.044135172,0.0038077678,0.0024780761,0.02845215,0.026642872
3,3.0,冬瓜,0.0,丝瓜,15.0,西葫芦,1.0,0.15648821,8.364851e-05,0.001339487,...,4.3033728e-05,0.00044631434,0.001275594,1.0221728e-05,0.0009085963,0.0018855697,0.01761911,0.000474287,0.048291694,0.018402666
4,0.0,丝瓜,15.0,西葫芦,13.0,苦瓜,1.0,0.7972671,5.001999e-05,0.00033321528,...,8.444517e-05,9.845294e-06,2.3616867e-05,0.00034483196,5.4924898e-05,0.0048265187,0.05549101,9.747511e-05,0.08496752,0.029771086
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
504,0.0,丝瓜,16.0,黄瓜,13.0,苦瓜,1.0,0.39302263,0.003081525,0.020131085,...,0.0002438154,0.0004994096,0.0010968797,0.0008781446,0.0014920258,0.050604977,0.15442532,0.02937849,0.0061266352,0.29520804
505,16.0,黄瓜,13.0,苦瓜,0.0,丝瓜,1.0,0.015567795,0.0008430542,0.002900118,...,1.512831e-05,0.00034122684,2.605401e-05,4.0408617e-05,0.00021403635,0.0030988124,0.10705263,7.25777e-06,0.00013544042,0.8656235
506,15.0,西葫芦,16.0,黄瓜,0.0,丝瓜,1.0,0.09589289,0.000109568646,0.0006377687,...,9.7345865e-06,5.2412793e-06,6.0724307e-07,8.978602e-06,1.217463e-05,0.0024328867,0.00076153403,9.7708e-05,0.5582861,0.28152248
507,16.0,黄瓜,13.0,苦瓜,0.0,丝瓜,1.0,0.023885977,0.00015046919,2.098638e-05,...,1.9152138e-05,2.4887358e-05,2.2076024e-06,3.9186667e-05,8.076126e-05,0.0027375026,0.042325806,7.93395e-05,0.0002736452,0.92884755


## 拼接两张表格

In [83]:
df = pd.concat([df, df_pred], axis=1)

In [84]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,...,甜瓜-伊丽莎白-预测置信度,甜瓜-白-预测置信度,甜瓜-绿-预测置信度,甜瓜-金-预测置信度,白兰瓜-预测置信度,羊角蜜-预测置信度,苦瓜-预测置信度,西瓜-预测置信度,西葫芦-预测置信度,黄瓜-预测置信度
0,melon17_split/val/丝瓜/109.jpg,0,丝瓜,0.0,丝瓜,2.0,佛手瓜,3.0,冬瓜,1.0,...,0.0009523736,0.030350773,0.002757894,0.0011861285,0.0069286656,0.067796834,0.00073874864,0.00035390825,0.00038738217,0.00041518346
1,melon17_split/val/丝瓜/111.jpg,0,丝瓜,0.0,丝瓜,13.0,苦瓜,12.0,羊角蜜,1.0,...,5.072762e-06,3.538376e-05,3.8564405e-05,1.7009173e-05,1.929875e-05,0.053255055,0.39012507,1.580059e-05,0.022745842,0.04728109
2,melon17_split/val/丝瓜/113.jpg,0,丝瓜,0.0,丝瓜,3.0,冬瓜,5.0,哈密瓜,1.0,...,0.0038605933,0.02665418,0.00048441,0.001215179,0.018728737,0.044135172,0.0038077678,0.0024780761,0.02845215,0.026642872
3,melon17_split/val/丝瓜/115.jpg,0,丝瓜,3.0,冬瓜,0.0,丝瓜,15.0,西葫芦,1.0,...,4.3033728e-05,0.00044631434,0.001275594,1.0221728e-05,0.0009085963,0.0018855697,0.01761911,0.000474287,0.048291694,0.018402666
4,melon17_split/val/丝瓜/120.jpg,0,丝瓜,0.0,丝瓜,15.0,西葫芦,13.0,苦瓜,1.0,...,8.444517e-05,9.845294e-06,2.3616867e-05,0.00034483196,5.4924898e-05,0.0048265187,0.05549101,9.747511e-05,0.08496752,0.029771086
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
504,melon17_split/val/黄瓜/85.jpeg,16,黄瓜,0.0,丝瓜,16.0,黄瓜,13.0,苦瓜,1.0,...,0.0002438154,0.0004994096,0.0010968797,0.0008781446,0.0014920258,0.050604977,0.15442532,0.02937849,0.0061266352,0.29520804
505,melon17_split/val/黄瓜/91.png,16,黄瓜,16.0,黄瓜,13.0,苦瓜,0.0,丝瓜,1.0,...,1.512831e-05,0.00034122684,2.605401e-05,4.0408617e-05,0.00021403635,0.0030988124,0.10705263,7.25777e-06,0.00013544042,0.8656235
506,melon17_split/val/黄瓜/92.jpg,16,黄瓜,15.0,西葫芦,16.0,黄瓜,0.0,丝瓜,1.0,...,9.7345865e-06,5.2412793e-06,6.0724307e-07,8.978602e-06,1.217463e-05,0.0024328867,0.00076153403,9.7708e-05,0.5582861,0.28152248
507,melon17_split/val/黄瓜/96.png,16,黄瓜,16.0,黄瓜,13.0,苦瓜,0.0,丝瓜,1.0,...,1.9152138e-05,2.4887358e-05,2.2076024e-06,3.9186667e-05,8.076126e-05,0.0027375026,0.042325806,7.93395e-05,0.0002736452,0.92884755


## 导出完整表格

In [16]:
df.to_csv('测试集预测结果.csv', index=False)