In [None]:
from univa.dataset.qwen2vl_dataset import Qwen2VLDataset
from univa.models.qwen2p5vl.modeling_univa_qwen2p5vl import UnivaQwen2p5VLForConditionalGeneration

from univa.utils.anyres_util import dynamic_resize
from univa.utils.prompter import Qwen2VLPrompter
from torchvision import transforms
from transformers import (
    CLIPTextModel,
    T5EncoderModel,
    CLIPTokenizer,
    T5TokenizerFast,
    AutoImageProcessor,
    PreTrainedTokenizer,
    AutoTokenizer,
    AutoProcessor, SiglipImageProcessor
)
from torch.utils.data import DataLoader
from univa.dataset.data_collator import DataCollator, pad_list_of_tensors



In [None]:
dataset_type = "qwen2p5vl"
data_txt = "/workspace/UniWorld-V1/training_data/uniworld_removal_dataset/data.txt"
anyres = 'any_1ratio'


anchor_pixels = 1024 * 1024


resize_lambda = transforms.Lambda(
    lambda img: transforms.Resize(
        dynamic_resize(
            img.shape[1], img.shape[2], anyres, anchor_pixels), 
            interpolation=transforms.InterpolationMode.BICUBIC
        )(img)
)
transform = transforms.Compose([
        resize_lambda,
        transforms.Normalize([0.5], [0.5]),
    ]
)

min_pixels, max_pixels = 1048576, 1048576

drop_condition_rate = 0.0
joint_ref_feature = False
mask_weight_type  = "log"

In [None]:
with open(data_txt, "r") as f:
    datasets = [line.strip() for line in f.readlines()]

datasets

In [None]:
pretrained_lvlm_name_or_path = "/workspace/UniWorld-V1/model_weight/UniWorld-V1"
pretrained_siglip_name_or_path = "/workspace/UniWorld-V1/model_weight/siglip2-so400m-patch16-512"
processor = AutoProcessor.from_pretrained(
    pretrained_lvlm_name_or_path,
)
lvlm_tokenizer = processor.tokenizer
image_processor = processor.image_processor



lvlm_model = UnivaQwen2p5VLForConditionalGeneration.from_pretrained(
    pretrained_lvlm_name_or_path,
    attn_implementation='flash_attention_2',
)


prompter = Qwen2VLPrompter()



siglip_processor = SiglipImageProcessor.from_pretrained(
    pretrained_siglip_name_or_path
    )

In [None]:
dataset = Qwen2VLDataset(
        dataset_type=dataset_type, 
        data_txt=data_txt,
        transform=transform, 
        tokenizer=lvlm_tokenizer,
        prompter=prompter,
        image_processor=image_processor,
        processor=processor,
        min_pixels=min_pixels,
        max_pixels=max_pixels,
        image_token_length=lvlm_model.config.image_token_length,
        only_generated_task=True,
        drop_prompt_rate=drop_condition_rate,
        joint_ref_feature=joint_ref_feature, 
        anyres=anyres, 
        mask_weight_type=mask_weight_type, 
        siglip_processor=siglip_processor, 
        ocr_enhancer=False, 
        random_data=False)

In [None]:
batch_size = 2
pin_memory = True
num_workers = 16
padding_side = "left"

data_collator = DataCollator(tokenizer=lvlm_tokenizer, padding_side=padding_side)



train_dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=pin_memory,
    num_workers=num_workers,
    collate_fn=data_collator,
    prefetch_factor=None if num_workers == 0 else 4
)

In [None]:
len(train_dataloader)

In [None]:
for step, batch in enumerate(train_dataloader):
    prompts = batch["prompts"]
    print(len(prompts))