In [26]:
import os
import xml.etree.ElementTree as ET
import copy
import math
import random

# 读取标准xml

In [27]:
with open("base.xml", mode="r", encoding="utf-8") as f:
    tree = ET.parse(f)
tree

<xml.etree.ElementTree.ElementTree at 0x2cbc13efd70>

## 获取root

In [28]:
root = tree.getroot()
root

<Element 'annotation' at 0x000002CBC1593740>

## 查看root属性

In [29]:
root.tag

'annotation'

In [30]:
root.attrib

{}

## 查看root的子属性

In [31]:
root.find("filename").text

'000001.jpg'

## 查看多级并修改属性

In [32]:
print(root.find("size").find("width").text)
print(root.find("size").find("height").text)
print(root.find("size").find("depth").text)

353
500
3


In [None]:
# 可以这样修改 xml 属性
root.find("size").find("width").text = "500"
root.find("size").find("height").text = "300"
root.find("size").find("depth").text = "4"

In [34]:
print(root.find("size").find("width").text)
print(root.find("size").find("height").text)
print(root.find("size").find("depth").text)

500
300
4


## 通过 `root.findall()` 或者 `root.iter()` 获取迭代器

In [35]:
def find_childrens(root):
    # findall和iter效果类似
    for obj in root.findall("object"):
        print(
            obj.find("name").text,
            obj.find("bndbox").find("xmin").text,
            obj.find("bndbox").find("ymin").text,
            obj.find("bndbox").find("xmax").text,
            obj.find("bndbox").find("ymax").text,
        )

In [36]:
find_childrens(root)

dog 48 240 195 371


## 保存子节点

In [37]:
base_object = copy.deepcopy(root.find("object"))

print(base_object.find("name").text)
print(base_object.find("bndbox").find("xmin").text)
print(base_object.find("bndbox").find("ymin").text)
print(base_object.find("bndbox").find("xmax").text)
print(base_object.find("bndbox").find("ymax").text)

dog
48
240
195
371


## 删除子节点

In [38]:
# 注意: 如果需要多次使用root添加object，每次都要remove，不然就会产生多余的object(之前的全部object)
for o in root.findall("object"):
    root.remove(o)

# 转换为字符串

In [None]:
# 1. ET.tostring() 默认生成 bytes
# 2. 指定 encoding='utf-8' 会自动添加 XML 声明 (<?xml ...?>)
# 3. .decode('utf-8') 将 bytes 转换为 str
str = ET.tostring(root, encoding="utf-8").decode("utf-8")
print(str)

<annotation>
	<folder>VOC2007</folder>
	<filename>000001.jpg</filename>
	<source>
		<database>The VOC2007 Database</database>
		<annotation>PASCAL VOC2007</annotation>
		<image>flickr</image>
		<flickrid>341012865</flickrid>
	</source>
	<owner>
		<flickrid>Fried Camels</flickrid>
		<name>Jinky the Fruit Bat</name>
	</owner>
	<size>
		<width>500</width>
		<height>300</height>
		<depth>4</depth>
	</size>
	<segmented>0</segmented>
	<object>
		<name>stu_2_0</name>
		<pose>Left</pose>
		<truncated>1</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>14.61810348997091</xmin>
			<ymin>-0.9376723412771066</ymin>
			<xmax>12.993044280443474</xmax>
			<ymax>1.0092845816373774</ymax>
		</bndbox>
	</object>
	<object>
		<name>stu_2_1</name>
		<pose>Left</pose>
		<truncated>1</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>14.643652566934449</xmin>
			<ymin>4.026498719161514</ymin>
			<xmax>12.028360312985988</xmax>
			<ymax>3.281341393947442</ymax>
		</bndbox>
	</object>
</annota

## 缩进xml

In [39]:
def indent(elem, level=0):
    """缩进xml
    https://www.cnblogs.com/muffled/p/3462157.html
    """
    i = "\n" + level * "\t"
    if len(elem):
        if not elem.text or not elem.text.strip():
            elem.text = i + "\t"
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
        for elem in elem:
            indent(elem, level + 1)
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
    else:
        if level and (not elem.tail or not elem.tail.strip()):
            elem.tail = i

## 插入多个相同的子节点，注意使用deepcopy，不然插入的总是最后一个

In [40]:
for i in range(3):
    # 注意: 如果需要多次使用root添加object，每次都要remove，不然就会产生多余的object(之前的全部object)
    for o in root.findall("object"):
        root.remove(o)

    for j in range(2):
        # 防止保存多个object时只保存最后一个
        temp_object = copy.deepcopy(base_object)  # important!!!

        temp_object.find("name").text = "stu_" + str(i) + "_" + str(j)
        temp_object.find("bndbox").find("xmin").text = str(
            i + math.pi + j + math.pi + random.random() * 10
        )
        temp_object.find("bndbox").find("ymin").text = str(
            i - math.pi + j - math.pi + random.random() * 10
        )
        temp_object.find("bndbox").find("xmax").text = str(
            i * math.pi + j * math.pi + random.random() * 10
        )
        temp_object.find("bndbox").find("ymax").text = str(
            i / math.pi + j / math.pi + random.random() * 10
        )
        root.append(temp_object)

    find_childrens(root)

    indent(root)

    new_tree = ET.ElementTree(root)

    # 打开使用utf-8,写入时也需要utf-8
    new_tree.write(f"test_{i}.xml", encoding="utf-8")

stu_0_0 12.523331826366032 -4.852619161478607 0.3188825146220997 5.801527557088595
stu_0_1 8.573508783635509 -0.6792989870273338 3.9335275517156596 1.1978187978921973
stu_1_0 13.692716292585596 1.6597139528886595 5.209821278807758 7.083936649039451
stu_1_1 13.170826909722317 2.100816446288561 13.264644727984933 7.930916167329381
stu_2_0 14.61810348997091 -0.9376723412771066 12.993044280443474 1.0092845816373774
stu_2_1 14.643652566934449 4.026498719161514 12.028360312985988 3.281341393947442
