In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoProcessor

device = "cuda:0"
model_path = "DAMO-NLP-SG/VideoLLaMA3-2B"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
    device_map={"": device},
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)



In [None]:
from datasets import load_from_disk
dataset = load_from_disk("/home/ubuntu/temp/large_sports/large_dataset")

In [None]:
from tqdm import tqdm
for sample in tqdm(dataset, total=len(dataset)):
    break
print(sample.keys())

In [None]:
data_list = []
for sample in tqdm(dataset, total=len(dataset)):
    video_path = sample["mp4_path"]
    conversation = [
        {"role": "system", "content": "You are a helpful assistant."},
        {
            "role": "user",
            "content": [
                {"type": "video", "video": {"video_path": video_path, "fps": 1, "max_frames": 20, "start_time": 0, "end_time": 10}},
                {"type": "text", "text": "Describe the video."},
            ]
        },
    ]

    inputs = processor(
        conversation=conversation,
        add_system_prompt=True,
        add_generation_prompt=True,
        return_tensors="pt"
    )
    inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
    if "pixel_values" in inputs:
        inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
    output_ids = model.generate(**inputs, max_new_tokens=128)
    response = processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

    data_list.append({"mp4_path": video_path, "description": response})
    print(response)

In [None]:
len(data_list)

In [None]:
#collate to original dataset
collate_list = []

for sample in data_list:
    data_dict = {}
    data_dict["mp4_path"] = sample["mp4_path"]
    data_dict["chunks"] = [{"activity": {"start": 0, "end": 15, "description": sample["description"]}, "interval": (0,15)}]
    collate_list.append(data_dict)

In [None]:
collate_list[2]

In [None]:
from datasets import Dataset
filtered_dataset = Dataset.from_list(collate_list)
filtered_dataset.save_to_disk(f"/home/ubuntu/temp/large_sports/sports_description_dataset")

In [None]:
filtered_dataset[-1]