# 预测测试集所有图像

在测试集上，使用训练好的图像分类模型预测，得到预测结果表格。

同济子豪兄：https://space.bilibili.com/1900783

[代码运行云GPU环境](https://featurize.cn/?s=d7ce99f842414bfcaea5662a97581bd1)：GPU RTX 3060、CUDA v11.2

## 导入工具包

In [1]:
import os
from tqdm import tqdm

import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn.functional as F

In [2]:
# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

device cuda:0


## 图像预处理

In [3]:
from torchvision import transforms

# 训练集图像预处理：缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

# 测试集图像预处理-RCTN：缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

## 载入测试集（和训练代码教程相同）

In [4]:
# 数据集文件夹路径
dataset_dir = 'melon17_split'

In [5]:
test_path = os.path.join(dataset_dir, 'val')

In [6]:
from torchvision import datasets

# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)

In [7]:
print('训练集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)

训练集图像数量 509
类别个数 17
各类别名称 ['丝瓜', '人参果', '佛手瓜', '冬瓜', '南瓜', '哈密瓜', '木瓜', '甜瓜-伊丽莎白', '甜瓜-白', '甜瓜-绿', '甜瓜-金', '白兰瓜', '羊角蜜', '苦瓜', '西瓜', '西葫芦', '黄瓜']


In [8]:
# 各类别名称
class_names = test_dataset.classes

In [9]:
class_names

['丝瓜',
 '人参果',
 '佛手瓜',
 '冬瓜',
 '南瓜',
 '哈密瓜',
 '木瓜',
 '甜瓜-伊丽莎白',
 '甜瓜-白',
 '甜瓜-绿',
 '甜瓜-金',
 '白兰瓜',
 '羊角蜜',
 '苦瓜',
 '西瓜',
 '西葫芦',
 '黄瓜']

In [10]:
# 映射关系：类别 到 索引号
test_dataset.class_to_idx

{'丝瓜': 0,
 '人参果': 1,
 '佛手瓜': 2,
 '冬瓜': 3,
 '南瓜': 4,
 '哈密瓜': 5,
 '木瓜': 6,
 '甜瓜-伊丽莎白': 7,
 '甜瓜-白': 8,
 '甜瓜-绿': 9,
 '甜瓜-金': 10,
 '白兰瓜': 11,
 '羊角蜜': 12,
 '苦瓜': 13,
 '西瓜': 14,
 '西葫芦': 15,
 '黄瓜': 16}

## 载入类别名称和ID

In [11]:
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()

In [12]:
idx_to_labels

{0: '丝瓜',
 1: '人参果',
 2: '佛手瓜',
 3: '冬瓜',
 4: '南瓜',
 5: '哈密瓜',
 6: '木瓜',
 7: '甜瓜-伊丽莎白',
 8: '甜瓜-白',
 9: '甜瓜-绿',
 10: '甜瓜-金',
 11: '白兰瓜',
 12: '羊角蜜',
 13: '苦瓜',
 14: '西瓜',
 15: '西葫芦',
 16: '黄瓜'}

In [13]:
# 获得类别名称
classes = list(idx_to_labels.values())
print(classes)

['丝瓜', '人参果', '佛手瓜', '冬瓜', '南瓜', '哈密瓜', '木瓜', '甜瓜-伊丽莎白', '甜瓜-白', '甜瓜-绿', '甜瓜-金', '白兰瓜', '羊角蜜', '苦瓜', '西瓜', '西葫芦', '黄瓜']


## 测试集图像路径及标注

In [14]:
test_dataset.imgs[:10]

[('melon17_split/val/丝瓜/109.jpg', 0),
 ('melon17_split/val/丝瓜/111.jpg', 0),
 ('melon17_split/val/丝瓜/113.jpg', 0),
 ('melon17_split/val/丝瓜/115.jpg', 0),
 ('melon17_split/val/丝瓜/120.jpg', 0),
 ('melon17_split/val/丝瓜/135.jpg', 0),
 ('melon17_split/val/丝瓜/141.jpg', 0),
 ('melon17_split/val/丝瓜/143.jpg', 0),
 ('melon17_split/val/丝瓜/150.jpg', 0),
 ('melon17_split/val/丝瓜/160.jpg', 0)]

In [15]:
img_paths = [each[0] for each in test_dataset.imgs]

In [16]:
df = pd.DataFrame()
df['图像路径'] = img_paths
df['标注类别ID'] = test_dataset.targets
df['标注类别名称'] = [idx_to_labels[ID] for ID in test_dataset.targets]

In [17]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称
0,melon17_split/val/丝瓜/109.jpg,0,丝瓜
1,melon17_split/val/丝瓜/111.jpg,0,丝瓜
2,melon17_split/val/丝瓜/113.jpg,0,丝瓜
3,melon17_split/val/丝瓜/115.jpg,0,丝瓜
4,melon17_split/val/丝瓜/120.jpg,0,丝瓜
...,...,...,...
504,melon17_split/val/黄瓜/85.jpeg,16,黄瓜
505,melon17_split/val/黄瓜/91.png,16,黄瓜
506,melon17_split/val/黄瓜/92.jpg,16,黄瓜
507,melon17_split/val/黄瓜/96.png,16,黄瓜


## 导入训练好的模型

In [22]:
# !wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/checkpoints/melon17_pytorch_20220812.pth -O checkpoints/melon17_pytorch_20220812.pth


In [23]:
model = torch.load('checkpoints/melon17_pytorch_20220812.pth')
model = model.eval().to(device)

## 预测单张图像的函数

In [24]:
n = 3

In [25]:
classes

['丝瓜',
 '人参果',
 '佛手瓜',
 '冬瓜',
 '南瓜',
 '哈密瓜',
 '木瓜',
 '甜瓜-伊丽莎白',
 '甜瓜-白',
 '甜瓜-绿',
 '甜瓜-金',
 '白兰瓜',
 '羊角蜜',
 '苦瓜',
 '西瓜',
 '西葫芦',
 '黄瓜']

In [26]:
df_pred = pd.DataFrame()
for idx, row in tqdm(df.iterrows()):
    img_path = row['图像路径']
    img_pil = Image.open(img_path).convert('RGB')
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测，得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

    pred_dict = {}

    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    
    # top-n 预测结果
    for i in range(1, n+1):
        pred_dict['top-{}-预测ID'.format(i)] = pred_ids[i-1]
        pred_dict['top-{}-预测名称'.format(i)] = idx_to_labels[pred_ids[i-1]]
    pred_dict['top-n预测正确'] = row['标注类别ID'] in pred_ids
    # 每个类别的预测置信度
    for idx, each in enumerate(classes):
        pred_dict['{}-预测置信度'.format(each)] = pred_softmax[0][idx].cpu().detach().numpy()
        
    df_pred = df_pred.append(pred_dict, ignore_index=True)

509it [00:16, 30.20it/s]


In [27]:
df_pred

Unnamed: 0,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,丝瓜-预测置信度,人参果-预测置信度,佛手瓜-预测置信度,...,甜瓜-伊丽莎白-预测置信度,甜瓜-白-预测置信度,甜瓜-绿-预测置信度,甜瓜-金-预测置信度,白兰瓜-预测置信度,羊角蜜-预测置信度,苦瓜-预测置信度,西瓜-预测置信度,西葫芦-预测置信度,黄瓜-预测置信度
0,0.0,丝瓜,2.0,佛手瓜,12.0,羊角蜜,1.0,0.596132,0.0049858876,0.16638866,...,0.002193116,0.014018003,0.014890861,0.0035005058,0.006259539,0.09234588,0.0014535866,0.00392526,0.0066614095,0.001151434
1,0.0,丝瓜,13.0,苦瓜,15.0,西葫芦,1.0,0.37336475,0.0021691073,0.016954591,...,0.0005665258,0.00037204658,0.000705166,0.0020635515,0.0004212223,0.07255299,0.19591753,0.00051014195,0.14407678,0.04816605
2,0.0,丝瓜,5.0,哈密瓜,11.0,白兰瓜,1.0,0.365683,0.007297802,0.00506592,...,0.022292944,0.033744015,0.0024989622,0.009714146,0.10879801,0.04475199,0.009132443,0.011166618,0.09802411,0.035006613
3,3.0,冬瓜,0.0,丝瓜,15.0,西葫芦,1.0,0.106365584,0.0012092766,0.0015654711,...,0.00088775734,0.0031546394,0.0037253152,0.0010065447,0.0067758462,0.0023898373,0.00795086,0.0017810254,0.10077254,0.0067877406
4,0.0,丝瓜,15.0,西葫芦,16.0,黄瓜,1.0,0.6201438,0.00040383684,0.0016078089,...,0.0016597046,0.00012629366,0.00037194998,0.0031683359,0.000812446,0.015667504,0.04719282,0.000765715,0.14260118,0.08020245
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
504,0.0,丝瓜,16.0,黄瓜,14.0,西瓜,1.0,0.35514027,0.008681144,0.029855315,...,0.0016557868,0.0030794123,0.0069995252,0.0032152752,0.00457061,0.07065773,0.07867981,0.09408268,0.021051414,0.25632507
505,16.0,黄瓜,13.0,苦瓜,0.0,丝瓜,1.0,0.015019714,0.0018177313,0.0023614594,...,0.00036590992,0.0005939998,0.0004931314,0.00053202955,0.0010377116,0.012991533,0.070010215,0.00011173802,0.00120984,0.87485486
506,15.0,西葫芦,16.0,黄瓜,0.0,丝瓜,1.0,0.10180287,0.000445377,0.0010038702,...,0.0002971525,0.00019661515,7.126853e-05,0.0004855906,0.0002692762,0.0069090463,0.0014733479,0.0018840615,0.55471665,0.27996975
507,16.0,黄瓜,13.0,苦瓜,0.0,丝瓜,1.0,0.032155924,0.00032865006,0.00017473535,...,0.00042318233,0.00013750575,8.360388e-05,0.00018777518,0.0006451526,0.004023972,0.054080762,0.00064940023,0.00281813,0.8985192


## 拼接表格

In [28]:
df = pd.concat([df, df_pred], axis=1)

In [29]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,...,甜瓜-伊丽莎白-预测置信度,甜瓜-白-预测置信度,甜瓜-绿-预测置信度,甜瓜-金-预测置信度,白兰瓜-预测置信度,羊角蜜-预测置信度,苦瓜-预测置信度,西瓜-预测置信度,西葫芦-预测置信度,黄瓜-预测置信度
0,melon17_split/val/丝瓜/109.jpg,0,丝瓜,0.0,丝瓜,2.0,佛手瓜,12.0,羊角蜜,1.0,...,0.002193116,0.014018003,0.014890861,0.0035005058,0.006259539,0.09234588,0.0014535866,0.00392526,0.0066614095,0.001151434
1,melon17_split/val/丝瓜/111.jpg,0,丝瓜,0.0,丝瓜,13.0,苦瓜,15.0,西葫芦,1.0,...,0.0005665258,0.00037204658,0.000705166,0.0020635515,0.0004212223,0.07255299,0.19591753,0.00051014195,0.14407678,0.04816605
2,melon17_split/val/丝瓜/113.jpg,0,丝瓜,0.0,丝瓜,5.0,哈密瓜,11.0,白兰瓜,1.0,...,0.022292944,0.033744015,0.0024989622,0.009714146,0.10879801,0.04475199,0.009132443,0.011166618,0.09802411,0.035006613
3,melon17_split/val/丝瓜/115.jpg,0,丝瓜,3.0,冬瓜,0.0,丝瓜,15.0,西葫芦,1.0,...,0.00088775734,0.0031546394,0.0037253152,0.0010065447,0.0067758462,0.0023898373,0.00795086,0.0017810254,0.10077254,0.0067877406
4,melon17_split/val/丝瓜/120.jpg,0,丝瓜,0.0,丝瓜,15.0,西葫芦,16.0,黄瓜,1.0,...,0.0016597046,0.00012629366,0.00037194998,0.0031683359,0.000812446,0.015667504,0.04719282,0.000765715,0.14260118,0.08020245
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
504,melon17_split/val/黄瓜/85.jpeg,16,黄瓜,0.0,丝瓜,16.0,黄瓜,14.0,西瓜,1.0,...,0.0016557868,0.0030794123,0.0069995252,0.0032152752,0.00457061,0.07065773,0.07867981,0.09408268,0.021051414,0.25632507
505,melon17_split/val/黄瓜/91.png,16,黄瓜,16.0,黄瓜,13.0,苦瓜,0.0,丝瓜,1.0,...,0.00036590992,0.0005939998,0.0004931314,0.00053202955,0.0010377116,0.012991533,0.070010215,0.00011173802,0.00120984,0.87485486
506,melon17_split/val/黄瓜/92.jpg,16,黄瓜,15.0,西葫芦,16.0,黄瓜,0.0,丝瓜,1.0,...,0.0002971525,0.00019661515,7.126853e-05,0.0004855906,0.0002692762,0.0069090463,0.0014733479,0.0018840615,0.55471665,0.27996975
507,melon17_split/val/黄瓜/96.png,16,黄瓜,16.0,黄瓜,13.0,苦瓜,0.0,丝瓜,1.0,...,0.00042318233,0.00013750575,8.360388e-05,0.00018777518,0.0006451526,0.004023972,0.054080762,0.00064940023,0.00281813,0.8985192


## 导出完整表格

In [30]:
df.to_csv('测试集预测结果.csv', index=False)