# 测试集图像分类预测结果

使用训练好的图像分类模型，预测测试集的所有图像，得到预测结果表格。


## 导入工具包

In [1]:
import os
from tqdm import tqdm

import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn.functional as F

# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

device cuda:0


## 图像预处理

In [2]:
from torchvision import transforms

# # 训练集图像预处理：缩放裁剪、图像增强、转 Tensor、归一化
# train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
#                                       transforms.RandomHorizontalFlip(),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#                                      ])

# 测试集图像预处理-RCTN：缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

## 载入测试集（和训练代码教程相同）

In [3]:
# 数据集文件夹路径
dataset_dir = 'Particle Figures_split'
test_path = os.path.join(dataset_dir, 'val')
from torchvision import datasets
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)
# 载入类别名称 和 ID索引号 的映射字典
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()
# 获得类别名称
classes = list(idx_to_labels.values())
print(classes)

测试集图像数量 248
类别个数 6
各类别名称 ['Angular', 'Rounded', 'Subangular', 'Subrounded', 'Very angular', 'Well rounded']
['Angular', 'Rounded', 'Subangular', 'Subrounded', 'Very angular', 'Well rounded']


## 导入训练好的模型

In [4]:
model = torch.load('Particle Figures_pytorch_C1.pth')
model = model.eval().to(device)

## 表格A-测试集图像路径及标注

In [5]:
test_dataset.imgs[:10]

[('Particle Figures_split/val/Angular/Angular (10).jpg', 0),
 ('Particle Figures_split/val/Angular/Angular (106).jpg', 0),
 ('Particle Figures_split/val/Angular/Angular (111).jpg', 0),
 ('Particle Figures_split/val/Angular/Angular (13).jpg', 0),
 ('Particle Figures_split/val/Angular/Angular (130).jpg', 0),
 ('Particle Figures_split/val/Angular/Angular (131).jpg', 0),
 ('Particle Figures_split/val/Angular/Angular (133).jpg', 0),
 ('Particle Figures_split/val/Angular/Angular (143).jpg', 0),
 ('Particle Figures_split/val/Angular/Angular (146).jpg', 0),
 ('Particle Figures_split/val/Angular/Angular (149).jpg', 0)]

In [6]:
img_paths = [each[0] for each in test_dataset.imgs]

In [7]:
df = pd.DataFrame()
df['Image path'] = img_paths
df['Labeling categories ID'] = test_dataset.targets
df['Labeling categories name'] = [idx_to_labels[ID] for ID in test_dataset.targets]

In [8]:
df

Unnamed: 0,Image path,Labeling categories ID,Labeling categories name
0,Particle Figures_split/val/Angular/Angular (10...,0,Angular
1,Particle Figures_split/val/Angular/Angular (10...,0,Angular
2,Particle Figures_split/val/Angular/Angular (11...,0,Angular
3,Particle Figures_split/val/Angular/Angular (13...,0,Angular
4,Particle Figures_split/val/Angular/Angular (13...,0,Angular
...,...,...,...
243,Particle Figures_split/val/Well rounded/Well r...,5,Well rounded
244,Particle Figures_split/val/Well rounded/Well r...,5,Well rounded
245,Particle Figures_split/val/Well rounded/Well r...,5,Well rounded
246,Particle Figures_split/val/Well rounded/Well r...,5,Well rounded


## 表格B-测试集每张图像的图像分类预测结果，以及各类别置信度

In [9]:
# 记录 top-n 预测结果
n = 3

In [12]:
df_pred = pd.DataFrame()
for idx, row in tqdm(df.iterrows()):
    img_path = row['Image path']
    img_pil = Image.open(img_path).convert('RGB')
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测，得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

    pred_dict = {}

    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    
    # top-n 预测结果
    for i in range(1, n+1):
        pred_dict['top-{}-Predictions ID'.format(i)] = pred_ids[i-1]
        pred_dict['top-{}-Predictions name'.format(i)] = idx_to_labels[pred_ids[i-1]]
    pred_dict['top-n Predictions correction'] = row['Labeling categories ID'] in pred_ids
    # 每个类别的预测置信度
    for idx, each in enumerate(classes):
        pred_dict['{}-Predictions Confidence Level'.format(each)] = pred_softmax[0][idx].cpu().detach().numpy()
        
    df_pred = df_pred._append(pred_dict, ignore_index=True)

248it [00:04, 52.06it/s]


In [13]:
df_pred

Unnamed: 0,top-1-Predictions ID,top-1-Predictions name,top-2-Predictions ID,top-2-Predictions name,top-3-Predictions ID,top-3-Predictions name,top-n Predictions correction,Angular-Predictions Confidence Level,Rounded-Predictions Confidence Level,Subangular-Predictions Confidence Level,Subrounded-Predictions Confidence Level,Very angular-Predictions Confidence Level,Well rounded-Predictions Confidence Level
0,0,Angular,4,Very angular,2,Subangular,True,0.5740181,0.0034985272,0.111644566,0.016418498,0.29402992,0.00039027288
1,3,Subrounded,2,Subangular,0,Angular,True,0.020437263,0.01073399,0.2844617,0.67212415,0.012036975,0.00020592667
2,0,Angular,2,Subangular,4,Very angular,True,0.51057106,0.08709357,0.16634618,0.10468967,0.122171775,0.009127764
3,0,Angular,2,Subangular,4,Very angular,True,0.42534515,0.0047413907,0.4243061,0.027566545,0.11761896,0.00042182644
4,0,Angular,2,Subangular,4,Very angular,True,0.682034,0.01076125,0.17730393,0.050382897,0.07471112,0.004806696
...,...,...,...,...,...,...,...,...,...,...,...,...,...
243,3,Subrounded,1,Rounded,2,Subangular,False,0.03188204,0.3015919,0.23369747,0.36432058,0.050350558,0.018157482
244,5,Well rounded,3,Subrounded,2,Subangular,True,0.06267947,0.0921578,0.1355485,0.27929887,0.009975068,0.4203403
245,2,Subangular,3,Subrounded,4,Very angular,False,0.037356008,0.034041755,0.70084286,0.17599387,0.043497056,0.008268428
246,2,Subangular,3,Subrounded,1,Rounded,False,0.07757834,0.11785209,0.44920433,0.23566626,0.090628035,0.029070955


## 拼接AB两张表格

In [14]:
df = pd.concat([df, df_pred], axis=1)

In [16]:
df

Unnamed: 0,Image path,Labeling categories ID,Labeling categories name,top-1-Predictions ID,top-1-Predictions name,top-2-Predictions ID,top-2-Predictions name,top-3-Predictions ID,top-3-Predictions name,top-n Predictions correction,Angular-Predictions Confidence Level,Rounded-Predictions Confidence Level,Subangular-Predictions Confidence Level,Subrounded-Predictions Confidence Level,Very angular-Predictions Confidence Level,Well rounded-Predictions Confidence Level
0,Particle Figures_split/val/Angular/Angular (10...,0,Angular,0,Angular,4,Very angular,2,Subangular,True,0.5740181,0.0034985272,0.111644566,0.016418498,0.29402992,0.00039027288
1,Particle Figures_split/val/Angular/Angular (10...,0,Angular,3,Subrounded,2,Subangular,0,Angular,True,0.020437263,0.01073399,0.2844617,0.67212415,0.012036975,0.00020592667
2,Particle Figures_split/val/Angular/Angular (11...,0,Angular,0,Angular,2,Subangular,4,Very angular,True,0.51057106,0.08709357,0.16634618,0.10468967,0.122171775,0.009127764
3,Particle Figures_split/val/Angular/Angular (13...,0,Angular,0,Angular,2,Subangular,4,Very angular,True,0.42534515,0.0047413907,0.4243061,0.027566545,0.11761896,0.00042182644
4,Particle Figures_split/val/Angular/Angular (13...,0,Angular,0,Angular,2,Subangular,4,Very angular,True,0.682034,0.01076125,0.17730393,0.050382897,0.07471112,0.004806696
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
243,Particle Figures_split/val/Well rounded/Well r...,5,Well rounded,3,Subrounded,1,Rounded,2,Subangular,False,0.03188204,0.3015919,0.23369747,0.36432058,0.050350558,0.018157482
244,Particle Figures_split/val/Well rounded/Well r...,5,Well rounded,5,Well rounded,3,Subrounded,2,Subangular,True,0.06267947,0.0921578,0.1355485,0.27929887,0.009975068,0.4203403
245,Particle Figures_split/val/Well rounded/Well r...,5,Well rounded,2,Subangular,3,Subrounded,4,Very angular,False,0.037356008,0.034041755,0.70084286,0.17599387,0.043497056,0.008268428
246,Particle Figures_split/val/Well rounded/Well r...,5,Well rounded,2,Subangular,3,Subrounded,1,Rounded,False,0.07757834,0.11785209,0.44920433,0.23566626,0.090628035,0.029070955


## 导出完整表格

In [17]:
df.to_csv('测试集预测结果.csv', index=False)