In [43]:
import torch
from typing import List
from torch import Tensor
import numpy as np

def _max_by_axis_pad(the_list):
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)

    block = 128

    for i in range(2):
        maxes[i+1] = ((maxes[i+1] - 1) // block + 1) * block
    return maxes


def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:

        # TODO make it support different-sized images
        max_size = _max_by_axis_pad([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        for img, pad_img in zip(tensor_list, tensor):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
    else:
        raise ValueError('not supported')
    return tensor

def collate_fn_crowd(batch):
    # re-organize the batch
    batch_new = []
    for b in batch:
        imgs, points = b
        if imgs.ndim == 3:
            imgs = imgs.unsqueeze(0) # 将单个图像转换为形状为 [1, C, H, W] 
        # 遍历处理每个图像（如果imgs是多图像张量，则会遍历每个图像）
        print(f"imgs.shape = {imgs.shape}")
        for i in range(len(imgs)):
            batch_new.append((imgs[i, :, :, :], points[i]))
    print(f"batch_new = {batch_new}")
    batch = batch_new
    # print(f"batch_new = {batch_new}")
    batch = list(zip(*batch))
    print(f"len = {len(batch)}")
    # print(f"batch = {batch}")
    batch[0] = nested_tensor_from_tensor_list(batch[0])
    print(f"batch[0] = {batch[0].shape}")
    
    print(f"batch[1] = {batch[1]}")
    

    return tuple(batch)

In [44]:
# 假设图像的高度和宽度
c1, height1, width1 = 3, 100, 200
c2, height2, width2 = 3, 150, 250

# 创建两个随机图像张量
img1 = torch.rand(c1, height1, width1)
img2 = torch.rand(c2, height2, width2)

# 创建对应的随机标注点
# 假设每个图像有5个标注点
points1 = torch.rand(5, 2).tolist()
points2 = torch.rand(5, 2).tolist()

print(f"points1 = {points1}")
print(f"points2 = {points2}")

# 创建批次数据
batch = [
    (img1, points1),  # img1: [c1, H1, W1], points1: [N1, 2]
    (img2, points2)   # img2: [c2, H2, W2], points2: [N2, 2]
]

result = collate_fn_crowd(batch)
print(f"result = {result[1]}")

points1 = [[0.6638845801353455, 0.8458588719367981], [0.4588415026664734, 0.9813203811645508], [0.3919997215270996, 0.18197685480117798], [0.047447264194488525, 0.9441867470741272], [0.8190175294876099, 0.08473968505859375]]
points2 = [[0.8430103659629822, 0.939226508140564], [0.005113065242767334, 0.5371531248092651], [0.044196128845214844, 0.5609201192855835], [0.22545641660690308, 0.8165781497955322], [0.019254744052886963, 0.25466984510421753]]
imgs.shape = torch.Size([1, 3, 100, 200])
imgs.shape = torch.Size([1, 3, 150, 250])
batch_new = [(tensor([[[0.0143, 0.2860, 0.4216,  ..., 0.6058, 0.7872, 0.6927],
         [0.3401, 0.9436, 0.9291,  ..., 0.8825, 0.4819, 0.6070],
         [0.2620, 0.0966, 0.5824,  ..., 0.6841, 0.3647, 0.4097],
         ...,
         [0.8482, 0.7748, 0.1538,  ..., 0.5402, 0.0044, 0.3309],
         [0.3356, 0.7458, 0.9334,  ..., 0.2713, 0.4069, 0.0588],
         [0.0311, 0.8886, 0.8130,  ..., 0.3560, 0.3384, 0.7228]],

        [[0.9368, 0.9747, 0.5108,  ..., 0.5

In [47]:
a = [[]]*4
a[0] = 1
a

[1, [], [], []]