# 准备数据

## 获取标签

In [6]:
import xml.dom.minidom as xmldom
from glob import glob
from tqdm import tqdm

PATH = "./data/VOC2012/Annotations"
all_xml = glob(PATH + "/*.xml")
print(len(all_xml))
label_names = []
for xml in tqdm(all_xml):
    domobj = xmldom.parse(xml)
    elementobj = domobj.documentElement
    subElementObj = elementobj.getElementsByTagName("object")
    for i in range(len(subElementObj)):
        name = subElementObj[i].getElementsByTagName("name")[0].firstChild.data
        if name not in label_names:
            label_names.append(name)
print(label_names)
print(len(label_names))
with open("./data/VOC2012/labels.txt", "w") as f:
    for label in label_names:
        f.write(label + "\n")

17125


100%|██████████| 17125/17125 [00:08<00:00, 2094.65it/s]

['person', 'aeroplane', 'tvmonitor', 'train', 'boat', 'dog', 'chair', 'bird', 'bicycle', 'bottle', 'sheep', 'diningtable', 'horse', 'motorbike', 'sofa', 'cow', 'car', 'cat', 'bus', 'pottedplant']
20





## 计算均值和方差

In [6]:
from src.datasets import VOCDataset
import numpy as np
from tqdm import tqdm

# 获取训练集
train_set = VOCDataset("./data/VOC2012", split="train")

# 计算均值和方差
mean = 0
std = 0
for _, _, image in tqdm(train_set):
    image = np.array(image)
    mean += image.mean(axis=(0, 1))
    std += image.std(axis=(0, 1))
mean /= len(train_set)
std /= len(train_set)
display(mean, std)

100%|██████████| 5717/5717 [01:45<00:00, 54.33it/s]


array([116.54703538, 111.75323747, 103.57417823])

array([60.96688134, 59.94961114, 61.13129536])

## 生成数据列表

In [16]:
import os
import glob
import random

TRAINVAL_PATH = "./data/VOC2012"
sample_list = glob.glob(f"{TRAINVAL_PATH}/Annotations/*.xml")
n_samples = len(sample_list)
print(f"get {n_samples} samples")
n_train_samples = int(n_samples * 0.8)
n_trainval_samples = int(n_samples * 0.9)
random.shuffle(sample_list)
train_samples, val_samples, test_samples = sample_list[:n_train_samples], sample_list[n_train_samples:n_trainval_samples], sample_list[n_trainval_samples:]
train_sample_names = [os.path.basename(sample).split(".")[0] for sample in train_samples]
val_sample_names = [os.path.basename(sample).split(".")[0] for sample in val_samples]
test_sample_names = [os.path.basename(sample).split(".")[0] for sample in test_samples]
with open(f"{TRAINVAL_PATH}/train.txt", "w") as f:
    for name in train_sample_names:
        f.write(name + "\n")
with open(f"{TRAINVAL_PATH}/val.txt", "w") as f:
    for name in val_sample_names:
        f.write(name + "\n")
with open(f"{TRAINVAL_PATH}/test.txt", "w") as f:
    for name in test_sample_names:
        f.write(name + "\n")

get 17125 samples


## 抽样

从训练集、验证集、测试集中分别抽取 30、10、10 个样本，存放到 `./data/VOC2012-sample` 目录下。

In [17]:
from random import sample
from tqdm import tqdm

# 保存路径
ORIGIN_PATH = r".\data\VOC2012"
OUTPUT_PATH = r".\data\VOC2012-sample"
image_path = os.path.join(OUTPUT_PATH, "JPEGImages")
annot_path = os.path.join(OUTPUT_PATH, "Annotations")
os.makedirs(image_path, exist_ok=True)
os.makedirs(annot_path, exist_ok=True)

# 随机抽样
train_sample_ids = sample(train_sample_names, 30)
val_sample_ids = sample(val_sample_names, 10)
test_sample_ids = sample(test_sample_names, 10)
all_sample_ids = train_sample_ids + val_sample_ids + test_sample_ids

# 保存数据集列表
with open(os.path.join(OUTPUT_PATH, "train.txt"), "w") as f:
    train_sample_ids = [str(i) for i in train_sample_ids]
    f.write("\n".join(train_sample_ids))
with open(os.path.join(OUTPUT_PATH, "val.txt"), "w") as f:
    val_sample_ids = [str(i) for i in val_sample_ids]
    f.write("\n".join(val_sample_ids))
with open(os.path.join(OUTPUT_PATH, "test.txt"), "w") as f:
    test_sample_ids = [str(i) for i in test_sample_ids]
    f.write("\n".join(test_sample_ids))

# 保存图像和标注
for id in tqdm(all_sample_ids):
    jpg_origin_path = os.path.join(ORIGIN_PATH, "JPEGImages", f"{id}.jpg")
    jpg_output_path = os.path.join(OUTPUT_PATH, "JPEGImages", f"{id}.jpg")    
    xml_origin_path = os.path.join(ORIGIN_PATH, "Annotations", f"{id}.xml")
    xml_output_path = os.path.join(OUTPUT_PATH, "Annotations", f"{id}.xml")
    os.system(f"copy {jpg_origin_path} {jpg_output_path}")
    os.system(f"copy {xml_origin_path} {xml_output_path}")

100%|██████████| 50/50 [00:01<00:00, 34.21it/s]
