抽象数据接口 \
1.图像(img) \
2.数据样本(DataSample)：一个训练或测试样本的所有标注信息和预测信息 \
3.数据元素(xxxData)：单一类型的预测或标注，一般是指模型中子模块的输出

注意：数据样本和数据元素并不是互斥关系，数据样本是对数据元素的高级封装

数据样本和数据元素的基类：BaseDataElement \
包含两种数据类型，data和metainfo，data是可以通过key=value形式直接赋值的，metainfo必须显式指定

In [1]:
#数据元素基类的创建
from mmengine.structures import BaseDataElement 
import torch

img_id=0
H,W=640,640
bboxes=torch.randn((5,4))
scores=torch.randn((5,))

#空
data_sample=BaseDataElement()
#直接赋值给data
data_sample1=BaseDataElement(bboxes=bboxes,scores=scores)
#显式赋值metainfo
data_sample2=BaseDataElement(
    bboxes=bboxes,
    scores=scores,
    metainfo=dict(img_id=img_id,img_shape=(H,W))
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#BaseDataElement的new、clone方法
"""
用户可以使用 new() 方法基于已有的 BaseDataElement 创建一个具有相同 data 和 metainfo 的 
BaseDataElement。用户也可以在调用 new 方法时传入新的 data 和 metainfo，例如 new(metainfo=xx) ，
此时创建的 BaseDataElement 相较于已有的 BaseDataElement，data 完全一致 ，而 metainfo 则为新设置的内容。
也可以直接使用 clone() 来获得一份深拷贝，clone() 函数的行为与 PyTorch 中 Tensor 的 clone() 参数保持一致
"""
data_element = BaseDataElement(
    bboxes=torch.rand((5, 4)),
    scores=torch.rand((5,)),
    metainfo=dict(img_id=1, img_shape=(640, 640)))

# 可以在创建新 `BaseDataElement` 时设置 metainfo 和 data，使得新的 BaseDataElement 有相同未被设置的数据
data_element1 = data_element.new(metainfo=dict(img_id=2, img_shape=(320, 320)))
print('bboxes is in data_element1:', 'bboxes' in data_element1) # True
print('bboxes in data_element1 is same as bbox in data_element', (data_element1.bboxes == data_element.bboxes).all())
print('img_id in data_element1 is', data_element1.img_id == 2) # True

data_element2 = data_element.new(label=torch.rand(5,))
print('bboxes is not in data_element2', 'bboxes' not in data_element2) # True
print('img_id in data_element2 is same as img_id in data_element', data_element2.img_id == data_element.img_id)
print('label in data_element2 is', 'label' in data_element2)

# 也可以通过 `clone` 构建一个新的 object，新的 object 会拥有和 data_element 相同的 data 和 metainfo 内容以及状态。
data_element2 = data_element1.clone()

bboxes is in data_element1: True
bboxes in data_element1 is same as bbox in data_element tensor(True)
img_id in data_element1 is True
bboxes is not in data_element2 True
img_id in data_element2 is same as img_id in data_element True
label in data_element2 is True


属性的增加与查询 \
1. 增加data属性，像类属性一样，用.即可；增加metainfo属性，需要使用set_metainfo接口 \
2. 查询data属性，keys,values,items;查询metainfo属性，使用metainfo_keys,metainfo_values,metainfo_items,使用all_keys,all_values,all_items接口查询所有的属性，不区分类型 \
3. 通过get()接口像访问字典一样访问某个属性的值，也可直接像类属性一样查看(注意这个也是增加属性的方式) \
4. data和metainfo不能有相同的属性名 \
5. 不能使用名字索引取值、赋值

In [2]:
from mmengine.structures import BaseDataElement
data_element = BaseDataElement()
# 通过 `set_metainfo`设置 data_element 的 metainfo 字段，
# 同时 img_id 和 img_shape 成为 data_element 的属性
data_element.set_metainfo(dict(img_id=9, img_shape=(100, 100)))
# 查看 metainfo 的 key, value 和 item
print("metainfo'keys are", data_element.metainfo_keys())
print("metainfo'values are", data_element.metainfo_values())
for k, v in data_element.metainfo_items():
    print(f'{k}: {v}')

print("通过类属性查看 img_id 和 img_shape")
print('img_id:', data_element.img_id)
print('img_shape:', data_element.img_shape)

metainfo'keys are ['img_shape', 'img_id']
metainfo'values are [(100, 100), 9]
img_shape: (100, 100)
img_id: 9
通过类属性查看 img_id 和 img_shape
img_id: 9
img_shape: (100, 100)


属性的删改  \
data属性的修改和data属性的新增方式一样，通过类属性方式(.引用)   \
metainfo属性的修改，需要使用set_metainfo接口 \
删除接口，del和pop

In [2]:
from mmengine.structures import BaseDataElement
import torch
data_element=BaseDataElement(
    bboxes=torch.rand((6,4)),
    scores=torch.rand((6,)),
    metainfo=dict(img_id=9, img_shape=(640, 640))
)

for k,v in data_element.all_items():
    print(f"{k}:{v}")

img_id:9
img_shape:(640, 640)
scores:tensor([0.2028, 0.2644, 0.0398, 0.1335, 0.4261, 0.8572])
bboxes:tensor([[0.5987, 0.0852, 0.6665, 0.7433],
        [0.8427, 0.5409, 0.3495, 0.0052],
        [0.6860, 0.0147, 0.5511, 0.5160],
        [0.6471, 0.9110, 0.0569, 0.1992],
        [0.4515, 0.5749, 0.6803, 0.0289],
        [0.2259, 0.8861, 0.2789, 0.5278]])


In [3]:
# 对 data 进行修改
data_element.bboxes = data_element.bboxes * 2
data_element.scores = data_element.scores * -1
for k, v in data_element.items():
    print(f'{k}: {v}')

# 删除 data 中的属性
del data_element.bboxes
for k, v in data_element.items():
    print(f'{k}: {v}')

data_element.pop('scores', None)
print('The keys in data is', data_element.keys())

scores: tensor([-0.2028, -0.2644, -0.0398, -0.1335, -0.4261, -0.8572])
bboxes: tensor([[1.1974, 0.1704, 1.3329, 1.4867],
        [1.6854, 1.0818, 0.6990, 0.0104],
        [1.3720, 0.0294, 1.1021, 1.0320],
        [1.2942, 1.8221, 0.1138, 0.3985],
        [0.9031, 1.1498, 1.3606, 0.0578],
        [0.4518, 1.7723, 0.5578, 1.0555]])
scores: tensor([-0.2028, -0.2644, -0.0398, -0.1335, -0.4261, -0.8572])
The keys in data is []


In [4]:
# 对 metainfo 进行修改
data_element.set_metainfo(dict(img_shape = (1280, 1280), img_id=10))
print(data_element.img_shape)  # (1280, 1280)
for k, v in data_element.metainfo_items():
    print(f'{k}: {v}')

# 提供了便捷的属性删除和访问操作 pop
del data_element.img_shape
for k, v in data_element.metainfo_items():
    print(f'{k}: {v}')

data_element.pop('img_id')
print('The keys in metainfo is', data_element.metainfo_keys())

(1280, 1280)
img_id: 10
img_shape: (1280, 1280)
img_id: 10
The keys in metainfo is []


类张量操作 \    
用户可以像 torch.Tensor 那样对 BaseDataElement 的 data 进行状态转换，目前支持 cuda， cpu， to， numpy 等操作。 其中，to 函数拥有和 torch.Tensor.to() 相同的接口，使得用户可以灵活地将被封装的 tensor 进行状态转换。 注意： 这些接口只会处理类型为 np.array，torch.Tensor，或者数字的序列，其他属性的数据（如字符串）会被跳过处理。

In [17]:
from mmengine.structures import BaseDataElement
import torch

data_element=BaseDataElement(
    bboxes=torch.rand((6,4)),
    scores=torch.rand((6,)),
    metainfo=dict(img_id=0, img_shape=(640, 640))
)

#转移到gpu上
cuda_element_1=data_element.cuda()
print(f"cuda_element_1's device is {cuda_element_1.bboxes.device}")
cuda_element_2=data_element.to(torch.device('cuda:0'))
print(f"cuda_element_2's device is {cuda_element_2.bboxes.device}")

# 将所有 data 转移到 cpu 上
cpu_element_1 = cuda_element_1.cpu()
print('cpu_element_1 is on the device of', cpu_element_1.bboxes.device)  # cpu
cpu_element_2 = cuda_element_2.to('cpu')
print('cpu_element_2 is on the device of', cpu_element_2.bboxes.device)  # cpu

cuda_element_1's device is cuda:0
cuda_element_2's device is cuda:0
cpu_element_1 is on the device of cpu
cpu_element_2 is on the device of cpu


In [18]:
#将所有的data变成fp16
fp16_instances=cuda_element_1.to(device=None,dtype=torch.float16,non_blocking=False, copy=False,
    memory_format=torch.preserve_format)
print(f"fp16_instances's data type is {fp16_instances.bboxes.dtype}")
#阻断梯度
cuda_element_3=cuda_element_1.detach()
print(f'cuda_element_3 required grad is {cuda_element_3.bboxes.requires_grad}')
#转移 data 到 numpy array
np_instances = cpu_element_1.numpy()
print('The type of cpu_element_1 is convert to', type(np_instances.bboxes))

fp16_instances's data type is torch.float16
cuda_element_3 required grad is False
The type of cpu_element_1 is convert to <class 'numpy.ndarray'>


属性展示 \
BaseDataElement 还实现了 \__repr__，因此，用户可以直接通过 print 函数看到其中的所有数据信息。 同时，为了便捷开发者 debug，BaseDataElement 中的属性都会添加进 \__dict__ 中，方便用户在 IDE 界面可以直观看到 BaseDataElement 中的内容。 一个完整的属性展示如下

In [19]:
img_info=dict(img_shape=(640, 640,3),padding_shape=(4,4,3))
instances=BaseDataElement(metainfo=img_info)
instances.bboxes=torch.randn(10,4)
instances.labels=torch.randint(0,10,(10,))
print(instances)

<BaseDataElement(

    META INFORMATION
    padding_shape: (4, 4, 3)
    img_shape: (640, 640, 3)

    DATA FIELDS
    labels: tensor([3, 9, 2, 6, 0, 3, 2, 5, 9, 0])
    bboxes: tensor([[ 0.6133, -1.4811,  1.8712, -0.3347],
                [-2.1329,  0.5925, -0.3979,  0.0459],
                [ 0.2678, -0.2856, -1.5449, -0.2113],
                [-0.4307, -1.1413,  0.2422,  0.1296],
                [ 0.8319, -0.2335,  1.4273, -1.2889],
                [ 2.9358, -0.7402, -0.6104,  0.5525],
                [ 0.3033, -1.5639,  0.1427,  0.8960],
                [-1.2516, -0.3869, -0.1031, -0.3760],
                [ 0.6205,  0.1703,  1.1321,  0.1605],
                [ 0.7103, -0.7690,  1.2041, -1.2740]])
) at 0x7efd22545cf0>


数据元素(xxData)
1. InstanceData
2. PixelData
3. LabelData

在 BaseDataElement 的基础上对 data 存储的数据做了限制，要求存储在 data 中的数据的长度一致。比如在目标检测中, 假设一张图像中有 N 个目标 (instance)，可以将图像的所有边界框 (bbox)，类别 (label) 等存储在 InstanceData 中, InstanceData 的 bbox 和 label 的长度相同 
1. 对 InstanceData 中 data 所存储的数据进行了长度校验
2. data 部分支持类字典访问和设置它的属性
3. 支持基础索引，切片以及高级索引功能
4. 支持具有相同的 key 但是不同的 InstanceData 进行拼接的功能。

这些扩展功能除了支持基础的数据结构， 比如 torch.tensor, numpy.dnarray, list, str 和 tuple, 也可以是自定义的数据结构，只要自定义数据结构实现了 __len__, __getitem__ 和 cat 方法。

In [1]:
#1.InstanceData 数据校验
from mmengine.structures import InstanceData
import torch
import numpy as np

img_data=dict(img_shape=(640, 640,3),padding_shape=(4,4,3))
instances=InstanceData(metainfo=img_data)
instances.bboxes=torch.randn(2,4)
instances.det_labels=torch.LongTensor([2,3])
instances.det_scores=torch.Tensor([0.9,0.8])

print(f"the length of instances is {len(instances)}")#len函数直接取用instances的data属性的长度

#假如加入的data属性长度不等，则报断言错误
instances.bboxes=torch.randn(3,4)



  from .autonotebook import tqdm as notebook_tqdm


the length of instances is 2


AssertionError: The length of values 3 is not consistent with the length of this :obj:`InstanceData` 2

In [2]:
#2.InstanceData是支持像字典一样按照字段名进行索引的
img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instance_data = InstanceData(metainfo=img_meta)
instance_data["det_labels"] = torch.LongTensor([2, 3])
instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
instance_data.bboxes = torch.rand((2, 4))
print(instance_data)

<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([[0.6656, 0.2982, 0.4252, 0.1886],
                [0.7531, 0.9877, 0.2023, 0.7168]])
    det_scores: tensor([0.8000, 0.7000])
    det_labels: tensor([2, 3])
) at 0x7f4a80453eb0>


In [3]:
# 3.InstanceData 索引与切片
img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instance_data = InstanceData(metainfo=img_meta)
instance_data.det_labels = torch.LongTensor([2, 3])
instance_data.det_scores = torch.Tensor([0.8, 0.7])
instance_data.bboxes = torch.rand((2, 4))
print(instance_data)

<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([[0.7027, 0.4390, 0.1586, 0.8094],
                [0.3566, 0.0361, 0.9358, 0.6374]])
    det_scores: tensor([0.8000, 0.7000])
    det_labels: tensor([2, 3])
) at 0x7f4a7c16c4c0>


In [4]:
#InstanceData支持按位置索引，但是取所有data属性，如下
print(instance_data[0])

<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([[0.7027, 0.4390, 0.1586, 0.8094]])
    det_scores: tensor([0.8000])
    det_labels: tensor([2])
) at 0x7f4a7c16ce80>


In [5]:
#切片同理
print(instance_data[0:1])

<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([[0.7027, 0.4390, 0.1586, 0.8094]])
    det_scores: tensor([0.8000])
    det_labels: tensor([2])
) at 0x7f4a80452b30>


注：在numpy、torch中索引是取出一个元素，其shape当缩减一，而切片是取出一个切片，其shape不变。 \
InstanceData的索引和切片都不会改变shape

In [6]:
a=np.array([[1,2,3],[4,5,6]])
print(a[0])
print(a[0:1])

[1 2 3]
[[1 2 3]]


In [7]:
b=torch.tensor([[1,2,3],[4,5,6]])
print(b[0])
print(b[0:1])

tensor([1, 2, 3])
tensor([[1, 2, 3]])


In [8]:
#列表索引
sorted_results = instance_data[instance_data.det_scores.sort().indices]
print(sorted_results)

<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([[0.3566, 0.0361, 0.9358, 0.6374],
                [0.7027, 0.4390, 0.1586, 0.8094]])
    det_scores: tensor([0.7000, 0.8000])
    det_labels: tensor([3, 2])
) at 0x7f49d784e290>


In [9]:
#布尔索引
filter_results=instance_data[instance_data.det_scores>0.75]
print(filter_results)

<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([[0.7027, 0.4390, 0.1586, 0.8094]])
    det_scores: tensor([0.8000])
    det_labels: tensor([2])
) at 0x7f49d784ee90>


In [10]:
#结果为空
empty_results = instance_data[instance_data.det_scores > 1]
print(empty_results)

<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([], size=(0, 4))
    det_scores: tensor([])
    det_labels: tensor([], dtype=torch.int64)
) at 0x7f49d784e9b0>


In [11]:
#具有相同key（包括metainfo和data属性）的InstanceData对象可以cat
img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instance_data = InstanceData(metainfo=img_meta)
instance_data.det_labels = torch.LongTensor([2, 3])
instance_data.det_scores = torch.Tensor([0.8, 0.7])
instance_data.bboxes = torch.rand((2, 4))
print('The length of instance_data is', len(instance_data))
cat_results = InstanceData.cat([instance_data, instance_data])
print('The length of instance_data is', len(cat_results))
print(cat_results)

The length of instance_data is 2
The length of instance_data is 4
<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([[0.1941, 0.2833, 0.1317, 0.0637],
                [0.7264, 0.8688, 0.0104, 0.8151],
                [0.1941, 0.2833, 0.1317, 0.0637],
                [0.7264, 0.8688, 0.0104, 0.8151]])
    det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
    det_labels: tensor([2, 3, 2, 3])
) at 0x7f49d784f070>


In [37]:
#自定义数据结构，需要实现__len__,__getitem__,cat静态方法
import itertools

class TmpObject:
    def __init__(self, tmp) -> None:
        assert isinstance(tmp, list)
        self.tmp = tmp

    def __len__(self):
        return len(self.tmp)

    def __getitem__(self, item):
        if type(item) == int:
            if item >= len(self) or item < -len(self):  # type:ignore
                raise IndexError(f'Index {item} out of range!')
            else:
                # keep the dimension
                item = slice(item, None, len(self))
        return TmpObject(self.tmp[item])

    @staticmethod
    def cat(tmp_objs):
        assert all(isinstance(results, TmpObject) for results in tmp_objs)
        if len(tmp_objs) == 1:
            return tmp_objs[0]
        tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
        tmp_list = list(itertools.chain(*tmp_list))
        new_data = TmpObject(tmp_list)
        return new_data

    def __repr__(self):
        return str(self.tmp)

In [34]:
#这段代码有三个地方值得注意
#1. slice方法  (start,stop,step)返回slice对象，可以用于切片
#写成slice(idx,None,len(nums))的形式，可以取出nums中idx位置的元素，且保持shape不变
nums=list(range(10))
s=slice(1,None,len(nums))
nums[s]

[1]

In [23]:
#2.all方法，判断传入可迭代对象的每一个元素是否为真
#这里主要是说明本来是应该写成all((isinstance(results, TmpObject) for results in tmp_objs))的形式
#因为for循环单独写不行，应该写成元组生成器(xxx for ...)的形式，但是这里省略括号，直接放入all()方法的括号里也行
#这应该是一种写法

In [24]:
#3.itertools.chain方法：将多个可迭代对象整合为一个可迭代对象，并保持shape不变
#比如多维的列表，这里的*a为两个2维数组，整合之后仍是2维数组，也就是整合发生在最外层
#用这个即可实现任意可迭代对象的cat方法
import itertools
a=[[[1, 2, 3, 4], [5, 6, 7, 8]],[[1, 2, 3, 4], [5, 6, 7, 8]]]
print(*a)
b=itertools.chain(*a)
print(list(b))

[[1, 2, 3, 4], [5, 6, 7, 8]] [[1, 2, 3, 4], [5, 6, 7, 8]]
[[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]]


In [38]:
img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instance_data = InstanceData(metainfo=img_meta)
instance_data.det_labels = torch.LongTensor([2, 3])
instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
instance_data.bboxes = torch.rand((2, 4))
instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
#print函数会调用instance_data中每一个属性的__repr__方法,可以看到这里直接显示的tmp值，而不是整个对象
print(instance_data)

<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([[0.9587, 0.0455, 0.6518, 0.6603],
                [0.0439, 0.6812, 0.9202, 0.5651]])
    det_scores: tensor([0.8000, 0.7000])
    polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
    det_labels: tensor([2, 3])
) at 0x7f49d78283a0>


In [39]:
# 高级索引
print(instance_data[instance_data.det_scores > 0.75])

<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([[0.9587, 0.0455, 0.6518, 0.6603]])
    det_scores: tensor([0.8000])
    polygons: [[1, 2, 3, 4]]
    det_labels: tensor([2])
) at 0x7f4a7c16c5e0>


In [40]:
# 拼接
print(InstanceData.cat([instance_data, instance_data]))

<InstanceData(

    META INFORMATION
    img_shape: (800, 1196, 3)
    pad_shape: (800, 1216, 3)

    DATA FIELDS
    bboxes: tensor([[0.9587, 0.0455, 0.6518, 0.6603],
                [0.0439, 0.6812, 0.9202, 0.5651],
                [0.9587, 0.0455, 0.6518, 0.6603],
                [0.0439, 0.6812, 0.9202, 0.5651]])
    det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
    polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]]
    det_labels: tensor([2, 3, 2, 3])
) at 0x7f49d7829c90>


PixelData \
在BaseDataElement基础上对data属性做了两点约束： 
1. 所有 data 内的数据均为 3 维，并且顺序为 (通道，高，宽) 
2. 所有在 data 内的数据要有相同的长和宽
且具有以下扩展：\
1.对PixelData 中 data 所存储的数据进行了尺寸的校验 \
2支持对 data 部分的数据对实例进行空间维度的索引和切片

In [1]:
from mmengine.structures import PixelData
import torch
import numpy as np
import random

metainfo=dict(
    img_id=random.randint(0, 100),
    img_shape=(random.randint(400,600),random.randint(400,600))
)
image=np.random.randint(0,255,(4,20,40))
featmap=np.random.randint(0,255,(10,20,40))

pixel_data=PixelData(image=image,
                     featmap=featmap,
                     metainfo=metainfo)

print('The shape of pixel_data is', pixel_data.shape)
# set
pixel_data.map3 = torch.randint(0, 255, (20, 40))
print('The shape of pixel_data is', pixel_data.map3.shape)


The shape of pixel_data is (20, 40)
The shape of pixel_data is torch.Size([1, 20, 40])


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pixel_data.map2 = torch.randint(0, 255, (3, 20, 30))
# AssertionError: the height and width of values (20, 30) is not consistent with the length of this :obj:`PixelData` (20, 40)

AssertionError: The height and width of values (20, 30) is not consistent with the shape of this :obj:`PixelData` (20, 40)

In [3]:
pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40))
# AssertionError: The dim of value must be 2 or 3, but got 4

AssertionError: The dim of value must be 2 or 3, but got 4

In [4]:
#空间维度索引
#PixelData支持对data部分的数据实例进行空间维度的索引和切片，只需传入长度
metainfo=dict(
    img_id=random.randint(0, 100),
    img_shape=(random.randint(400,600),random.randint(400,600))
)
image=np.random.randint(0,255,(4,20,40))
featmap=np.random.randint(0,255,(10,20,40))

pixel_data=PixelData(image=image,
                     featmap=featmap,
                     metainfo=metainfo)

#可以看到pixel_data.shape是只有长和宽的
print("The shape of pixel_data is",pixel_data.shape)

The shape of pixel_data is (20, 40)


In [5]:
#索引
#索引返回的也是一个PixelData对象，所以其shape也只有长和宽
index_data=pixel_data[10,20]
print("The shape of index_data is ",index_data.shape)

The shape of index_data is  (1, 1)


In [6]:
#切片
#切片同理
slice_data=pixel_data[10:20,20:40]
print("The shape of slice_data is",slice_data.shape)

The shape of slice_data is (10, 20)


LabelData \
LabelData主要用来封装标签数据，如场景分类标签，文字识别标签等。LabelData没有对data做任何限制，只提供了两个额外功能：onehot和index的转换。

In [7]:
from mmengine.structures import LabelData
import torch

item=torch.tensor([1],dtype=torch.int64)
num_classes=10

In [8]:
onehot=LabelData.label_to_onehot(label=item,num_classes=num_classes)
print(f"{num_classes} is convert to ",onehot)

index=LabelData.onehot_to_label(onehot=onehot)
print(f"{onehot} is convert to ",index)

10 is convert to  tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0]) is convert to  tensor([1])


数据样本(xxxDataSample) \
对数据元素的高级封装，用于保存图像级别的标注信息（包括各种标注，如实例分割，检测，语义分割，类别等）\
以MMDet下游库为例，定义DetDataSample，其定义了7个字段: \
括号内为其数据类型

- 标注信息：
    gt_instance(InstanceData)：实例标注信息，包括实例的类别、边界框等
    gt_panoptic_seg(PixelData):全景分割的标注信息
    gt_semantic_seg(PixelData):语义分割的标注信息
- 预测结果：
    pred_instance(InstanceData):实例预测结果，包括实例的类别、边界框等
    pred_panoptic_seg(PixelData):全景分割的预测信息
    pred_semantic_seg(PixelData):语义分割的预测信息
- 中间结果：
    proposal(InstanceData):主要为二阶段中RPN的预测结果

In [None]:
#property的setter和deleter方法分别设置了被装饰函数的赋值和删除的自定义操作

from mmengine.structures import BaseDataElement
import torch

class DetDataSample(BaseDataElement):

    # 标注
    @property
    def gt_instances(self) -> InstanceData:
        return self._gt_instances

    #设置属性时 类型的约束
    @gt_instances.setter
    def gt_instances(self, value: InstanceData):
        self.set_field(value, '_gt_instances', dtype=InstanceData)

    @gt_instances.deleter
    def gt_instances(self):
        del self._gt_instances

    @property
    def gt_panoptic_seg(self) -> PixelData:
        return self._gt_panoptic_seg

    @gt_panoptic_seg.setter
    def gt_panoptic_seg(self, value: PixelData):
        self.set_field(value, '_gt_panoptic_seg', dtype=PixelData)

    @gt_panoptic_seg.deleter
    def gt_panoptic_seg(self):
        del self._gt_panoptic_seg

    @property
    def gt_sem_seg(self) -> PixelData:
        return self._gt_sem_seg

    @gt_sem_seg.setter
    def gt_sem_seg(self, value: PixelData):
        self.set_field(value, '_gt_sem_seg', dtype=PixelData)

    @gt_sem_seg.deleter
    def gt_sem_seg(self):
        del self._gt_sem_seg

    # 预测
    @property
    def pred_instances(self) -> InstanceData:
        return self._pred_instances

    @pred_instances.setter
    def pred_instances(self, value: InstanceData):
        self.set_field(value, '_pred_instances', dtype=InstanceData)

    @pred_instances.deleter
    def pred_instances(self):
        del self._pred_instances

    @property
    def pred_panoptic_seg(self) -> PixelData:
        return self._pred_panoptic_seg

    @pred_panoptic_seg.setter
    def pred_panoptic_seg(self, value: PixelData):
        self.set_field(value, '_pred_panoptic_seg', dtype=PixelData)

    @pred_panoptic_seg.deleter
    def pred_panoptic_seg(self):
        del self._pred_panoptic_seg

    # 中间结果
    @property
    def pred_sem_seg(self) -> PixelData:
        return self._pred_sem_seg

    @pred_sem_seg.setter
    def pred_sem_seg(self, value: PixelData):
        self.set_field(value, '_pred_sem_seg', dtype=PixelData)

    @pred_sem_seg.deleter
    def pred_sem_seg(self):
        del self._pred_sem_seg

    @property
    def proposals(self) -> InstanceData:
        return self._proposals

    @proposals.setter
    def proposals(self, value: InstanceData):
        self.set_field(value, '_proposals', dtype=InstanceData)

    @proposals.deleter
    def proposals(self):
        del self._proposals


类型约束 \
上述每一个属性的赋值操作都设置了类型约束

接口简化

In [None]:
#不同组件输入的参数不一定相同，在mmdet3.x版本中都简化为img,data_sample的形式，不同组件按照需求按名字索引对应值即可。
from mmdet.models import BaseDetector
class SingleStageDetector(BaseDetector):
    ...

    def forward_train(self,
                      img,
                      data_samples):
        pass

class SingleStageInstanceSegmentor(BaseDetector):
    ...

    def forward_train(self,
                      img,
                      data_samples):
        pass
