In [2]:
import torch
import torch.nn.functional as F

def process_batch_data(batch_data):
    """
    处理输入的批次数据，去除不必要的维度和分类特征，并缩放 patches。
    
    参数:
    - batch_data: 输入批次数据，形状为 [10, 10, 1, 1025, 3200]

    返回:
    - processed_data: 处理后的数据，形状为 [10, 2560, 3200]
    """
    # 1. 移除中间的 1 维度，形状变为 [10, 10, 1025, 3200]
    batch_data = batch_data.squeeze(2)

    # 2. 移除分类特征，保留图像 patches 的特征。形状变为 [10, 10, 1024, 3200]
    patch_data = batch_data[:, :, 1:, :]  # 移除第一个分类特征（位于第 1 个位置）

    # 3. 将 32x32 的图像 patches 特征缩放为 16x16
    # 首先 reshape 成 [batch_size*10, 3200, 32, 32]，然后进行缩放
    patch_data = patch_data.view(patch_data.size(0) * patch_data.size(1), 3200, 32, 32)
    
    # 使用 interpolate 进行缩放
    patch_data = F.interpolate(patch_data, size=(16, 16), mode='bilinear', align_corners=False)

    # 重新 reshape 回 [batch_size, 10, 16*16, 3200]
    patch_data = patch_data.view(batch_data.size(0), batch_data.size(1), 16*16, 3200)

    # 4. 将 256 在序列长度维度上展开，得到形状 [batch_size, 2560, 3200]
    processed_data = patch_data.view(patch_data.size(0), -1, patch_data.size(-1))

    return processed_data

# 示例使用
batch_data = torch.randn(10, 10, 1, 1025, 3200)  # 假设输入数据形状为 [10, 10, 1, 1025, 3200]
processed_data = process_batch_data(batch_data)
print(f"处理后的数据形状: {processed_data.shape}")

处理后的数据形状: torch.Size([10, 10, 256, 3200])


In [None]:
m = [[3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 1, 1, 1, 1, 1, 1]]

In [None]:
len(m)