In [3]:
# from src.data.dataset import VietnameseTrafficDataset, collate_fn
# from torch.utils.data import DataLoader

# def test_dataloader(
#     split_json: str,
#     video_root: str,
#     num_samples: int = 5,
# ):
#     dataset = VietnameseTrafficDataset(
#         json_path=split_json,
#         video_root=video_root,
#         num_frames=8,
#         use_support_frames=True,
#         max_samples=num_samples
#     )
#     dataloader = DataLoader(
#         dataset,
#         batch_size=2,
#         shuffle=False,
#         num_workers=0,
#         collate_fn=collate_fn
#     )
#     return dataset, dataloader

# dataset, dataloader = test_dataloader(
#     split_json='data/processed/train_split.json',
#     video_root='data/',
#     num_samples=5
# )
# result = dataset.__getitem__(3)
# print(result['frames'].shape)  # Should be (8, H, W, 3) where H, W are video's original size

In [4]:
import torch
from pathlib import Path
from src.data.dataset import VietnameseTrafficDataset, collate_fn
from src.models.video_llava import VietnameseTrafficVQAModel
from torch.utils.data import DataLoader

# Load a small subset
dataset = VietnameseTrafficDataset(
    json_path="path/to/train.json",
    video_root="path/to/videos",
    num_frames=8,
    max_samples=2  # Just 2 samples for testing
)

# Create dataloader
loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)

# Initialize model
model = VietnameseTrafficVQAModel(
    mode="lite",  # Use lite for testing
    use_lora=True,
    load_in_4bit=True
)

# Test forward pass
for batch in loader:
    frames = batch['frames']
    prompts = batch['prompts']
    answers = batch['answers']
    
    print(f"Frames shape: {frames.shape}")
    print(f"Prompts: {prompts}")
    print(f"Answers: {answers}")
    
    # Test prepare_inputs
    inputs = model.prepare_inputs(frames, prompts, answers)
    print(f"Input keys: {inputs.keys()}")
    print(f"Input_ids shape: {inputs['input_ids'].shape}")
    
    # Test forward pass
    inputs = {k: v.to('cuda') if isinstance(v, torch.Tensor) else v 
              for k, v in inputs.items()}
    outputs = model(**inputs)
    print(f"Loss: {outputs.loss.item()}")
    
    break  # Just test one batch

ModuleNotFoundError: No module named 'peft'