## 推理

In [None]:
import pandas as pd
from torch4keras.snippets import YamlConfig
import os

config = YamlConfig('./config.yaml')
data_dir = config['data_dir']
root_dir = config['root_dir']
model_dir = os.path.join(root_dir, 'ckpt')
train_data_path = os.path.join(data_dir, 'cls.xlsx')
train_data_save_path = os.path.join(root_dir, 'prediction/cls.xlsx')
test_data_path = os.path.join(data_dir, '20250319.xlsx')
test_data_save_path = os.path.join(root_dir, 'prediction/20250319.xlsx')

In [None]:
from train import inference, model

def batch_infer(input_path, output_path):
    train_data = pd.read_excel(input_path)
    map_data = pd.read_excel(os.path.join(data_dir, 'category_map.xlsx'))
    if all(train_data['class_id'].isna()):
        train_data.drop('class_id', axis=1, inplace=True)
        train_data = pd.merge(train_data, map_data[['class_id', 'class_name']], on='class_name')

    map_dict = map_data[['class_id', 'class_name']].set_index('class_id')['class_name'].to_dict()
    texts = train_data['content'].to_list()
    train_data['maybe_wrong'] = False
    for fold in range(0, 5):
        model.load_weights(model_dir + f'best_model_{fold}.pt')
        preds, logits = inference(model, texts)
        train_data[f'pred_{fold}_class_id'] = preds
        train_data[f'pred_{fold}_logit'] = logits
        train_data[f'pred_{fold}_class_name'] = [map_dict[i] for i in preds]
    
    train_data.to_excel(output_path)

batch_infer(train_data_path, train_data_save_path)
batch_infer(test_data_path, test_data_save_path)

In [18]:
# 统计可能错误的样本
def statistics_maybe_wrong(input_path, output_path):
    predict_data = pd.read_excel(input_path)
    cols = ['class_id'] + [f'pred_{fold}_class_id' for fold in range(0, 5)]
    predict_data['maybe_wrong'] = predict_data.apply(lambda x: len(set(x[cols])), axis=1)
    predict_data.to_excel(output_path)

statistics_maybe_wrong(train_data_save_path, train_data_save_path)
statistics_maybe_wrong(test_data_save_path, test_data_save_path)

In [None]:
def statistics(input_path):
    predict_data = pd.read_excel(input_path)
    print('预测不一致的'.center(60, '-'))
    print(predict_data['maybe_wrong'].value_counts())

statistics(train_data_save_path)
statistics(test_data_save_path)