# 测试集图像分类预测结果

使用训练好的图像分类模型，预测测试集的所有图像，得到预测结果表格。

同济子豪兄：https://space.bilibili.com/1900783

[代码运行云GPU环境](https://featurize.cn/?s=d7ce99f842414bfcaea5662a97581bd1)：GPU RTX 3060、CUDA v11.2

## 导入工具包

In [1]:
import os
from tqdm import tqdm

import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn.functional as F

# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

device cuda:0


## 图像预处理

In [2]:
from torchvision import transforms

# # 训练集图像预处理：缩放裁剪、图像增强、转 Tensor、归一化
# train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
#                                       transforms.RandomHorizontalFlip(),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#                                      ])

# 测试集图像预处理-RCTN：缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

  warn(


## 载入测试集（和训练代码教程相同）

In [3]:
# 数据集文件夹路径
dataset_dir = r'D:\dataset\sr'
test_path = os.path.join(dataset_dir, 'val')
from torchvision import datasets
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
# 载入类别名称 和 ID索引号 的映射字典
idx_to_labels = np.load(r'E:\MV-Code-202018010103-Lucy\main\Train_Custom_Dataset\图像分类\3-【Pytorch】迁移学习训练自己的图像分类模型\idx_to_labels.npy', allow_pickle=True).item()
# 获得类别名称
classes = list(idx_to_labels.values())
print(classes)

测试集图像数量 5512
类别个数 2
各类别名称 ['parasitized', 'uninfected']
['parasitized', 'uninfected']


## 导入训练好的模型

In [6]:
model = torch.load( r'E:\MV-Code-202018010103-Lucy\main\Train_Custom_Dataset\图像分类\3-【Pytorch】迁移学习训练自己的图像分类模型\checkpoint\best-0.911.pth')
model = model.eval().to(device)

## 表格A-测试集图像路径及标注

In [7]:
test_dataset.imgs[:10]

[('D:\\dataset\\sr\\val\\parasitized\\C100P61ThinF_IMG_20150918_144104_cell_163_RCAN_BIX4-official.png',
  0),
 ('D:\\dataset\\sr\\val\\parasitized\\C100P61ThinF_IMG_20150918_144104_cell_166_RCAN_BIX4-official.png',
  0),
 ('D:\\dataset\\sr\\val\\parasitized\\C100P61ThinF_IMG_20150918_144104_cell_167_RCAN_BIX4-official.png',
  0),
 ('D:\\dataset\\sr\\val\\parasitized\\C100P61ThinF_IMG_20150918_144104_cell_171_RCAN_BIX4-official.png',
  0),
 ('D:\\dataset\\sr\\val\\parasitized\\C100P61ThinF_IMG_20150918_144348_cell_139_RCAN_BIX4-official.png',
  0),
 ('D:\\dataset\\sr\\val\\parasitized\\C100P61ThinF_IMG_20150918_144348_cell_141_RCAN_BIX4-official.png',
  0),
 ('D:\\dataset\\sr\\val\\parasitized\\C100P61ThinF_IMG_20150918_144348_cell_144_RCAN_BIX4-official.png',
  0),
 ('D:\\dataset\\sr\\val\\parasitized\\C100P61ThinF_IMG_20150918_144823_cell_161_RCAN_BIX4-official.png',
  0),
 ('D:\\dataset\\sr\\val\\parasitized\\C100P61ThinF_IMG_20150918_145609_cell_146_RCAN_BIX4-official.png',
  0),
 

In [8]:
img_paths = [each[0] for each in test_dataset.imgs]

In [9]:
df = pd.DataFrame()
df['图像路径'] = img_paths
df['标注类别ID'] = test_dataset.targets
df['标注类别名称'] = [idx_to_labels[ID] for ID in test_dataset.targets]

In [10]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称
0,D:\dataset\sr\val\parasitized\C100P61ThinF_IMG...,0,parasitized
1,D:\dataset\sr\val\parasitized\C100P61ThinF_IMG...,0,parasitized
2,D:\dataset\sr\val\parasitized\C100P61ThinF_IMG...,0,parasitized
3,D:\dataset\sr\val\parasitized\C100P61ThinF_IMG...,0,parasitized
4,D:\dataset\sr\val\parasitized\C100P61ThinF_IMG...,0,parasitized
...,...,...,...
5507,D:\dataset\sr\val\uninfected\C99P60ThinF_IMG_2...,1,uninfected
5508,D:\dataset\sr\val\uninfected\C99P60ThinF_IMG_2...,1,uninfected
5509,D:\dataset\sr\val\uninfected\C99P60ThinF_IMG_2...,1,uninfected
5510,D:\dataset\sr\val\uninfected\C99P60ThinF_IMG_2...,1,uninfected


## 表格B-测试集每张图像的图像分类预测结果，以及各类别置信度

In [11]:
# 记录 top-n 预测结果
n = 2

In [12]:

# Let's assume that 'test_transform', 'model', 'device', 'idx_to_labels', and 'classes' are already defined

# Your original DataFrame 'df' that contains '图像路径' and '标注类别ID'
# For example: df = pd.DataFrame({'图像路径': ['path1.jpg', 'path2.jpg'], '标注类别ID': [0, 1]})

# The number of predictions to consider as top-n


# Collect each prediction dictionary in a list
pred_dicts = []

for idx, row in tqdm(df.iterrows(), total=df.shape[0]):
    img_path = row['图像路径']
    img_pil = Image.open(img_path).convert('RGB')
    input_img = test_transform(img_pil).unsqueeze(0).to(device)
    pred_logits = model(input_img)
    pred_softmax = F.softmax(pred_logits, dim=1)

    pred_dict = {}

    top_n = torch.topk(pred_softmax, n)
    pred_ids = top_n[1].cpu().detach().numpy().squeeze()

    for i in range(1, n+1):
        pred_dict['top-{}-预测ID'.format(i)] = int(pred_ids[i-1])
        pred_dict['top-{}-预测名称'.format(i)] = idx_to_labels[int(pred_ids[i-1])]
    pred_dict['top-n预测正确'] = int(row['标注类别ID']) in pred_ids.tolist()

    for idx, each in enumerate(classes):
        pred_dict['{}-预测置信度'.format(each)] = float(pred_softmax[0][idx].cpu().detach().numpy())
    
    pred_dicts.append(pred_dict)

# Now concatenate all dictionaries to form a DataFrame
df_pred = pd.DataFrame(pred_dicts)

# Display the new prediction DataFrame
display(df_pred)


100%|██████████| 5512/5512 [01:29<00:00, 61.83it/s]


Unnamed: 0,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-n预测正确,parasitized-预测置信度,uninfected-预测置信度
0,0,parasitized,1,uninfected,True,0.965240,0.034760
1,1,uninfected,0,parasitized,True,0.383567,0.616433
2,0,parasitized,1,uninfected,True,0.780675,0.219325
3,0,parasitized,1,uninfected,True,0.994083,0.005918
4,0,parasitized,1,uninfected,True,0.997244,0.002756
...,...,...,...,...,...,...,...
5507,1,uninfected,0,parasitized,True,0.078119,0.921881
5508,0,parasitized,1,uninfected,True,0.636197,0.363803
5509,1,uninfected,0,parasitized,True,0.453985,0.546015
5510,1,uninfected,0,parasitized,True,0.174118,0.825882


In [13]:
df_pred

Unnamed: 0,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-n预测正确,parasitized-预测置信度,uninfected-预测置信度
0,0,parasitized,1,uninfected,True,0.965240,0.034760
1,1,uninfected,0,parasitized,True,0.383567,0.616433
2,0,parasitized,1,uninfected,True,0.780675,0.219325
3,0,parasitized,1,uninfected,True,0.994083,0.005918
4,0,parasitized,1,uninfected,True,0.997244,0.002756
...,...,...,...,...,...,...,...
5507,1,uninfected,0,parasitized,True,0.078119,0.921881
5508,0,parasitized,1,uninfected,True,0.636197,0.363803
5509,1,uninfected,0,parasitized,True,0.453985,0.546015
5510,1,uninfected,0,parasitized,True,0.174118,0.825882


## 拼接AB两张表格

In [14]:
df = pd.concat([df, df_pred], axis=1)

In [15]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-n预测正确,parasitized-预测置信度,uninfected-预测置信度
0,D:\dataset\sr\val\parasitized\C100P61ThinF_IMG...,0,parasitized,0,parasitized,1,uninfected,True,0.965240,0.034760
1,D:\dataset\sr\val\parasitized\C100P61ThinF_IMG...,0,parasitized,1,uninfected,0,parasitized,True,0.383567,0.616433
2,D:\dataset\sr\val\parasitized\C100P61ThinF_IMG...,0,parasitized,0,parasitized,1,uninfected,True,0.780675,0.219325
3,D:\dataset\sr\val\parasitized\C100P61ThinF_IMG...,0,parasitized,0,parasitized,1,uninfected,True,0.994083,0.005918
4,D:\dataset\sr\val\parasitized\C100P61ThinF_IMG...,0,parasitized,0,parasitized,1,uninfected,True,0.997244,0.002756
...,...,...,...,...,...,...,...,...,...,...
5507,D:\dataset\sr\val\uninfected\C99P60ThinF_IMG_2...,1,uninfected,1,uninfected,0,parasitized,True,0.078119,0.921881
5508,D:\dataset\sr\val\uninfected\C99P60ThinF_IMG_2...,1,uninfected,0,parasitized,1,uninfected,True,0.636197,0.363803
5509,D:\dataset\sr\val\uninfected\C99P60ThinF_IMG_2...,1,uninfected,1,uninfected,0,parasitized,True,0.453985,0.546015
5510,D:\dataset\sr\val\uninfected\C99P60ThinF_IMG_2...,1,uninfected,1,uninfected,0,parasitized,True,0.174118,0.825882


## 导出完整表格

In [16]:
df.to_csv('测试集预测结果.csv', index=False)