In [1]:
import fiftyone as fo
import os

# 创建数据集

In [2]:
root_dir = "../"
# 数据集的目录
dataset_dir_train = os.path.join(root_dir, "data/cchess_multi_label_layout", "train")
dataset_dir_val = os.path.join(root_dir, "data/cchess_multi_label_layout", "val")

In [3]:
# 遍历数据集
def get_all_files(dataset_dir):
    file_list = os.listdir(dataset_dir)

    sorted_file_list = sorted(file_list)
    
    target_files = []

    for file in sorted_file_list:
        if file.endswith(".jpg") or file.endswith(".png"):
            # ann
            ann_file = file.replace(".jpg", ".txt").replace(".png", ".txt")
            if not os.path.exists(os.path.join(dataset_dir, ann_file)):
                print(f"ann_file not exists: {ann_file}")
                continue
            target_files.append(os.path.join(dataset_dir, file))
            
    return target_files

In [4]:

all_files = get_all_files(dataset_dir_train)
all_files.extend(get_all_files(dataset_dir_val))
print(len(all_files))

20661


## 删除 fifityone, 重新创建

In [5]:
project_name = "chess_multi_label"

fo.delete_dataset(project_name)

DatasetNotFoundError: Dataset chess_multi_label not found

## 加载模型

In [6]:
import os
from onnx_classifier.full_classifier import FULL_CLASSIFIER_ONNX

full_classifier = FULL_CLASSIFIER_ONNX(
    model_path=os.path.join('..', "work_dirs/deploy_0307/cchess_reg.onnx"),
)

2025-03-07 22:29:20.752166476 [W:onnxruntime:, session_state.cc:1136 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2025-03-07 22:29:20.752178519 [W:onnxruntime:, session_state.cc:1138 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.


In [8]:
import torch


torch.cuda.is_available()

True

## 加载数据集

In [10]:
# all_files
import tqdm

all_samples = []

start_center_x = 50
start_center_y = 50

image_width = 450
image_height = 500

# 图片原尺寸 width = 450, height = 500

item_width = (image_width - 50) / 9
item_height = (image_height - 50) / 10


batch_size = 32  # 设置批处理大小

# 创建数据集
dataset = fo.Dataset("chess_multi_label")

# 遍历所有图片文件
for img_path in tqdm.tqdm(all_files):
    # 获取对应的标注文件路径
    ann_path = img_path.replace(".jpg", ".txt").replace(".png", ".txt")

    if not os.path.exists(ann_path):
        print(f"ann_path not exists: {ann_path}")
        continue
    
    annotation = ''
    with open(ann_path, 'r', encoding='utf-8') as f:
        annotation = f.read()

    _, short_labels, confidence_10x9, _ = full_classifier.pred(img_path)


    annotation = annotation.strip()
    # annotation_arr_10_9 为 10 行 9 列的二维数组
    # ['.C....r..', '....a....', '...kba...', 'p.p.p.pCp', '..b.....c', '.cP.n.P..', 'P.......P', 'BR..Ba...', '....KR...', '...N.a.N.']
    annotation_10_rows = [item for item in annotation.split("\n")]
    # 将 annotation_10_rows 转换成为 10 行 9 列的二维数组
    annotation_arr_10_9 = [list(item) for item in annotation_10_rows]


    detections = []


    for row_index, row in enumerate(annotation_arr_10_9):
        for col_index, col in enumerate(row):
            if col != '.':
                detections.append(fo.Detection(
                    label=col,
                    bounding_box=[
                        (start_center_x + col_index * item_width - item_width / 2) / image_width, 
                        (start_center_y + row_index * item_height - item_height / 2) / image_height, 
                        item_width / image_width, 
                        item_height / image_height
                    ]
                ))

    pred_detections = []
    for row_index, row in enumerate(short_labels):
        for col_index, col in enumerate(row):
            if col != '.':
                pred_detections.append(fo.Detection(
                    label=col,
                    confidence=confidence_10x9[row_index][col_index],
                    bounding_box=[
                        (start_center_x + col_index * item_width - item_width / 2) / image_width, 
                        (start_center_y + row_index * item_height - item_height / 2) / image_height, 
                        item_width / image_width, 
                        item_height / image_height
                    ]
                ))

    diff_detections = []
    padding = 10
    for row_index in range(10):
        for col_index in range(9):
            ann_label = annotation_arr_10_9[row_index][col_index]
            pred_label = short_labels[row_index][col_index]

            # 忽略 x 和 .
            if ann_label == 'x' and pred_label == '.':
                continue

            if ann_label == '.' and pred_label == 'x':
                continue

            # 不一致才展示
            if ann_label != pred_label:
                diff_detections.append(fo.Detection(
                    label=f"{ann_label} -> {pred_label}",
                    bounding_box=[
                        (start_center_x + col_index * item_width - item_width / 2 + padding) / image_width, 
                        (start_center_y + row_index * item_height - item_height / 2 + padding) / image_height, 
                        (item_width - padding * 2) / image_width, 
                        (item_height - padding * 2) / image_height
                    ]
                ))

    base_name = os.path.basename(img_path)

    tag="js_v2"

    # 判断 base_name 是否以 js_v2_ 开头
    if not base_name.startswith("js_v2_"):

        if "_" in base_name:
            base_name_arr = base_name.split("_")
        else:
            base_name_arr = base_name.split("-")
        
        tag = base_name_arr[0]
    
    sample = fo.Sample(filepath=img_path, tags=[tag])
        # 添加多标签分类信息
    sample['ground_truth'] = fo.Detections(
        detections=detections
    )
    sample['predictions'] = fo.Detections(
        detections=pred_detections
    )
    sample['diff'] = fo.Detections(
        detections=diff_detections
    )
    # all_samples.append(sample)
    dataset.add_sample(sample)



# dataset.default_config = 
#     draw_config=draw_config
# )

# 保存数据集
dataset.save()


KeyboardInterrupt: 

## 创建各种 view

In [8]:
# 查找 文件以 'js' 开头
from fiftyone import ViewField as F

# 修正: 含有 js_v2_ 的 文件, 且存在 diff 标签
js_diff_view = dataset.match({
    "filepath": {"$regex": ".*js_v2_.*"},
})

js_diff_view = js_diff_view.match({
    "diff.detections": {"$not": {"$size": 0}}  # 检测结果数组非空
})
# 保存 view
js_diff_view.save("js_diff_view")

In [None]:
len(js_diff_view)

1

In [10]:
all_diff_view = dataset.match({
    "diff.detections": {"$not": {"$size": 0}}  # 检测结果数组非空
})
# 保存 view
all_diff_view.save("all_diff_view")
dataset.save_view("all_diff_view", all_diff_view)

In [None]:
len(all_diff_view)

1218