# 使用模型进行"测试集-薄片"的预测与评价，即无标签数据的预测
## The model is used for prediction and evaluation of "test set-slice", i.e. prediction of unlabeled data

# 薄片标签较少，因此"测试集-薄片"采用手动划分，防止出现某些类别没有出现在测试中
## There are fewer wafer labels, so "Test set - Wafer" is manually divided to prevent certain categories from appearing in the test


In [1]:
from Mineral_segmentation.LocalModelPredictor import LocalModelPredictor
from scrips.Configs import Config
import json

area_limit = 150
transparency_threshold=0.15

classes = Config.Classes_rocks

palette = Config.Palette_rocks

label_to_value = {name: idx for idx, name in enumerate(classes)}

# 从 JSON 文件读取字典
with open('sorted_classname_mapping.json', 'r') as json_file:
    sorted_classname_mapping = json.load(json_file)

print("从 data.json 文件读取的字典：", sorted_classname_mapping)


从 data.json 文件读取的字典： {'Muscovite': 8, 'Quartz': 5, 'um': 12}


In [2]:
# 测试集输入
input_dir_unlabel = '../data/testdata'

# 测试集标签，用于评价预测结果  (有则评价，无则不评价)
true_masks_folder = '../data/testdata/labels'

#加载训练模型的权重文件
model_name='resnet18'
model_checkpoint_path=r'models_result\resnet18\resnet18_20250108_163634.pth'


In [3]:
from Mineral_segmentation.Load_data import load_images_masks_with_measuring_scale

print("Load unlabeled data...")
# 用于累积所有批次的数据
images, masks = [], []
image_names   = []
measuring_scales=[]
# 遍历生成器的每个批次
for images_batch, masks_batch, image_names_batch, pixel_to_mm_ratios in load_images_masks_with_measuring_scale(input_dir_unlabel, batch_size=32):

    # 累积每个批次的数据
    images.extend(images_batch)
    masks.extend(masks_batch)

    image_names.extend(image_names_batch)

    measuring_scales.extend(pixel_to_mm_ratios)

# 此时，all_images, all_masks, all_labels中保存了所有批次的数据
print(f"Total images loaded: {len(images)}")
print(f"Total masks loaded: {len(masks)}")



Load unlabeled data...


加载图像和掩膜数据: 100%|██████████| 1/1 [00:00<00:00, 30.41it/s]

Total images loaded: 1
Total masks loaded: 1





In [4]:

from Mineral_segmentation.Extract_instances import extract_instances_unlabels

print("Particle extraction...")

instances_without_label = extract_instances_unlabels(images, masks,measuring_scales=measuring_scales,area_limit=area_limit,image_names=image_names,transparency_threshold=transparency_threshold)

print(instances_without_label.instances[0].original_image_name)
   

Particle extraction...


Extracting instances: 100%|██████████| 1/1 [00:03<00:00,  3.17s/it]

A total of  267 instances are extracted, covering the total image area of  91.90%.
1-28-1-6.jpg





In [5]:
def map_prediction_to_real_class(sorted_classname_mapping,predicted_class):
    # sorted_classname_mapping = dict(sorted(label_to_value.items()))

    key, real_class = list(sorted_classname_mapping.items())[predicted_class]

    return real_class


In [6]:
#加载模型作为预测器
predictor = LocalModelPredictor(
        model_path=model_checkpoint_path,
        model_name=model_name,
        Num_classes=len(sorted_classname_mapping)  #要与训练的模型类别数目保持一致  Num_classes=len(classes)
    )

  self.model.load_state_dict(torch.load(model_path, map_location=self.device))


In [7]:
# from Mineral_segmentation.Map_prediction_to_real_class import map_prediction_to_real_class
from tqdm import tqdm

Threshold=0

for index in tqdm(range(len(instances_without_label.instances)),desc="Predict",total = len(instances_without_label.instances)):
    instance = instances_without_label.instances[index]#获取一个实例
    path =  instance.image_path
    predicted_class,prob = predictor.predict(path)
    #更新实例类别
    if prob >= Threshold:
         # 将预测结果转化为对应的字典值
         real_class = map_prediction_to_real_class(sorted_classname_mapping,predicted_class)
         # print(f'{real_class}-{predicted_class} ')
         
    else:
        real_class = 0
    #更新实例类别id   
    # print(real_class)
    instance.cluster_id = int(real_class)#现在的类别名是数字，之后要换成类别与id的映射



Predict: 100%|██████████| 267/267 [00:01<00:00, 146.16it/s]


In [8]:
from Mineral_segmentation.Recolor_and_Remask import remask_instance_masks

output_masked_unlabels_instances = "./results/masked_unlabels_instances"
#类别赋予 更新实例的_mask,保存新的掩码图像
instances_without_label = remask_instance_masks(instances_without_label,outpath=output_masked_unlabels_instances)


Remasking instances: 100%|██████████| 9/9 [00:00<00:00, 63.55it/s]


In [9]:

import os
from Mineral_segmentation.Reassemble import reassemble_image_masks

out_ressembled_unlabels_masks = os.path.join(predictor.model_dir,os.path.basename(predictor.model_dir)+'_ressembled_unlabels_masks')

# output_ressembled_unlabels_masks = "./results/ressembled_unlabels_masks"
reassembled_image_paths = reassemble_image_masks(instances_without_label, images,out_ressembled_unlabels_masks)
                              
# display_random_images(reassembled_image_paths, num_images=3)                                     

Reassembling masks: 100%|██████████| 267/267 [00:00<00:00, 26768.95it/s]


In [10]:
from Dataset.fill_zero import process_masks

#填补空值 并重新上色
out_ressembled_unlabels_images = os.path.join(predictor.model_dir,os.path.basename(predictor.model_dir)+'_ressembled_unlabels_images')
process_masks(out_ressembled_unlabels_masks, out_ressembled_unlabels_images, palette)
print(out_ressembled_unlabels_images)

paint : 100%|██████████| 1/1 [00:00<00:00,  6.43it/s]

models_result\resnet18\resnet18_ressembled_unlabels_images





# 对预测结果进行评价（若无，则可标签数据不进行评价）
# Evaluation of prediction results (if not, label data not evaluated)

In [11]:
import pandas as pd

from Mineral_segmentation.Evaluate_pix_floder import evaluate_pix

pred_masks_folder = out_ressembled_unlabels_masks
#像素级评价结果输出路径
output_dir_pix = os.path.join(predictor.model_dir,'evaluate_pix')
overall_accuracy, report_df,conf_matrix = evaluate_pix(true_masks_folder, pred_masks_folder, sorted_classname_mapping, output_dir_pix)
sorted_classname = list(sorted_classname_mapping.keys())
conf_matrix_df=pd.DataFrame(conf_matrix,index=sorted_classname,columns=sorted_classname)
conf_matrix_df.to_csv(output_dir_pix+'/conf_matrix.txt',sep="\t")   

Found 1 matching files for evaluation.


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'precision': 0.9817852916314455, 'recall': 0.9294813415815041, 'f1-score': 0.9549176385468866, 'support': 1249570.0}
{'Muscovite': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0.0}, 'Quartz': {'precision': 0.9817852916314455, 'recall': 0.9294813415815041, 'f1-score': 0.9549176385468867, 'support': 1249570.0}, 'um': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0.0}, 'micro avg': {'precision': 0.9757452388973413, 'recall': 0.9294813415815041, 'f1-score': 0.9520515858687247, 'support': 1249570.0}, 'macro avg': {'precision': 0.3272617638771485, 'recall': 0.3098271138605014, 'f1-score': 0.3183058795156289, 'support': 1249570.0}, 'weighted avg': {'precision': 0.9817852916314455, 'recall': 0.9294813415815041, 'f1-score': 0.9549176385468866, 'support': 1249570.0}}
Overall Accuracy: 0.9818
Classification report saved to models_result\resnet18\evaluate_pix\classification_report_pix.csv
Confusion matrix saved to models_result\resnet18\evaluate_pix\confusion_matr