In [2]:
import os
from collections import defaultdict

# 数据路径
train_labels_dir = "./datasets/traffic/labels/train"
val_labels_dir = "./datasets/traffic/labels/val"

# 类别映射（根据你的 yaml）
names = {
    0: "5kilometer",
    1: "15kilometer",
    2: "30kilometer",
    3: "40kilometer",
    4: "50kilometer",
    5: "60kilometer",
    6: "70kilometer",
    7: "90kilometer",
    8: "No Left Turn or Straight Ahead",
    9: "No Right Turn or Straight Ahead",
    10: "No Straight Ahead",
    11: "No Left Turn",
    12: "No Left or Right Turn",
    13: "No Right Turn",
    14: "No Overtaking",
    15: "No U-turn",
    16: "No Entry for Motor Vehicles",
    17: "No Horn",
    18: "End Speed Limit 40",
    19: "End Speed Limit 50",
    20: "Turn Right or Go Straight Ahead",
    21: "Ahead Only",
    22: "Left Turn Only",
    23: "Left or Right Turn Only",
    24: "Right Turn Only",
    25: "Keep Left",
    26: "Keep Right",
    27: "Roundabout",
    28: "Motor Vehicles Only",
    29: "Sound Horn",
    30: "Bicycles Only",
    31: "U-turn Only",
    32: "Divided Road Ahead",
    33: "Traffic Signals Ahead",
    34: "General Warning",
    35: "Pedestrian Crossing Ahead",
    36: "Cyclists Ahead",
    37: "Children Crossing Ahead",
    38: "Right Curve Ahead",
    39: "Left Curve Ahead",
    40: "Steep Descent",
    41: "Steep Ascent",
    42: "SLOW",
    43: "Side Road Junction Ahead",
    44: "Side Road Junction (left) Ahead",
    45: "Built-up Area Warning",
    46: "Winding Road Ahead",
    47: "train ahead",
    48: "Road Works Ahead",
    49: "Continuous sharp turn sign",
    50: "Railway level crossing",
    51: "Rear End Collision",
    52: "STOP",
    53: "No Entry for Vehicles",
    54: "No Stopping",
    55: "No Entry",
    56: "Give Way",
    57: "Stop - Police",
}

def count_classes(labels_dir):
    counts = defaultdict(set)  # 用 set 统计每个类别出现过的图片
    for file in os.listdir(labels_dir):
        if file.endswith(".txt"):
            filepath = os.path.join(labels_dir, file)
            with open(filepath, "r") as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) > 0:
                        cls_id = int(parts[0])
                        counts[cls_id].add(file)  # 用文件名防止重复计数
    return {cls: len(files) for cls, files in counts.items()}

# 统计 train 和 val
train_counts = count_classes(train_labels_dir)
val_counts = count_classes(val_labels_dir)

# 输出结果
print("Train dataset class counts:")
for cls_id, count in sorted(train_counts.items()):
    print(f"{cls_id:2d} {names[cls_id]:35s}: {count}")

print("\nVal dataset class counts:")
for cls_id, count in sorted(val_counts.items()):
    print(f"{cls_id:2d} {names[cls_id]:35s}: {count}")


Train dataset class counts:
 0 5kilometer                         : 118
 1 15kilometer                        : 40
 2 30kilometer                        : 80
 3 40kilometer                        : 260
 4 50kilometer                        : 98
 5 60kilometer                        : 194
 6 70kilometer                        : 78
 7 90kilometer                        : 152
 8 No Left Turn or Straight Ahead     : 8
 9 No Right Turn or Straight Ahead    : 2
10 No Straight Ahead                  : 70
11 No Left Turn                       : 138
12 No Left or Right Turn              : 96
13 No Right Turn                      : 36
14 No Overtaking                      : 128
15 No U-turn                          : 22
16 No Entry for Motor Vehicles        : 142
17 No Horn                            : 130
18 End Speed Limit 40                 : 8
19 End Speed Limit 50                 : 4
20 Turn Right or Go Straight Ahead    : 18
21 Ahead Only                         : 12
22 Left Turn Only     