In [1]:
import argparse
from omegaconf import OmegaConf
from GOT.utils.arguments import *


parser = argparse.ArgumentParser()
parser.add_argument(
    "--configs",
    nargs="*",
    default=["/apps/GOT-OCR2.0/configs/got_test.yaml"],
    help="Path to the config file",
)
parser.add_argument('--local_rank', type=int, default=-1,
                    help='Used for distributed training')  # ✅ 添加这一行
args = parser.parse_args([])

config_list = [OmegaConf.load(c) for c in args.configs]
config = OmegaConf.merge(*config_list)
# model_args, data_args, training_args = parser.parse_yaml_file(
#     "configs/got.yaml")
# config = OmegaConf.load("configs/got.yaml")
# 分别提取字段构造 dataclass
model_args = ModelArguments(
    **{k: v for k, v in config.items() if k in ModelArguments.__dataclass_fields__}
)
data_args = DataArguments(
    **{k: v for k, v in config.items() if k in DataArguments.__dataclass_fields__}
)
training_args = TrainingArguments(
    **{k: v for k, v in config.items() if k in TrainingArguments.__dataclass_fields__}
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from GOT.model import *
# 假设模型已实例化并在cuda
model = GOTQwenForCausalLM.from_pretrained(
    "/data_8t_1/qby/GOT-OCR2_0", use_safetensors=True)
model.to("cuda")
# model.eval()

GOTQwenForCausalLM(
  (model): GOTQwenModel(
    (embed_tokens): Embedding(151860, 1024)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (down_proj): Linear(in_features=2816, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1024,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1024,), eps=1e-06)
    (rotary_em

In [3]:
vision_tower_dict = model.get_model().initialize_vision_modules(
    vision_tower=model_args.vision_tower,
    pretrained_stage1_model=model_args.pretrained_stage1_model,
    freeze_vision_tower=model_args.freeze_vision_tower,
    use_im_start_end=model_args.use_im_start_end,
    vision_select_layer=model_args.vision_select_layer,
    device=training_args.device
)

In [4]:

from GOT.data.conversation_dataset_qwen import ConversationDataset
from transformers import AutoTokenizer
dataset_cls = ConversationDataset
tokenizer=AutoTokenizer.from_pretrained(
    "/data_8t_1/qby/GOT-OCR2_0",
    trust_remote_code=True,
    # use_fast=False,
    # revision="v1.0.0"
)
data_args.image_token_len = 256
data_args.image_processor = vision_tower_dict['image_processor']
data_args.image_processor_high = vision_tower_dict['image_processor_high']
data_args.use_im_start_end = model_args.use_im_start_end
train_dataset = dataset_cls(
    tokenizer=tokenizer,
    datasets=data_args.train_datasets,
    multimodal_cfg=dict(
        sep_image_conv_front=data_args.sep_image_conv_front,
        image_token_len=data_args.image_token_len,
        image_aspect_ratio=data_args.image_aspect_ratio,
        use_im_start_end=data_args.use_im_start_end,
        image_processor=data_args.image_processor,
        image_processor_high=data_args.image_processor_high,
        box_limit=data_args.box_limit,
    )
)



In [5]:
print("train_dataset:", train_dataset)

train_dataset: <GOT.data.conversation_dataset_qwen.ConversationDataset object at 0x7f7a4af3fdc0>


In [6]:
# print("train_dataset[0]:",train_dataset[0])
for k, v in train_dataset[0].items():
    print(f"{k}: {v.shape if isinstance(v, torch.Tensor) else len(v)}")

data_dict_ori torch.Size([417])
data_dict torch.Size([417])
input_ids: torch.Size([417])
labels: torch.Size([417])
image: 1
image_high: 1


In [7]:
train_dataset

<GOT.data.conversation_dataset_qwen.ConversationDataset at 0x7f7a4af3fdc0>

In [8]:
# from torch.utils.data import Dataset, DataLoader
# from GOT.data import DataCollatorForSupervisedDataset
# data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
# dataloader = DataLoader(train_dataset, batch_size=16, collate_fn=data_collator)

In [9]:
# dict(train_dataset)

In [10]:
# dataloader

In [11]:
# for batch in dataloader:
#     print(batch)
#     break

In [12]:
# # 假设你已经定义好了 dataloader
# batch = next(iter(dataloader))

# # 打印内容
# print(batch.keys())  # 查看有哪些字段
# print(batch['input_ids'].shape)
# print(batch['labels'].shape)
# print(batch['attention_mask'].shape)
# print(len(batch['images']))  # 查看图片数量
# print(len(batch['images'][0]))  # 查看第一张图片的数量
# print(batch['images'][0][0].shape)  # 查看第一张图片的形状
# print(batch['images'][0][1].shape)  # 查看第一张图片的
# # print((batch['images'][0][0]))

In [None]:
from collections import defaultdict
import torch
from tqdm import tqdm

# 初始化记录字典
field_shapes = defaultdict(list)

# 遍历整个数据集
for example in tqdm(train_dataset):
    for k, v in example.items():
        if isinstance(v, torch.Tensor):
            field_shapes[k].append(v.shape[0])
        elif isinstance(v, list) or isinstance(v, str):
            field_shapes[k].append(len(v))
        else:
            field_shapes[k].append(type(v))

# 打印每个字段的统计结果
for k, v_list in field_shapes.items():
    unique_vals = set(v_list)
    print(f"{k}:")
    print(f"  Unique shapes/lengths: {unique_vals}")
    print(f"  Max: {max(v_list)}, Min: {min(v_list)}, Avg: {sum(v_list)/len(v_list):.2f}" if all(
        isinstance(x, int) for x in v_list) else "")

  1%|          | 5/992 [00:00<00:20, 47.13it/s]

data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([483])
data_dict torch.Size([483])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([392])
data_dict torch.Size([392])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])


  1%|          | 11/992 [00:00<00:19, 50.64it/s]

data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([531])
data_dict torch.Size([531])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])


  2%|▏         | 17/992 [00:00<00:18, 51.48it/s]

data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])
data_dict_ori torch.Size([396])
data_dict torch.Size([396])


  2%|▏         | 23/992 [00:00<00:18, 51.36it/s]

data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([394])
data_dict torch.Size([394])
data_dict_ori torch.Size([376])
data_dict torch.Size([376])
data_dict_ori torch.Size([385])
data_dict torch.Size([385])


  3%|▎         | 29/992 [00:00<00:18, 52.18it/s]

data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([422])
data_dict torch.Size([422])
data_dict_ori torch.Size([425])
data_dict torch.Size([425])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([463])
data_dict torch.Size([463])


  4%|▎         | 35/992 [00:00<00:18, 52.86it/s]

data_dict_ori torch.Size([464])
data_dict torch.Size([464])
data_dict_ori torch.Size([386])
data_dict torch.Size([386])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([396])
data_dict torch.Size([396])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([436])
data_dict torch.Size([436])


  4%|▍         | 41/992 [00:00<00:18, 52.61it/s]

data_dict_ori torch.Size([422])
data_dict torch.Size([422])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([454])
data_dict torch.Size([454])


  5%|▍         | 47/992 [00:00<00:18, 50.15it/s]

data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([387])
data_dict torch.Size([387])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([443])
data_dict torch.Size([443])
data_dict_ori torch.Size([381])
data_dict torch.Size([381])
data_dict_ori torch.Size([418])
data_dict torch.Size([418])


  5%|▌         | 53/992 [00:01<00:18, 51.13it/s]

data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([396])
data_dict torch.Size([396])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])


  6%|▌         | 59/992 [00:01<00:18, 51.34it/s]

data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([420])
data_dict torch.Size([420])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])


  7%|▋         | 65/992 [00:01<00:18, 50.44it/s]

data_dict_ori torch.Size([426])
data_dict torch.Size([426])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([453])
data_dict torch.Size([453])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])
data_dict_ori torch.Size([410])
data_dict torch.Size([410])
data_dict_ori torch.Size([429])
data_dict torch.Size([429])


  7%|▋         | 71/992 [00:01<00:18, 50.93it/s]

data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])


  8%|▊         | 77/992 [00:01<00:17, 51.31it/s]

data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([482])
data_dict torch.Size([482])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([426])
data_dict torch.Size([426])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])


  8%|▊         | 83/992 [00:01<00:17, 51.38it/s]

data_dict_ori torch.Size([454])
data_dict torch.Size([454])
data_dict_ori torch.Size([398])
data_dict torch.Size([398])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([464])
data_dict torch.Size([464])
data_dict_ori torch.Size([432])
data_dict torch.Size([432])
data_dict_ori torch.Size([384])
data_dict torch.Size([384])


  9%|▉         | 89/992 [00:01<00:17, 50.56it/s]

data_dict_ori torch.Size([425])
data_dict torch.Size([425])
data_dict_ori torch.Size([410])
data_dict torch.Size([410])
data_dict_ori torch.Size([444])
data_dict torch.Size([444])
data_dict_ori torch.Size([456])
data_dict torch.Size([456])
data_dict_ori torch.Size([392])
data_dict torch.Size([392])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])


 10%|▉         | 95/992 [00:01<00:17, 50.62it/s]

data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([487])
data_dict torch.Size([487])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])


 10%|█         | 101/992 [00:01<00:17, 50.43it/s]

data_dict_ori torch.Size([429])
data_dict torch.Size([429])
data_dict_ori torch.Size([429])
data_dict torch.Size([429])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])
data_dict_ori torch.Size([430])
data_dict torch.Size([430])
data_dict_ori torch.Size([437])
data_dict torch.Size([437])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])


 11%|█         | 107/992 [00:02<00:17, 51.04it/s]

data_dict_ori torch.Size([449])
data_dict torch.Size([449])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([478])
data_dict torch.Size([478])


 11%|█▏        | 113/992 [00:02<00:17, 51.22it/s]

data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([382])
data_dict torch.Size([382])
data_dict_ori torch.Size([435])
data_dict torch.Size([435])
data_dict_ori torch.Size([439])
data_dict torch.Size([439])


 12%|█▏        | 119/992 [00:02<00:16, 51.99it/s]

data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([446])
data_dict torch.Size([446])
data_dict_ori torch.Size([390])
data_dict torch.Size([390])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])


 13%|█▎        | 125/992 [00:02<00:16, 52.25it/s]

data_dict_ori torch.Size([446])
data_dict torch.Size([446])
data_dict_ori torch.Size([457])
data_dict torch.Size([457])
data_dict_ori torch.Size([486])
data_dict torch.Size([486])
data_dict_ori torch.Size([385])
data_dict torch.Size([385])
data_dict_ori torch.Size([391])
data_dict torch.Size([391])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])


 13%|█▎        | 131/992 [00:02<00:16, 52.20it/s]

data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([436])
data_dict torch.Size([436])
data_dict_ori torch.Size([450])
data_dict torch.Size([450])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])
data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([392])
data_dict torch.Size([392])
data_dict_ori torch.Size([382])
data_dict torch.Size([382])


 14%|█▍        | 137/992 [00:02<00:16, 51.85it/s]

data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([420])
data_dict torch.Size([420])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])


 14%|█▍        | 143/992 [00:02<00:16, 51.79it/s]

data_dict_ori torch.Size([459])
data_dict torch.Size([459])
data_dict_ori torch.Size([445])
data_dict torch.Size([445])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([444])
data_dict torch.Size([444])
data_dict_ori torch.Size([450])
data_dict torch.Size([450])


 15%|█▌        | 149/992 [00:02<00:17, 49.34it/s]

data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([442])
data_dict torch.Size([442])
data_dict_ori torch.Size([394])
data_dict torch.Size([394])


 16%|█▌        | 154/992 [00:03<00:16, 49.30it/s]

data_dict_ori torch.Size([450])
data_dict torch.Size([450])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([418])
data_dict torch.Size([418])
data_dict_ori torch.Size([445])
data_dict torch.Size([445])
data_dict_ori torch.Size([390])
data_dict torch.Size([390])


 16%|█▌        | 159/992 [00:03<00:16, 49.06it/s]

data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([398])
data_dict torch.Size([398])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([393])
data_dict torch.Size([393])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([551])
data_dict torch.Size([551])


 17%|█▋        | 165/992 [00:03<00:16, 49.46it/s]

data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])
data_dict_ori torch.Size([444])
data_dict torch.Size([444])


 17%|█▋        | 171/992 [00:03<00:16, 49.45it/s]

data_dict_ori torch.Size([401])
data_dict torch.Size([401])
data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([451])
data_dict torch.Size([451])
data_dict_ori torch.Size([410])
data_dict torch.Size([410])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])


 18%|█▊        | 177/992 [00:03<00:16, 49.83it/s]

data_dict_ori torch.Size([418])
data_dict torch.Size([418])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([521])
data_dict torch.Size([521])
data_dict_ori torch.Size([442])
data_dict torch.Size([442])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])


 18%|█▊        | 183/992 [00:03<00:16, 50.10it/s]

data_dict_ori torch.Size([445])
data_dict torch.Size([445])
data_dict_ori torch.Size([373])
data_dict torch.Size([373])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([529])
data_dict torch.Size([529])
data_dict_ori torch.Size([410])
data_dict torch.Size([410])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([420])
data_dict torch.Size([420])
data_dict_ori torch.Size([474])
data_dict torch.Size([474])


 19%|█▉        | 189/992 [00:03<00:16, 49.82it/s]

data_dict_ori torch.Size([425])
data_dict torch.Size([425])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])


 20%|█▉        | 195/992 [00:03<00:15, 50.50it/s]

data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([445])
data_dict torch.Size([445])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([429])
data_dict torch.Size([429])
data_dict_ori torch.Size([445])
data_dict torch.Size([445])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([432])
data_dict torch.Size([432])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([548])
data_dict torch.Size([548])


 20%|██        | 201/992 [00:03<00:15, 50.21it/s]

data_dict_ori torch.Size([429])
data_dict torch.Size([429])
data_dict_ori torch.Size([437])
data_dict torch.Size([437])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([384])
data_dict torch.Size([384])
data_dict_ori torch.Size([435])
data_dict torch.Size([435])


 21%|██        | 207/992 [00:04<00:15, 49.34it/s]

data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([384])
data_dict torch.Size([384])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([458])
data_dict torch.Size([458])


 21%|██▏       | 212/992 [00:04<00:15, 49.27it/s]

data_dict_ori torch.Size([442])
data_dict torch.Size([442])
data_dict_ori torch.Size([439])
data_dict torch.Size([439])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([391])
data_dict torch.Size([391])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])


 22%|██▏       | 218/992 [00:04<00:15, 49.67it/s]

data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([434])
data_dict torch.Size([434])
data_dict_ori torch.Size([432])
data_dict torch.Size([432])


 22%|██▏       | 223/992 [00:04<00:15, 49.36it/s]

data_dict_ori torch.Size([506])
data_dict torch.Size([506])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([394])
data_dict torch.Size([394])
data_dict_ori torch.Size([430])
data_dict torch.Size([430])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])


 23%|██▎       | 229/992 [00:04<00:15, 50.06it/s]

data_dict_ori torch.Size([396])
data_dict torch.Size([396])
data_dict_ori torch.Size([418])
data_dict torch.Size([418])
data_dict_ori torch.Size([445])
data_dict torch.Size([445])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([376])
data_dict torch.Size([376])


 24%|██▎       | 235/992 [00:04<00:14, 50.84it/s]

data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([431])
data_dict torch.Size([431])
data_dict_ori torch.Size([450])
data_dict torch.Size([450])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])


 24%|██▍       | 241/992 [00:04<00:14, 50.16it/s]

data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([455])
data_dict torch.Size([455])
data_dict_ori torch.Size([482])
data_dict torch.Size([482])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([432])
data_dict torch.Size([432])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])


 25%|██▍       | 247/992 [00:04<00:14, 51.48it/s]

data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([425])
data_dict torch.Size([425])
data_dict_ori torch.Size([439])
data_dict torch.Size([439])
data_dict_ori torch.Size([430])
data_dict torch.Size([430])
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([495])
data_dict torch.Size([495])


 26%|██▌       | 253/992 [00:05<00:15, 48.74it/s]

data_dict_ori torch.Size([451])
data_dict torch.Size([451])
data_dict_ori torch.Size([527])
data_dict torch.Size([527])


 26%|██▌       | 259/992 [00:05<00:14, 49.92it/s]

data_dict_ori torch.Size([451])
data_dict torch.Size([451])
data_dict_ori torch.Size([422])
data_dict torch.Size([422])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([398])
data_dict torch.Size([398])
data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([384])
data_dict torch.Size([384])
data_dict_ori torch.Size([382])
data_dict torch.Size([382])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([483])
data_dict torch.Size([483])
data_dict_ori torch.Size([471])
data_dict torch.Size([471])


 27%|██▋       | 265/992 [00:05<00:14, 50.24it/s]

data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([372])
data_dict torch.Size([372])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([422])
data_dict torch.Size([422])


 27%|██▋       | 271/992 [00:05<00:14, 50.53it/s]

data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([439])
data_dict torch.Size([439])


 28%|██▊       | 277/992 [00:05<00:14, 50.20it/s]

data_dict_ori torch.Size([438])
data_dict torch.Size([438])
data_dict_ori torch.Size([460])
data_dict torch.Size([460])
data_dict_ori torch.Size([479])
data_dict torch.Size([479])
data_dict_ori torch.Size([457])
data_dict torch.Size([457])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])


 29%|██▊       | 283/992 [00:05<00:14, 49.78it/s]

data_dict_ori torch.Size([393])
data_dict torch.Size([393])
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([425])
data_dict torch.Size([425])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([398])
data_dict torch.Size([398])


 29%|██▉       | 289/992 [00:05<00:13, 50.62it/s]

data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([441])
data_dict torch.Size([441])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])
data_dict_ori torch.Size([386])
data_dict torch.Size([386])
data_dict_ori torch.Size([390])
data_dict torch.Size([390])
data_dict_ori torch.Size([445])
data_dict torch.Size([445])


 30%|██▉       | 295/992 [00:05<00:13, 51.50it/s]

data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([452])
data_dict torch.Size([452])


 30%|███       | 301/992 [00:05<00:13, 51.14it/s]

data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([441])
data_dict torch.Size([441])
data_dict_ori torch.Size([437])
data_dict torch.Size([437])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([469])
data_dict torch.Size([469])
data_dict_ori torch.Size([387])
data_dict torch.Size([387])


 31%|███▏      | 312/992 [00:06<00:13, 49.64it/s]

data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([434])
data_dict torch.Size([434])
data_dict_ori torch.Size([456])
data_dict torch.Size([456])
data_dict_ori torch.Size([440])
data_dict torch.Size([440])
data_dict_ori torch.Size([451])
data_dict torch.Size([451])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([480])
data_dict torch.Size([480])
data_dict_ori torch.Size([445])
data_dict torch.Size([445])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])
data_dict_ori torch.Size([442])
data_dict torch.Size([442])


 32%|███▏      | 322/992 [00:06<00:14, 47.64it/s]

data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([420])
data_dict torch.Size([420])
data_dict_ori torch.Size([469])
data_dict torch.Size([469])
data_dict_ori torch.Size([422])
data_dict torch.Size([422])
data_dict_ori torch.Size([462])
data_dict torch.Size([462])
data_dict_ori torch.Size([472])
data_dict torch.Size([472])
data_dict_ori torch.Size([475])
data_dict torch.Size([475])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])


 33%|███▎      | 332/992 [00:06<00:13, 47.38it/s]

data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([467])
data_dict torch.Size([467])
data_dict_ori torch.Size([487])
data_dict torch.Size([487])
data_dict_ori torch.Size([531])
data_dict torch.Size([531])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([422])
data_dict torch.Size([422])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])


 34%|███▍      | 342/992 [00:06<00:13, 47.33it/s]

data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([496])
data_dict torch.Size([496])
data_dict_ori torch.Size([467])
data_dict torch.Size([467])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([472])
data_dict torch.Size([472])
data_dict_ori torch.Size([477])
data_dict torch.Size([477])
data_dict_ori torch.Size([454])
data_dict torch.Size([454])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([450])
data_dict torch.Size([450])
data_dict_ori torch.Size([382])
data_dict torch.Size([382])


 35%|███▌      | 352/992 [00:07<00:13, 47.90it/s]

data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([554])
data_dict torch.Size([554])
data_dict_ori torch.Size([429])
data_dict torch.Size([429])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])
data_dict_ori torch.Size([456])
data_dict torch.Size([456])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([446])
data_dict torch.Size([446])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])


 36%|███▌      | 357/992 [00:07<00:13, 47.82it/s]

data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([393])
data_dict torch.Size([393])
data_dict_ori torch.Size([490])
data_dict torch.Size([490])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([410])
data_dict torch.Size([410])


 37%|███▋      | 363/992 [00:07<00:12, 48.44it/s]

data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([429])
data_dict torch.Size([429])
data_dict_ori torch.Size([377])
data_dict torch.Size([377])
data_dict_ori torch.Size([491])
data_dict torch.Size([491])


 37%|███▋      | 368/992 [00:07<00:13, 47.95it/s]

data_dict_ori torch.Size([431])
data_dict torch.Size([431])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([430])
data_dict torch.Size([430])
data_dict_ori torch.Size([515])
data_dict torch.Size([515])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])


 38%|███▊      | 373/992 [00:07<00:12, 48.08it/s]

data_dict_ori torch.Size([392])
data_dict torch.Size([392])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([425])
data_dict torch.Size([425])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])


 38%|███▊      | 379/992 [00:07<00:12, 48.87it/s]

data_dict_ori torch.Size([396])
data_dict torch.Size([396])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])
data_dict_ori torch.Size([418])
data_dict torch.Size([418])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([412])
data_dict torch.Size([412])


 39%|███▉      | 385/992 [00:07<00:12, 49.71it/s]

data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([459])
data_dict torch.Size([459])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])


 39%|███▉      | 391/992 [00:07<00:11, 50.76it/s]

data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([369])
data_dict torch.Size([369])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([511])
data_dict torch.Size([511])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])


 40%|████      | 397/992 [00:07<00:11, 51.10it/s]

data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([487])
data_dict torch.Size([487])


 41%|████      | 403/992 [00:08<00:11, 51.15it/s]

data_dict_ori torch.Size([452])
data_dict torch.Size([452])
data_dict_ori torch.Size([391])
data_dict torch.Size([391])
data_dict_ori torch.Size([386])
data_dict torch.Size([386])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])
data_dict_ori torch.Size([426])
data_dict torch.Size([426])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])


 41%|████      | 409/992 [00:08<00:11, 50.34it/s]

data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([383])
data_dict torch.Size([383])


 42%|████▏     | 415/992 [00:08<00:11, 49.47it/s]

data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([518])
data_dict torch.Size([518])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([418])
data_dict torch.Size([418])
data_dict_ori torch.Size([436])
data_dict torch.Size([436])
data_dict_ori torch.Size([461])
data_dict torch.Size([461])
data_dict_ori torch.Size([473])
data_dict torch.Size([473])
data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])


 42%|████▏     | 420/992 [00:08<00:11, 48.65it/s]

data_dict_ori torch.Size([431])
data_dict torch.Size([431])
data_dict_ori torch.Size([454])
data_dict torch.Size([454])
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([434])
data_dict torch.Size([434])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([410])
data_dict torch.Size([410])


 43%|████▎     | 426/992 [00:08<00:11, 49.53it/s]

data_dict_ori torch.Size([503])
data_dict torch.Size([503])
data_dict_ori torch.Size([410])
data_dict torch.Size([410])
data_dict_ori torch.Size([439])
data_dict torch.Size([439])
data_dict_ori torch.Size([376])
data_dict torch.Size([376])
data_dict_ori torch.Size([444])
data_dict torch.Size([444])


 43%|████▎     | 431/992 [00:08<00:11, 49.03it/s]

data_dict_ori torch.Size([449])
data_dict torch.Size([449])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([462])
data_dict torch.Size([462])
data_dict_ori torch.Size([425])
data_dict torch.Size([425])


 44%|████▍     | 437/992 [00:08<00:11, 50.04it/s]

data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([396])
data_dict torch.Size([396])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([466])
data_dict torch.Size([466])


 45%|████▍     | 443/992 [00:08<00:10, 50.32it/s]

data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([478])
data_dict torch.Size([478])
data_dict_ori torch.Size([375])
data_dict torch.Size([375])
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([453])
data_dict torch.Size([453])
data_dict_ori torch.Size([426])
data_dict torch.Size([426])


 45%|████▌     | 449/992 [00:08<00:10, 49.90it/s]

data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([394])
data_dict torch.Size([394])
data_dict_ori torch.Size([447])
data_dict torch.Size([447])


 46%|████▌     | 454/992 [00:09<00:10, 49.65it/s]

data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([443])
data_dict torch.Size([443])
data_dict_ori torch.Size([384])
data_dict torch.Size([384])
data_dict_ori torch.Size([538])
data_dict torch.Size([538])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])


 46%|████▋     | 460/992 [00:09<00:10, 49.90it/s]

data_dict_ori torch.Size([587])
data_dict torch.Size([587])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])


 47%|████▋     | 466/992 [00:09<00:10, 50.84it/s]

data_dict_ori torch.Size([398])
data_dict torch.Size([398])
data_dict_ori torch.Size([430])
data_dict torch.Size([430])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([436])
data_dict torch.Size([436])
data_dict_ori torch.Size([482])
data_dict torch.Size([482])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])


 48%|████▊     | 472/992 [00:09<00:10, 49.06it/s]

data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])


 48%|████▊     | 477/992 [00:09<00:10, 49.09it/s]

data_dict_ori torch.Size([426])
data_dict torch.Size([426])
data_dict_ori torch.Size([468])
data_dict torch.Size([468])
data_dict_ori torch.Size([505])
data_dict torch.Size([505])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([390])
data_dict torch.Size([390])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([391])
data_dict torch.Size([391])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([476])
data_dict torch.Size([476])
data_dict_ori torch.Size([398])
data_dict torch.Size([398])


 49%|████▉     | 489/992 [00:09<00:10, 50.16it/s]

data_dict_ori torch.Size([472])
data_dict torch.Size([472])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([446])
data_dict torch.Size([446])
data_dict_ori torch.Size([464])
data_dict torch.Size([464])
data_dict_ori torch.Size([391])
data_dict torch.Size([391])
data_dict_ori torch.Size([447])
data_dict torch.Size([447])
data_dict_ori torch.Size([475])
data_dict torch.Size([475])
data_dict_ori torch.Size([454])
data_dict torch.Size([454])
data_dict_ori torch.Size([426])
data_dict torch.Size([426])


 50%|████▉     | 495/992 [00:09<00:10, 49.15it/s]

data_dict_ori torch.Size([434])
data_dict torch.Size([434])
data_dict_ori torch.Size([441])
data_dict torch.Size([441])
data_dict_ori torch.Size([435])
data_dict torch.Size([435])
data_dict_ori torch.Size([392])
data_dict torch.Size([392])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([422])
data_dict torch.Size([422])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])


 51%|█████     | 501/992 [00:10<00:09, 50.13it/s]

data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([382])
data_dict torch.Size([382])


 51%|█████     | 507/992 [00:10<00:10, 48.05it/s]

data_dict_ori torch.Size([596])
data_dict torch.Size([596])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])
data_dict_ori torch.Size([426])
data_dict torch.Size([426])
data_dict_ori torch.Size([430])
data_dict torch.Size([430])
data_dict_ori torch.Size([501])
data_dict torch.Size([501])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])


 52%|█████▏    | 513/992 [00:10<00:09, 48.94it/s]

data_dict_ori torch.Size([504])
data_dict torch.Size([504])


 52%|█████▏    | 519/992 [00:10<00:09, 49.53it/s]

data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([447])
data_dict torch.Size([447])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])
data_dict_ori torch.Size([440])
data_dict torch.Size([440])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([391])
data_dict torch.Size([391])
data_dict_ori torch.Size([381])
data_dict torch.Size([381])
data_dict_ori torch.Size([491])
data_dict torch.Size([491])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])


 53%|█████▎    | 524/992 [00:10<00:09, 49.64it/s]

data_dict_ori torch.Size([417])
data_dict torch.Size([417])


 53%|█████▎    | 530/992 [00:10<00:09, 51.15it/s]

data_dict_ori torch.Size([396])
data_dict torch.Size([396])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([482])
data_dict torch.Size([482])
data_dict_ori torch.Size([434])
data_dict torch.Size([434])
data_dict_ori torch.Size([387])
data_dict torch.Size([387])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([496])
data_dict torch.Size([496])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([429])
data_dict torch.Size([429])
data_dict_ori torch.Size([425])
data_dict torch.Size([425])


 55%|█████▍    | 542/992 [00:10<00:09, 49.57it/s]

data_dict_ori torch.Size([420])
data_dict torch.Size([420])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([420])
data_dict torch.Size([420])
data_dict_ori torch.Size([448])
data_dict torch.Size([448])
data_dict_ori torch.Size([383])
data_dict torch.Size([383])
data_dict_ori torch.Size([390])
data_dict torch.Size([390])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([454])
data_dict torch.Size([454])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])


 56%|█████▌    | 552/992 [00:11<00:08, 49.21it/s]

data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([509])
data_dict torch.Size([509])
data_dict_ori torch.Size([455])
data_dict torch.Size([455])
data_dict_ori torch.Size([396])
data_dict torch.Size([396])
data_dict_ori torch.Size([514])
data_dict torch.Size([514])
data_dict_ori torch.Size([392])
data_dict torch.Size([392])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([396])
data_dict torch.Size([396])
data_dict_ori torch.Size([435])
data_dict torch.Size([435])
data_dict_ori torch.Size([469])
data_dict torch.Size([469])


 57%|█████▋    | 563/992 [00:11<00:08, 49.68it/s]

data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([506])
data_dict torch.Size([506])
data_dict_ori torch.Size([439])
data_dict torch.Size([439])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])
data_dict_ori torch.Size([462])
data_dict torch.Size([462])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])


 57%|█████▋    | 568/992 [00:11<00:08, 48.68it/s]

data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([385])
data_dict torch.Size([385])
data_dict_ori torch.Size([393])
data_dict torch.Size([393])
data_dict_ori torch.Size([420])
data_dict torch.Size([420])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([392])
data_dict torch.Size([392])
data_dict_ori torch.Size([436])
data_dict torch.Size([436])
data_dict_ori torch.Size([434])
data_dict torch.Size([434])


 58%|█████▊    | 574/992 [00:11<00:08, 49.41it/s]

data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori

 58%|█████▊    | 580/992 [00:11<00:08, 48.18it/s]

 torch.Size([394])
data_dict torch.Size([394])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([435])
data_dict torch.Size([435])
image 000031174.png are broken or grayscale! we thus select 0-th sample instead!
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])
data_dict_ori torch.Size([432])
data_dict torch.Size([432])
data_dict_ori torch.Size([448])
data_dict torch.Size([448])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([469])
data_dict torch.Size([469])


 60%|█████▉    | 591/992 [00:11<00:08, 49.42it/s]

data_dict_ori torch.Size([437])
data_dict torch.Size([437])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([365])
data_dict torch.Size([365])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])
data_dict_ori torch.Size([431])
data_dict torch.Size([431])
data_dict_ori torch.Size([482])
data_dict torch.Size([482])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])


 61%|██████    | 603/992 [00:12<00:07, 50.86it/s]

data_dict_ori torch.Size([443])
data_dict torch.Size([443])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([531])
data_dict torch.Size([531])
data_dict_ori torch.Size([394])
data_dict torch.Size([394])
data_dict_ori torch.Size([431])
data_dict torch.Size([431])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([476])
data_dict torch.Size([476])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])


 62%|██████▏   | 615/992 [00:12<00:07, 50.23it/s]

data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])
data_dict_ori torch.Size([432])
data_dict torch.Size([432])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([476])
data_dict torch.Size([476])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])
data_dict_ori torch.Size([474])
data_dict torch.Size([474])
data_dict_ori torch.Size([442])
data_dict torch.Size([442])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])


 63%|██████▎   | 621/992 [00:12<00:07, 50.52it/s]

data_dict_ori torch.Size([443])
data_dict torch.Size([443])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([479])
data_dict torch.Size([479])
data_dict_ori torch.Size([478])
data_dict torch.Size([478])
data_dict_ori torch.Size([448])
data_dict torch.Size([448])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])


 63%|██████▎   | 627/992 [00:12<00:07, 50.34it/s]

data_dict_ori torch.Size([430])
data_dict torch.Size([430])
data_dict_ori torch.Size([488])
data_dict torch.Size([488])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])


 64%|██████▍   | 633/992 [00:12<00:07, 49.67it/s]

data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])
data_dict_ori torch.Size([386])
data_dict torch.Size([386])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([396])
data_dict torch.Size([396])


 64%|██████▍   | 639/992 [00:12<00:07, 50.18it/s]

data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])


 65%|██████▌   | 645/992 [00:12<00:07, 48.47it/s]

data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([458])
data_dict torch.Size([458])
data_dict_ori torch.Size([435])
data_dict torch.Size([435])
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([398])
data_dict torch.Size([398])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([446])
data_dict torch.Size([446])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])


 66%|██████▌   | 657/992 [00:13<00:06, 50.63it/s]

data_dict_ori torch.Size([441])
data_dict torch.Size([441])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])
data_dict_ori torch.Size([444])
data_dict torch.Size([444])
data_dict_ori torch.Size([451])
data_dict torch.Size([451])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])
data_dict_ori torch.Size([448])
data_dict torch.Size([448])
data_dict_ori torch.Size([391])
data_dict torch.Size([391])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])


 67%|██████▋   | 663/992 [00:13<00:06, 50.74it/s]

data_dict_ori torch.Size([466])
data_dict torch.Size([466])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([422])
data_dict torch.Size([422])
data_dict_ori torch.Size([432])
data_dict torch.Size([432])
data_dict_ori torch.Size([471])
data_dict torch.Size([471])
data_dict_ori torch.Size([518])
data_dict torch.Size([518])
data_dict_ori torch.Size([394])
data_dict torch.Size([394])


 67%|██████▋   | 669/992 [00:13<00:06, 50.50it/s]

data_dict_ori torch.Size([475])
data_dict torch.Size([475])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([434])
data_dict torch.Size([434])
data_dict_ori torch.Size([442])
data_dict torch.Size([442])


 68%|██████▊   | 675/992 [00:13<00:06, 49.48it/s]

data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([471])
data_dict torch.Size([471])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])
data_dict_ori torch.Size([456])
data_dict torch.Size([456])
data_dict_ori torch.Size([583])
data_dict torch.Size([583])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])


 69%|██████▊   | 681/992 [00:13<00:06, 50.41it/s]

data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([454])
data_dict torch.Size([454])


 69%|██████▉   | 687/992 [00:13<00:05, 51.02it/s]

data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([421])
data_dict torch.Size([421])


 70%|██████▉   | 693/992 [00:13<00:06, 49.46it/s]

data_dict_ori torch.Size([410])
data_dict torch.Size([410])
data_dict_ori torch.Size([431])
data_dict torch.Size([431])
data_dict_ori torch.Size([432])
data_dict torch.Size([432])
data_dict_ori torch.Size([501])
data_dict torch.Size([501])


 70%|███████   | 698/992 [00:13<00:05, 49.31it/s]

data_dict_ori torch.Size([428])
data_dict torch.Size([428])
data_dict_ori torch.Size([450])
data_dict torch.Size([450])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])


 71%|███████   | 704/992 [00:14<00:05, 50.34it/s]

data_dict_ori torch.Size([410])
data_dict torch.Size([410])
data_dict_ori torch.Size([391])
data_dict torch.Size([391])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])


 72%|███████▏  | 710/992 [00:14<00:05, 50.46it/s]

data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([437])
data_dict torch.Size([437])
data_dict_ori torch.Size([451])
data_dict torch.Size([451])
data_dict_ori torch.Size([378])
data_dict torch.Size([378])
data_dict_ori torch.Size([387])
data_dict torch.Size([387])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])
data_dict_ori torch.Size([398])
data_dict torch.Size([398])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([467])
data_dict torch.Size([467])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])


 72%|███████▏  | 716/992 [00:14<00:05, 50.40it/s]

data_dict_ori torch.Size([503])
data_dict torch.Size([503])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([494])
data_dict torch.Size([494])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([402])
data_dict torch.Size([402])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])


 73%|███████▎  | 722/992 [00:14<00:05, 51.76it/s]

data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([393])
data_dict torch.Size([393])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([366])
data_dict torch.Size([366])
data_dict_ori torch.Size([449])
data_dict torch.Size([449])


 73%|███████▎  | 728/992 [00:14<00:05, 51.90it/s]

data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([379])
data_dict torch.Size([379])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])


 74%|███████▍  | 734/992 [00:14<00:04, 52.25it/s]

data_dict_ori torch.Size([425])
data_dict torch.Size([425])
data_dict_ori torch.Size([383])
data_dict torch.Size([383])
data_dict_ori torch.Size([441])
data_dict torch.Size([441])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])


 75%|███████▍  | 740/992 [00:14<00:04, 52.38it/s]

data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([410])
data_dict torch.Size([410])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([552])
data_dict torch.Size([552])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])


 75%|███████▌  | 746/992 [00:14<00:04, 50.62it/s]

data_dict_ori torch.Size([483])
data_dict torch.Size([483])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([511])
data_dict torch.Size([511])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([448])
data_dict torch.Size([448])


 76%|███████▌  | 752/992 [00:15<00:04, 49.18it/s]

data_dict_ori torch.Size([482])
data_dict torch.Size([482])
data_dict_ori torch.Size([431])
data_dict torch.Size([431])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])
data_dict_ori torch.Size([418])
data_dict torch.Size([418])


 76%|███████▋  | 757/992 [00:15<00:04, 49.04it/s]

data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([454])
data_dict torch.Size([454])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([420])
data_dict torch.Size([420])


 77%|███████▋  | 762/992 [00:15<00:04, 48.96it/s]

data_dict_ori torch.Size([452])
data_dict torch.Size([452])
data_dict_ori torch.Size([470])
data_dict torch.Size([470])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])


 77%|███████▋  | 767/992 [00:15<00:04, 47.56it/s]

data_dict_ori torch.Size([377])
data_dict torch.Size([377])
data_dict_ori torch.Size([480])
data_dict torch.Size([480])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])
data_dict_ori torch.Size([435])
data_dict torch.Size([435])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])
data_dict_ori torch.Size([438])
data_dict torch.Size([438])
data_dict_ori torch.Size([431])
data_dict torch.Size([431])
data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([420])
data_dict torch.Size([420])
data_dict_ori torch.Size([439])
data_dict torch.Size([439])


 78%|███████▊  | 778/992 [00:15<00:04, 48.68it/s]

data_dict_ori torch.Size([519])
data_dict torch.Size([519])
data_dict_ori torch.Size([431])
data_dict torch.Size([431])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([532])
data_dict torch.Size([532])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([507])
data_dict torch.Size([507])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([384])
data_dict torch.Size([384])


 79%|███████▉  | 784/992 [00:15<00:04, 50.10it/s]

data_dict_ori torch.Size([593])
data_dict torch.Size([593])
data_dict_ori torch.Size([370])
data_dict torch.Size([370])
data_dict_ori torch.Size([462])
data_dict torch.Size([462])
data_dict_ori torch.Size([448])
data_dict torch.Size([448])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])
data_dict_ori torch.Size([446])
data_dict torch.Size([446])


 80%|███████▉  | 790/992 [00:15<00:04, 49.64it/s]

data_dict_ori torch.Size([503])
data_dict torch.Size([503])
data_dict_ori torch.Size([452])
data_dict torch.Size([452])
data_dict_ori torch.Size([432])
data_dict torch.Size([432])
data_dict_ori torch.Size([383])
data_dict torch.Size([383])


 80%|████████  | 795/992 [00:15<00:03, 49.25it/s]

data_dict_ori torch.Size([484])
data_dict torch.Size([484])
data_dict_ori torch.Size([386])
data_dict torch.Size([386])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([418])
data_dict torch.Size([418])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])


 81%|████████  | 800/992 [00:16<00:03, 48.28it/s]

data_dict_ori torch.Size([422])
data_dict torch.Size([422])
data_dict_ori torch.Size([425])
data_dict torch.Size([425])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])


 81%|████████▏ | 806/992 [00:16<00:03, 48.34it/s]

data_dict_ori torch.Size([390])
data_dict torch.Size([390])
data_dict_ori torch.Size([516])
data_dict torch.Size([516])
data_dict_ori torch.Size([429])
data_dict torch.Size([429])
data_dict_ori torch.Size([393])
data_dict torch.Size([393])
data_dict_ori torch.Size([410])
data_dict torch.Size([410])
data_dict_ori torch.Size([396])
data_dict torch.Size([396])


 82%|████████▏ | 811/992 [00:16<00:03, 48.37it/s]

data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([371])
data_dict torch.Size([371])
data_dict_ori torch.Size([439])
data_dict torch.Size([439])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])


 82%|████████▏ | 816/992 [00:16<00:03, 47.82it/s]

data_dict_ori torch.Size([373])
data_dict torch.Size([373])
data_dict_ori torch.Size([437])
data_dict torch.Size([437])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([441])
data_dict torch.Size([441])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])


 83%|████████▎ | 821/992 [00:16<00:03, 47.06it/s]

data_dict_ori torch.Size([387])
data_dict torch.Size([387])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([427])
data_dict torch.Size([427])


 83%|████████▎ | 826/992 [00:16<00:03, 47.40it/s]

data_dict_ori torch.Size([480])
data_dict torch.Size([480])
data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([392])
data_dict torch.Size([392])
data_dict_ori torch.Size([430])
data_dict torch.Size([430])
data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])


 84%|████████▍ | 831/992 [00:16<00:03, 47.44it/s]

data_dict_ori torch.Size([432])
data_dict torch.Size([432])
data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([448])
data_dict torch.Size([448])
data_dict_ori torch.Size([493])
data_dict torch.Size([493])


 84%|████████▍ | 837/992 [00:16<00:03, 49.18it/s]

data_dict_ori torch.Size([394])
data_dict torch.Size([394])
data_dict_ori torch.Size([387])
data_dict torch.Size([387])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([596])
data_dict torch.Size([596])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])


 85%|████████▍ | 842/992 [00:16<00:03, 48.53it/s]

data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([550])
data_dict torch.Size([550])


 85%|████████▌ | 847/992 [00:17<00:03, 47.85it/s]

data_dict_ori torch.Size([478])
data_dict torch.Size([478])
data_dict_ori torch.Size([512])
data_dict torch.Size([512])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])
data_dict_ori torch.Size([452])
data_dict torch.Size([452])
data_dict_ori torch.Size([451])
data_dict torch.Size([451])
data_dict_ori torch.Size([488])
data_dict torch.Size([488])


 86%|████████▌ | 852/992 [00:17<00:02, 48.01it/s]

data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([385])
data_dict torch.Size([385])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])


 86%|████████▋ | 857/992 [00:17<00:02, 47.30it/s]

data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([380])
data_dict torch.Size([380])
data_dict_ori torch.Size([384])
data_dict torch.Size([384])
data_dict_ori torch.Size([434])
data_dict torch.Size([434])
data_dict_ori torch.Size([432])
data_dict torch.Size([432])
data_dict_ori torch.Size([410])
data_dict torch.Size([410])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])


 87%|████████▋ | 862/992 [00:17<00:02, 47.94it/s]

data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([456])
data_dict torch.Size([456])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])


 87%|████████▋ | 867/992 [00:17<00:02, 47.44it/s]

data_dict_ori torch.Size([440])
data_dict torch.Size([440])
data_dict_ori torch.Size([425])
data_dict torch.Size([425])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([434])
data_dict torch.Size([434])
data_dict_ori torch.Size([417])
data_dict torch.Size([417])
data_dict_ori torch.Size([460])
data_dict torch.Size([460])


 88%|████████▊ | 873/992 [00:17<00:02, 48.64it/s]

data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])


 89%|████████▊ | 879/992 [00:17<00:02, 49.70it/s]

data_dict_ori torch.Size([457])
data_dict torch.Size([457])
data_dict_ori torch.Size([435])
data_dict torch.Size([435])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([440])
data_dict torch.Size([440])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([435])
data_dict torch.Size([435])


 89%|████████▉ | 884/992 [00:17<00:02, 49.08it/s]

data_dict_ori torch.Size([454])
data_dict torch.Size([454])
data_dict_ori torch.Size([391])
data_dict torch.Size([391])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])


 90%|████████▉ | 890/992 [00:17<00:02, 50.35it/s]

data_dict_ori torch.Size([435])
data_dict torch.Size([435])
data_dict_ori torch.Size([444])
data_dict torch.Size([444])
data_dict_ori torch.Size([419])
data_dict torch.Size([419])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([471])
data_dict torch.Size([471])
data_dict_ori torch.Size([446])
data_dict torch.Size([446])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([420])
data_dict torch.Size([420])
data_dict_ori torch.Size([383])
data_dict torch.Size([383])
data_dict_ori torch.Size([441])
data_dict torch.Size([441])


 91%|█████████ | 902/992 [00:18<00:01, 50.77it/s]

data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([478])
data_dict torch.Size([478])
data_dict_ori torch.Size([399])
data_dict torch.Size([399])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([435])
data_dict torch.Size([435])
data_dict_ori torch.Size([479])
data_dict torch.Size([479])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])


 92%|█████████▏| 914/992 [00:18<00:01, 50.95it/s]

data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([422])
data_dict torch.Size([422])
data_dict_ori torch.Size([390])
data_dict torch.Size([390])
data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([396])
data_dict torch.Size([396])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([460])
data_dict torch.Size([460])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([411])
data_dict torch.Size([411])


 93%|█████████▎| 920/992 [00:18<00:01, 50.31it/s]

data_dict_ori torch.Size([446])
data_dict torch.Size([446])
data_dict_ori torch.Size([394])
data_dict torch.Size([394])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([420])
data_dict torch.Size([420])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])
data_dict_ori torch.Size([385])
data_dict torch.Size([385])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([533])
data_dict torch.Size([533])
data_dict_ori torch.Size([394])
data_dict torch.Size([394])


 93%|█████████▎| 926/992 [00:18<00:01, 50.59it/s]

data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])


 94%|█████████▍| 932/992 [00:18<00:01, 50.62it/s]

data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([471])
data_dict torch.Size([471])
data_dict_ori torch.Size([448])
data_dict torch.Size([448])
data_dict_ori torch.Size([409])
data_dict torch.Size([409])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([451])
data_dict torch.Size([451])
data_dict_ori torch.Size([408])
data_dict torch.Size([408])


 95%|█████████▍| 938/992 [00:18<00:01, 50.13it/s]

data_dict_ori torch.Size([423])
data_dict torch.Size([423])
data_dict_ori torch.Size([398])
data_dict torch.Size([398])
data_dict_ori torch.Size([398])
data_dict torch.Size([398])


 95%|█████████▌| 944/992 [00:18<00:00, 50.53it/s]

data_dict_ori torch.Size([406])
data_dict torch.Size([406])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])
data_dict_ori torch.Size([388])
data_dict torch.Size([388])
data_dict_ori torch.Size([391])
data_dict torch.Size([391])
data_dict_ori torch.Size([396])
data_dict torch.Size([396])
data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([457])
data_dict torch.Size([457])
data_dict_ori torch.Size([461])
data_dict torch.Size([461])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([412])
data_dict torch.Size([412])


 96%|█████████▋| 956/992 [00:19<00:00, 50.44it/s]

data_dict_ori torch.Size([435])
data_dict torch.Size([435])
data_dict_ori torch.Size([434])
data_dict torch.Size([434])
data_dict_ori torch.Size([381])
data_dict torch.Size([381])
data_dict_ori torch.Size([443])
data_dict torch.Size([443])
data_dict_ori torch.Size([444])
data_dict torch.Size([444])
data_dict_ori torch.Size([404])
data_dict torch.Size([404])
data_dict_ori torch.Size([412])
data_dict torch.Size([412])
data_dict_ori torch.Size([418])
data_dict torch.Size([418])
data_dict_ori torch.Size([466])
data_dict torch.Size([466])
data_dict_ori torch.Size([398])
data_dict torch.Size([398])


 97%|█████████▋| 962/992 [00:19<00:00, 48.82it/s]

data_dict_ori torch.Size([424])
data_dict torch.Size([424])
data_dict_ori torch.Size([401])
data_dict torch.Size([401])
data_dict_ori torch.Size([541])
data_dict torch.Size([541])
data_dict_ori torch.Size([446])
data_dict torch.Size([446])
data_dict_ori torch.Size([403])
data_dict torch.Size([403])
data_dict_ori torch.Size([414])
data_dict torch.Size([414])
data_dict_ori torch.Size([468])
data_dict torch.Size([468])


 97%|█████████▋| 967/992 [00:19<00:00, 47.70it/s]

data_dict_ori torch.Size([407])
data_dict torch.Size([407])
data_dict_ori torch.Size([523])
data_dict torch.Size([523])


 98%|█████████▊| 972/992 [00:19<00:00, 48.16it/s]

data_dict_ori torch.Size([422])
data_dict torch.Size([422])
data_dict_ori torch.Size([397])
data_dict torch.Size([397])
data_dict_ori torch.Size([431])
data_dict torch.Size([431])
data_dict_ori torch.Size([416])
data_dict torch.Size([416])
data_dict_ori torch.Size([386])
data_dict torch.Size([386])
data_dict_ori torch.Size([389])
data_dict torch.Size([389])
data_dict_ori torch.Size([415])
data_dict torch.Size([415])
data_dict_ori torch.Size([433])
data_dict torch.Size([433])
data_dict_ori torch.Size([425])
data_dict torch.Size([425])


 99%|█████████▊| 978/992 [00:19<00:00, 49.58it/s]

data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([465])
data_dict torch.Size([465])


 99%|█████████▉| 984/992 [00:19<00:00, 49.93it/s]

data_dict_ori torch.Size([413])
data_dict torch.Size([413])
data_dict_ori torch.Size([504])
data_dict torch.Size([504])
data_dict_ori torch.Size([405])
data_dict torch.Size([405])
data_dict_ori torch.Size([400])
data_dict torch.Size([400])
data_dict_ori torch.Size([448])
data_dict torch.Size([448])
data_dict_ori torch.Size([395])
data_dict torch.Size([395])
data_dict_ori torch.Size([418])
data_dict torch.Size([418])
data_dict_ori torch.Size([392])
data_dict torch.Size([392])


100%|█████████▉| 989/992 [00:19<00:00, 48.58it/s]

data_dict_ori torch.Size([500])
data_dict torch.Size([500])
data_dict_ori torch.Size([407])
data_dict torch.Size([407])


100%|██████████| 992/992 [00:19<00:00, 49.77it/s]

data_dict_ori torch.Size([436])
data_dict torch.Size([436])
data_dict_ori torch.Size([390])
data_dict torch.Size([390])
data_dict_ori torch.Size([428])
data_dict torch.Size([428])
input_ids:
  Unique shapes/lengths: {512, 514, 515, 516, 518, 519, 521, 523, 527, 529, 531, 532, 533, 538, 541, 548, 550, 551, 552, 554, 583, 587, 593, 596, 365, 366, 369, 370, 371, 372, 373, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 486, 487, 488, 490, 491, 493, 494, 495, 496, 500, 501, 503, 504, 505, 506, 507, 509




: 

In [None]:
# from collections import defaultdict
# import torch
# from tqdm import tqdm

# # 初始化记录字典，保存长度和对应index
# field_max = defaultdict(lambda: {'max_len': -1, 'index': -1})

# for idx, example in enumerate(tqdm(train_dataset)):
#     for k, v in example.items():
#         length = None
#         if isinstance(v, torch.Tensor):
#             length = v.shape[0]
#         elif isinstance(v, list) or isinstance(v, str):
#             length = len(v)
#         else:
#             continue

#         if length > field_max[k]['max_len']:
#             field_max[k]['max_len'] = length
#             field_max[k]['index'] = idx

# print("字段最大长度及对应样本index:")
# for k, info in field_max.items():
#     print(f"{k}: max_len={info['max_len']}, index={info['index']}")

# # 以某个字段为例，保存最大长度样本
# key_of_interest = 'input_ids'  # 比如你想要最长input_ids的样本
# max_index = field_max[key_of_interest]['index']
# longest_sample = train_dataset[max_index]

# # longest_sample 就是你想保存的最大样本
# # 你可以保存为json，或torch保存，根据需求处理

In [None]:
# import torch

# torch.save(longest_sample, "longest_sample.pt")
longest_sample = torch.load("longest_sample.pt")

In [None]:
# for k, v in longest_sample.items():
#     if isinstance(v, torch.Tensor):
#         # 确保v至少有一维且长度大于600才截断
#         if v.size(0) > 600:
#             longest_sample[k] = v[:600]
#         print(f"{k}: {longest_sample[k].shape}")
#     elif isinstance(v, list) or isinstance(v, str):
#         print(f"{k}: {len(v)}")
#     else:
#         print(f"{k}: {type(v)}")

In [None]:
longest_sample

{'input_ids': tensor([151644,   8948,    198,   2610,   1265,   1795,    279,  11221,  15516,
            323,  10339,    697,  11253,    304,   7716,     13, 151645, 151644,
            872,    198, 151857, 151859, 151859, 151859, 151859, 151859, 151859,
         151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
         151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
         151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
         151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
         151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
         151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
         151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
         151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
         151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859, 151859,
         151859

In [None]:
for k,v in longest_sample.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: {v.shape}")
    elif isinstance(v, list) or isinstance(v, str):
        print(f"{k}: {len(v)}")
    else:
        print(f"{k}: {type(v)}")  # 打印其他类型的字段

input_ids: torch.Size([919])
labels: torch.Size([919])
image: 1
image_high: 1


In [None]:
# import matplotlib.pyplot as plt
# from collections import Counter

# for k, v_list in field_shapes.items():
#     print(f"Plotting field: {k}")

#     # 判断v_list中元素类型
#     if all(isinstance(x, int) for x in v_list):
#         # 连续数值，用直方图
#         plt.figure(figsize=(8, 4))
#         plt.hist(v_list, bins=30, color='skyblue', edgecolor='black')
#         plt.title(f"{k} length distribution")
#         plt.xlabel("Length")
#         plt.ylabel("Frequency")
#         plt.grid(True, linestyle='--', alpha=0.5)
#         plt.show()

#     elif all(isinstance(x, tuple) for x in v_list):
#         # 形状分布，统计频次并画条形图
#         shape_counts = Counter(v_list)
#         shapes = list(shape_counts.keys())
#         counts = list(shape_counts.values())

#         # 把tuple转成字符串方便显示
#         shapes_str = [str(s) for s in shapes]

#         plt.figure(figsize=(10, 5))
#         plt.bar(range(len(counts)), counts,
#                 color='lightcoral', edgecolor='black')
#         plt.xticks(range(len(counts)), shapes_str, rotation=45, ha='right')
#         plt.title(f"{k} shape distribution")
#         plt.xlabel("Shape")
#         plt.ylabel("Frequency")
#         plt.tight_layout()
#         plt.show()

#     else:
#         print(f"Skipped field {k} with unsupported type for plotting.")

In [None]:
import torch
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# 假设 longest_sample 是字典，包含键 'input_ids', 'attention_mask', 'labels', 'images'
# input_ids, attention_mask, labels是tensor，images是list或tensor
# batch_size = 16


def make_batch(sample, batch_size):
    batch = {}
    for k, v in sample.items():
        if isinstance(v, torch.Tensor):
            batch[k] = v.unsqueeze(0).repeat(batch_size, *([1] * (v.dim())))
        elif isinstance(v, list):
            # 假设 images 是 list of tensors，简单重复list内容
            batch[k] = v * batch_size  # 复制列表，batch_size倍
        else:
            # 其他情况，简单复制
            batch[k] = [v] * batch_size
    return batch

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import torch
from transformers import default_data_collator
from GOT.data import DataCollatorForSupervisedDataset
# 模拟 batch_size = 16
batch_size = 6
# 构造 batch 的输入样本列表
features = [longest_sample] * batch_size  # 复制样本

# 调用默认 collator（会处理 input_ids, attention_mask, labels）
collator= DataCollatorForSupervisedDataset(tokenizer=tokenizer)
batch = collator(features)


In [None]:
# 取一个 batch
# batch = next(iter(dataloader))  # 或者: for batch in dataloader:

# 把数据移动到 GPU（确保模型也在 CUDA 上）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)

# 图像可能需要特殊处理（比如 list of tensors 或 Tensor）
images = batch["images"]
images = [(item[0].to(device), item[1].to(device))
          for item in images]
# if isinstance(images, list):
#     images = torch.stack(images)
# images = images.to(device)

Using device: cuda


In [None]:
# model.float()

In [None]:
print(input_ids.device)
print(attention_mask.device)
print(labels.device)
print(images[0][1].device)
print(next(model.parameters()).device)

cuda:0
cuda:0
cuda:0
cuda:0
cuda:0


In [None]:
# 清空优化器梯度
optimizer.zero_grad()

# 显存前
print("显存占用（训练前）:", torch.cuda.memory_reserved() / 1024**2, "MB")
print(torch.cuda.memory_summary())

显存占用（训练前）: 2268.0 MB
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   2183 MiB |   2214 MiB |   2551 MiB | 377423 KiB |
|       from large pool |   2181 MiB |   2212 MiB |   2548 MiB | 375808 KiB |
|       from small pool |      1 MiB |      2 MiB |      2 MiB |   1615 KiB |
|---------------------------------------------------------------------------|
| Active memory         |   2183 MiB |   2214 MiB |   2551 MiB | 377423 KiB |
|       from large pool |   2181 MiB |   2212 MiB |   2548 MiB | 375808 KiB |
|       from small pool |      1 MiB |      2 MiB |      2 MiB |   1615 KiB |
|------------------------------------------

In [None]:




with torch.cuda.amp.autocast():
    outputs = model(input_ids=input_ids,
                    attention_mask=attention_mask, labels=labels, images=images)
    print("outputs:", outputs)
    loss = outputs.loss
# loss = outputs.loss
print("Loss:", loss.item())

# 反向传播 + 更新参数
loss.backward()
optimizer.step()

# 显存后
print("显存占用（训练后）:", torch.cuda.memory_allocated() / 1024**2, "MB")


  with torch.cuda.amp.autocast():


outputs: CausalLMOutputWithPast(loss=tensor(0.4377, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[[ 5.2305,  4.3984,  5.8164,  ..., -0.9219, -1.6182, -1.7695],
         [ 5.4609,  4.6758,  6.1484,  ..., -0.7266, -1.4258, -1.7080],
         [ 4.6328,  3.7129,  5.3047,  ..., -1.3828, -2.0488, -1.7646],
         ...,
         [20.1719, 16.2500, 17.0000,  ...,  7.8086, 14.9688, 11.0938],
         [17.0625, 18.3438, 21.6875,  ..., 11.2422, 18.5312, 12.0625],
         [16.7812, 16.2656, 19.4688,  ...,  6.0039, 14.3750,  7.7695]],

        [[ 5.2305,  4.3984,  5.8164,  ..., -0.9219, -1.6182, -1.7695],
         [ 5.4609,  4.6758,  6.1484,  ..., -0.7266, -1.4258, -1.7080],
         [ 4.6328,  3.7129,  5.3047,  ..., -1.3828, -2.0488, -1.7646],
         ...,
         [20.1719, 16.2500, 17.0000,  ...,  7.8086, 14.9688, 11.0938],
         [17.0625, 18.3438, 21.6875,  ..., 11.2422, 18.5312, 12.0625],
         [16.7812, 16.2656, 19.4688,  ...,  6.0039, 14.3750,  7.7695]],

        [[ 

OutOfMemoryError: CUDA out of memory. Tried to allocate 3.12 GiB. GPU 0 has a total capacity of 23.57 GiB of which 2.37 GiB is free. Including non-PyTorch memory, this process has 21.17 GiB memory in use. Of the allocated memory 20.58 GiB is allocated by PyTorch, and 293.29 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 1         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  11122 MiB |  21514 MiB | 182150 MiB | 171027 MiB |
|       from large pool |  11120 MiB |  21501 MiB | 181596 MiB | 170476 MiB |
|       from small pool |      2 MiB |     19 MiB |    554 MiB |    551 MiB |
|---------------------------------------------------------------------------|
| Active memory         |  11122 MiB |  21514 MiB | 182150 MiB | 171027 MiB |
|       from large pool |  11120 MiB |  21501 MiB | 181596 MiB | 170476 MiB |
|       from small pool |      2 MiB |     19 MiB |    554 MiB |    551 MiB |
|---------------------------------------------------------------

In [None]:
# batch_size = 16
# seq_len = 200

# # 模拟input_ids：随机生成在词表大小范围内
# vocab_size = model.config.vocab_size
# input_ids = torch.randint(low=0, high=vocab_size, size=(
#     batch_size, seq_len), device=device)

# # 模拟attention_mask：全1
# attention_mask = torch.ones_like(input_ids, device=device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

# # 模拟labels（可以和input_ids相同，也可部分设-100忽略loss）
# labels = input_ids.clone()
# labels[:, :10] = -100

# # 模拟 images 输入（batch内每个样本含1张图像，1个patch，3通道1024x1024）
# # GOT代码里图像输入是一个list，列表长度为batch_size，每个元素形如(image_count, 3, H, W)
# # 这里image_count=1, 高度宽度为1024
# images = []

In [None]:




# for _ in range(batch_size):
#     img_tensor = torch.randn(1, 3, 1024, 1024, device=device, dtype=torch.float32)  # 模拟单张图像
#     images.append((None, img_tensor))  # 你的代码中传入的images元素形如 tuple，第二项是Tensor


In [None]:
# len(images)

In [None]:
# len(images[0])
    

In [None]:
# # 记录显存占用
# print("显存占用（训练前）:", torch.cuda.memory_allocated() / 1024**2, "MB")

# optimizer.zero_grad()

# outputs = model(
#     input_ids=input_ids,
#     attention_mask=attention_mask,
#     labels=labels,
#     images=images,
#     return_dict=True,
# )

# loss = outputs.loss
# print("Loss:", loss.item())

# loss.backward()

# optimizer.step()

# print("显存占用（训练后）:", torch.cuda.memory_allocated() / 1024**2, "MB")