# 测试集图像分类预测结果

使用训练好的图像分类模型，预测测试集的所有图像，得到预测结果表格。

## 导入工具包

In [10]:
import os
from tqdm import tqdm

import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn.functional as F

# 有 GPU 就用 GPU，没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

device cuda:0


## 图像预处理

In [11]:
from torchvision import transforms
from COME15KClassDataset import set_data_loader
from torchvision import datasets

# # 训练集图像预处理：缩放裁剪、图像增强、转 Tensor、归一化
# train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
#                                       transforms.RandomHorizontalFlip(),
#                                       transforms.ToTensor(),
#                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#                                      ])

# 测试集图像预处理-RCTN：缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

## 载入测试集（和训练代码教程相同）easy

In [12]:
# 数据集文件夹路径
dataset_name = 'test_easy_classes'
dataset_dir = '../data_class_txt/'+ dataset_name + '.txt'
# 类别名称 和 ID索引号 的映射字典
class_names_dic = {0: 'covering', 1: 'device', 2: 'domestic_animal', 3: 'mater', 4: 'person', 5: 'plant',
                       6: 'structure', 7: 'vertebrate'}

test_dataset_loader_easy = set_data_loader(dataset_attr_word="test_easy", batch_size=1, size=512, shuffle=False,
                                           transforms_compose=test_transform, dataset_dir='data/SOD-SemanticDataset')
# 载入测试集
print('测试集图像数量', len(test_dataset_loader_easy))
print('类别个数', len(class_names_dic))
print('各类别名称', list(class_names_dic.values()))
# 获得类别名称
classes = list(class_names_dic.values())
print(classes)

测试集图像数量 4600
类别个数 8
各类别名称 ['covering', 'device', 'domestic_animal', 'mater', 'person', 'plant', 'structure', 'vertebrate']
['covering', 'device', 'domestic_animal', 'mater', 'person', 'plant', 'structure', 'vertebrate']


## 载入测试集（和训练代码教程相同）hard

In [4]:
# 数据集文件夹路径
dataset_name = 'test_hard_classes'
model_path = '2023-10-09-01-47_max_epoch_100/'
model_name = 'retrain_COME15K_checkpoint-best-avg-0.743-Medium.pth.tar'
dataset_dir = '../data_class_txt/'+ dataset_name + '.txt'
# 类别名称 和 ID索引号 的映射字典
class_names_dic = {0: 'covering', 1: 'device', 2: 'domestic_animal', 3: 'mater', 4: 'person', 5: 'plant',
                       6: 'structure', 7: 'vertebrate'}

test_dataset_loader_easy = set_data_loader(dataset_attr_word="test_hard", batch_size=1, size=512, shuffle=False,
                                           transforms_compose=test_transform, dataset_dir='data/SOD-SemanticDataset')
# 载入测试集
print('测试集图像数量', len(test_dataset_loader_easy))
print('类别个数', len(class_names_dic))
print('各类别名称', list(class_names_dic.values()))
# 获得类别名称
classes = list(class_names_dic.values())
print(classes)

测试集图像数量 3000
类别个数 8
各类别名称 ['covering', 'device', 'domestic_animal', 'mater', 'person', 'plant', 'structure', 'vertebrate']
['covering', 'device', 'domestic_animal', 'mater', 'person', 'plant', 'structure', 'vertebrate']


## 训练前模型

In [13]:
model_path = '2023-10-09-01-47_max_epoch_100/'
model_name = 'retrain_COME15K_checkpoint-best-avg-0.743-Medium.pth.tar'

In [18]:
after_sod_model_path = './SOD-model/'
after_sod_model_name = 'retrianed-frozen-1234-epoch-198.pth'

## 导入训练好的模型(导入参数)

In [15]:
# from network import ShuffleNetV2_Plus
# # init model
# architecture = [0, 0, 3, 1, 1, 1, 0, 0, 2, 0, 2, 1, 1, 0, 2, 0, 2, 1, 3, 2]
# model = ShuffleNetV2_Plus(architecture=architecture, n_class=class_names_dic.__len__(), model_size="Medium")
# weight_path = model_path + model_name
# trained_weight = torch.load(weight_path)
# model.load_state_dict(trained_weight['state_dict'], strict=True)
# model = model.eval().to(device)

## 导入训练好的模型(导入模型 + 参数)

In [41]:
model_and_weight_path = 'models/' + model_path + model_name
model = torch.load(model_and_weight_path)

In [42]:
import collections
after_sod_weight_path = after_sod_model_path + after_sod_model_name
after_sod_trained_weight = torch.load(after_sod_weight_path)
# print(after_sod_trained_weight)
new_orderdic = collections.OrderedDict()
for name, value in after_sod_trained_weight.items():
    if name.startswith('backbone.'):
        new_orderdic.setdefault(name.strip('backbone.'), value)
model.load_state_dict(new_orderdic, strict=False)
model = model.eval().to(device)

## 表格A-测试集图像路径及标注

In [43]:
data_path = test_dataset_loader_easy.dataset.data_path

In [44]:
img_paths = test_dataset_loader_easy.dataset.images
img_lables = test_dataset_loader_easy.dataset.labels

In [45]:
df = pd.DataFrame()
df['图像路径'] = [data_path + img_name for img_name in img_paths]
df['标注类别ID'] = img_lables
df['标注类别名称'] = [class_names_dic.get(ID) for ID in img_lables]

In [46]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称
0,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,0,covering
1,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,0,covering
2,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,0,covering
3,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,0,covering
4,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,0,covering
...,...,...,...
4595,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,7,vertebrate
4596,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,7,vertebrate
4597,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,7,vertebrate
4598,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,7,vertebrate


47_max_epoch_100## 表格B-测试集每张图像的图像分类预测结果，以及各类别置信度

In [47]:
# 记录 top-n 预测结果
n = 3

In [48]:
df_pred = pd.DataFrame()
for idx, row in tqdm(df.iterrows()):
    img_path = row['图像路径']
    img_pil = Image.open(img_path).convert('RGB')
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测，得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

    pred_dict = {}

    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
    
    # top-n 预测结果
    for i in range(1, n+1):
        pred_dict['top-{}-预测ID'.format(i)] = pred_ids[i-1]
        pred_dict['top-{}-预测名称'.format(i)] = class_names_dic.get(pred_ids[i-1])
    pred_dict['top-n预测正确'] = row['标注类别ID'] in pred_ids
    # 每个类别的预测置信度
    for idx, each in enumerate(classes):
        pred_dict['{}-预测置信度'.format(each)] = pred_softmax[0][idx].cpu().detach().numpy()
        
    df_pred = df_pred._append(pred_dict, ignore_index=True)

4600it [00:56, 81.34it/s]


In [49]:
df_pred

Unnamed: 0,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,covering-预测置信度,device-预测置信度,domestic_animal-预测置信度,mater-预测置信度,person-预测置信度,plant-预测置信度,structure-预测置信度,vertebrate-预测置信度
0,4,person,7,vertebrate,2,domestic_animal,False,1.1946931e-05,2.5092378e-10,2.6842174e-05,7.3258616e-07,0.99991345,9.284433e-12,4.4283144e-08,4.690949e-05
1,4,person,2,domestic_animal,7,vertebrate,False,9.711632e-06,3.5668707e-10,3.6198588e-05,1.925552e-06,0.99993086,4.8180637e-11,4.0123645e-08,2.1353806e-05
2,4,person,7,vertebrate,2,domestic_animal,False,3.798483e-08,1.5153873e-13,1.5963609e-07,2.6747823e-10,0.99999964,1.524078e-16,1.7499225e-12,1.969033e-07
3,4,person,2,domestic_animal,0,covering,True,6.0619975e-08,3.8730195e-14,3.3574102e-07,1.338268e-10,0.9999995,2.1752351e-16,3.4640168e-12,2.0080037e-08
4,4,person,2,domestic_animal,0,covering,True,1.0830144e-07,1.9855631e-13,1.8413326e-06,1.6820173e-09,0.999998,5.34575e-15,1.979243e-11,9.478797e-08
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4595,4,person,2,domestic_animal,0,covering,False,3.010814e-06,1.011442e-11,1.5906622e-05,5.4411586e-08,0.99998057,1.8852224e-12,1.4813116e-09,5.5865e-07
4596,4,person,2,domestic_animal,0,covering,False,2.2406213e-07,5.3714444e-13,2.7510835e-06,2.4366613e-09,0.9999969,1.0267309e-14,2.9232602e-11,1.4792947e-07
4597,4,person,2,domestic_animal,7,vertebrate,True,8.597927e-07,3.663859e-12,5.5530045e-06,9.363415e-09,0.9999925,1.8976894e-14,1.5881911e-10,1.0303363e-06
4598,4,person,2,domestic_animal,7,vertebrate,True,4.4077947e-08,1.02376915e-13,8.8837305e-07,8.831876e-10,0.99999917,1.2622912e-15,6.519891e-12,5.53733e-08


## 拼接AB两张表格

In [50]:
df = pd.concat([df, df_pred], axis=1)

In [51]:
df

Unnamed: 0,图像路径,标注类别ID,标注类别名称,top-1-预测ID,top-1-预测名称,top-2-预测ID,top-2-预测名称,top-3-预测ID,top-3-预测名称,top-n预测正确,covering-预测置信度,device-预测置信度,domestic_animal-预测置信度,mater-预测置信度,person-预测置信度,plant-预测置信度,structure-预测置信度,vertebrate-预测置信度
0,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,0,covering,4,person,7,vertebrate,2,domestic_animal,False,1.1946931e-05,2.5092378e-10,2.6842174e-05,7.3258616e-07,0.99991345,9.284433e-12,4.4283144e-08,4.690949e-05
1,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,0,covering,4,person,2,domestic_animal,7,vertebrate,False,9.711632e-06,3.5668707e-10,3.6198588e-05,1.925552e-06,0.99993086,4.8180637e-11,4.0123645e-08,2.1353806e-05
2,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,0,covering,4,person,7,vertebrate,2,domestic_animal,False,3.798483e-08,1.5153873e-13,1.5963609e-07,2.6747823e-10,0.99999964,1.524078e-16,1.7499225e-12,1.969033e-07
3,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,0,covering,4,person,2,domestic_animal,0,covering,True,6.0619975e-08,3.8730195e-14,3.3574102e-07,1.338268e-10,0.9999995,2.1752351e-16,3.4640168e-12,2.0080037e-08
4,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,0,covering,4,person,2,domestic_animal,0,covering,True,1.0830144e-07,1.9855631e-13,1.8413326e-06,1.6820173e-09,0.999998,5.34575e-15,1.979243e-11,9.478797e-08
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4595,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,7,vertebrate,4,person,2,domestic_animal,0,covering,False,3.010814e-06,1.011442e-11,1.5906622e-05,5.4411586e-08,0.99998057,1.8852224e-12,1.4813116e-09,5.5865e-07
4596,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,7,vertebrate,4,person,2,domestic_animal,0,covering,False,2.2406213e-07,5.3714444e-13,2.7510835e-06,2.4366613e-09,0.9999969,1.0267309e-14,2.9232602e-11,1.4792947e-07
4597,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,7,vertebrate,4,person,2,domestic_animal,7,vertebrate,True,8.597927e-07,3.663859e-12,5.5530045e-06,9.363415e-09,0.9999925,1.8976894e-14,1.5881911e-10,1.0303363e-06
4598,data/SOD-SemanticDataset/test/COME15K-Easy/COM...,7,vertebrate,4,person,2,domestic_animal,7,vertebrate,True,4.4077947e-08,1.02376915e-13,8.8837305e-07,8.831876e-10,0.99999917,1.2622912e-15,6.519891e-12,5.53733e-08


## 导出完整表格

In [52]:
df.to_csv(after_sod_model_path + dataset_name +'-测试集预测结果.csv', index=False)