In [138]:
import os
import xml.etree.ElementTree as ET
import copy
import math
import random

# 读取标准xml

In [139]:
with open("base.xml", mode="r", encoding="utf-8") as f:
    tree = ET.parse(f)
tree

<xml.etree.ElementTree.ElementTree at 0x1250918de70>

## 获取root

In [140]:
root = tree.getroot()
root

<Element 'annotation' at 0x000001250A5E93A0>

## 查看root属性

In [141]:
root.tag

'annotation'

In [142]:
root.attrib

{}

## 查看root的子属性

In [143]:
root.find("filename").text

'000001.jpg'

## 查看多级属性

In [144]:
print(root.find("size").find("width").text)
print(root.find("size").find("height").text)
print(root.find("size").find("depth").text)

353
500
3


## 通过 `root.findall()` 或者 `root.iter()` 获取迭代器

In [145]:
def find_childrens(root):
    # findall和iter效果类似
    for obj in root.findall("object"):
        print(
            obj.find("name").text,
            obj.find("bndbox").find("xmin").text,
            obj.find("bndbox").find("ymin").text,
            obj.find("bndbox").find("xmax").text,
            obj.find("bndbox").find("ymax").text,
        )

In [146]:
find_childrens(root)

dog 48 240 195 371


## 保存子节点

In [151]:
base_object = copy.deepcopy(root.find("object"))

print(base_object.find("name").text)
print(base_object.find("bndbox").find("xmin").text)
print(base_object.find("bndbox").find("ymin").text)
print(base_object.find("bndbox").find("xmax").text)
print(base_object.find("bndbox").find("ymax").text)

stu_2_0
13.64582799630811
0.9198220949309999
9.596629207656255
0.8828784201296501


## 删除子节点

In [152]:
# 注意: 如果需要多次使用root添加object，每次都要remove，不然就会产生多余的object(之前的全部object)
for o in root.findall("object"):
    root.remove(o)

## 缩进xml

In [153]:
def indent(elem, level=0):
    """缩进xml
    https://www.cnblogs.com/muffled/p/3462157.html
    """
    i = "\n" + level * "\t"
    if len(elem):
        if not elem.text or not elem.text.strip():
            elem.text = i + "\t"
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
        for elem in elem:
            indent(elem, level + 1)
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
    else:
        if level and (not elem.tail or not elem.tail.strip()):
            elem.tail = i

## 插入多个相同的子节点，注意使用deepcopy，不然插入的总是最后一个

In [154]:
for i in range(3):
    # 注意: 如果需要多次使用root添加object，每次都要remove，不然就会产生多余的object(之前的全部object)
    for o in root.findall("object"):
        root.remove(o)

    for j in range(2):
        # 防止保存多个object时只保存最后一个
        temp_object = copy.deepcopy(base_object)  # important!!!

        temp_object.find("name").text = "stu_" + str(i) + "_" + str(j)
        temp_object.find("bndbox").find("xmin").text = str(
            i + math.pi + j + math.pi + random.random() * 10
        )
        temp_object.find("bndbox").find("ymin").text = str(
            i - math.pi + j - math.pi + random.random() * 10
        )
        temp_object.find("bndbox").find("xmax").text = str(
            i * math.pi + j * math.pi + random.random() * 10
        )
        temp_object.find("bndbox").find("ymax").text = str(
            i / math.pi + j / math.pi + random.random() * 10
        )
        root.append(temp_object)

    find_childrens(root)

    indent(root)

    new_tree = ET.ElementTree(root)

    # 打开使用utf-8,写入时也需要utf-8
    new_tree.write(f"test_{i}.xml", encoding="utf-8")

stu_0_0 14.860877940180245 -3.7586029538221966 7.6231615404153334 8.00178765220383
stu_0_1 16.02661609888373 -2.8661137597596893 11.822194914544196 3.748405527621662
stu_1_0 14.26848849722007 -1.8457651451677943 6.608109819598996 6.649595938427941
stu_1_1 14.420674818930102 -2.2090090325481744 10.560970399808085 7.055940078607612
stu_2_0 10.859530520468994 -2.819998182621188 10.801228099068194 2.2750233133030493
stu_2_1 12.123431800682674 3.6021312169864963 10.688701881391488 4.695713928190694
