# Traffic Video MCQ Dataset Preparation

This notebook prepares the Zalo AI Traffic dataset for fine-tuning VLMs.
Dataset structure follows HuggingFaceM4/ChartQA format but adapted for multiple-choice questions with video frames.

**Dataset Structure:**
- Questions: Vietnamese traffic safety multiple-choice questions
- Videos: Dashcam footage with support frames
- Answers: A/B/C/D format with full text

## 1. Install Dependencies

In [1]:
# Install required packages for Qwen2.5-VL with 4-bit quantization
# !pip install datasets opencv-python pillow numpy tqdm
# !pip install transformers accelerate bitsandbytes peft
# !pip install qwen-vl-utils  # For Qwen2.5-VL video/image processing

## 2. Import Libraries

In [6]:
import os
import json
import cv2
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm import tqdm
from datasets import Dataset, DatasetDict
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
import warnings
warnings.filterwarnings('ignore')

# Check CUDA availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

PyTorch version: 2.7.0+cu128
CUDA available: True
GPU: NVIDIA GeForce RTX 5060 Ti
CUDA version: 12.8


## 3. Configuration

In [7]:
# Dataset paths
TRAIN_JSON_PATH = r"d:\ZALO_AI\trainining\train\train.json"
VIDEOS_PATH = r"d:\ZALO_AI\trainining\train\videos"

# Model configuration for Qwen2.5-VL
MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"

# Processing parameters
MAX_FRAMES = 8  # Maximum frames to extract per video (Qwen2.5-VL optimized)
TRAIN_SPLIT = 0.9  # 90% for training, 10% for validation
RANDOM_SEED = 42

# 4-bit Quantization config
USE_4BIT = True
BNB_4BIT_COMPUTE_DTYPE = torch.bfloat16
BNB_4BIT_QUANT_TYPE = "nf4"
BNB_4BIT_USE_DOUBLE_QUANT = True

# Output path
OUTPUT_DIR = r"d:\ZALO_AI\trainining\processed_dataset"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"🎯 Model: {MODEL_ID}")
print(f"📊 Train JSON: {TRAIN_JSON_PATH}")
print(f"📹 Videos Path: {VIDEOS_PATH}")
print(f"💾 Output Directory: {OUTPUT_DIR}")
print(f"⚙️ 4-bit Quantization: {USE_4BIT}")
print(f"🖼️ Max frames per video: {MAX_FRAMES}")

🎯 Model: Qwen/Qwen2.5-VL-3B-Instruct
📊 Train JSON: d:\ZALO_AI\trainining\train\train.json
📹 Videos Path: d:\ZALO_AI\trainining\train\videos
💾 Output Directory: d:\ZALO_AI\trainining\processed_dataset
⚙️ 4-bit Quantization: True
🖼️ Max frames per video: 8


## 4. Load Raw Data

In [8]:
# Load train.json
with open(TRAIN_JSON_PATH, 'r', encoding='utf-8') as f:
    raw_data = json.load(f)

train_samples = raw_data['data']
total_samples = len(train_samples)

print(f"✅ Loaded {total_samples} training samples")
print(f"\n📊 Sample structure:")
sample = train_samples[0]
for key, value in sample.items():
    if key not in ['_unused_']:
        print(f"  {key}: {value if not isinstance(value, list) or len(str(value)) < 100 else str(value)[:100] + '...'}")

✅ Loaded 1490 training samples

📊 Sample structure:
  id: train_0001
  question: Nếu xe ô tô đang chạy ở làn ngoài cùng bên phải trong video này thì xe đó chỉ được phép rẽ phải?
  choices: ['A. Đúng', 'B. Sai']
  answer: B. Sai
  support_frames: [4.427402]
  video_path: train/videos/2b840c67_386_clip_002_0008_0018_Y.mp4


## 5. Video Frame Extraction Functions

In [9]:
def extract_frames_from_video(video_path, support_frames=None, max_frames=12):
    """
    Extract frames from video at support frame timestamps or uniformly.

    Args:
        video_path: Path to video file
        support_frames: List of timestamps (in seconds) to extract frames
        max_frames: Maximum number of frames to extract

    Returns:
        List of PIL Images
    """
    if not os.path.exists(video_path):
        print(f"⚠️ Video not found: {video_path}")
        return []

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"⚠️ Cannot open video: {video_path}")
        return []

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    duration = total_frames / fps if fps > 0 else 0

    frames = []

    # Use support frames if provided, otherwise uniform sampling
    if support_frames and len(support_frames) > 0:
        # Extract frames at support timestamps
        for timestamp in support_frames[:max_frames]:
            frame_idx = int(timestamp * fps)
            if frame_idx < total_frames:
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
                ret, frame = cap.read()
                if ret:
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frames.append(Image.fromarray(frame_rgb))

        # If not enough frames from support_frames, add uniform samples
        if len(frames) < max_frames and total_frames > len(frames):
            remaining = max_frames - len(frames)
            uniform_indices = np.linspace(0, total_frames - 1, remaining, dtype=int)
            for idx in uniform_indices:
                cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
                ret, frame = cap.read()
                if ret:
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frames.append(Image.fromarray(frame_rgb))
    else:
        # Uniform sampling across the video
        if duration < 10:  # Short videos: sample every 0.5s
            frame_interval = max(1, int(fps * 0.5))
            frame_indices = list(range(0, total_frames, frame_interval))[:max_frames]
        else:
            frame_indices = np.linspace(0, total_frames - 1, max_frames, dtype=int)

        for frame_idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if ret:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(Image.fromarray(frame_rgb))

    cap.release()
    return frames


def format_question_with_choices(question, choices):
    """
    Format question with multiple choice options.
    Similar to ChartQA 'query' field.
    """
    formatted = f"{question}\n\nCác lựa chọn:\n"
    for choice in choices:
        formatted += f"{choice}\n"
    formatted += "\nHãy chọn đáp án đúng:"
    return formatted

## 6. Data Processing Pipeline

In [None]:
def process_sample(sample, videos_base_path, max_frames=12):
    """
    Process a single sample into ChartQA-like format.

    Returns:
        dict with keys: 'id', 'query', 'label', 'image' (first frame), 'frames' (all frames)
        or None if processing fails
    """
    try:
        # Extract video path
        video_rel_path = sample['video_path']
        video_full_path = os.path.join(videos_base_path, os.path.basename(video_rel_path))

        # Extract frames
        support_frames = sample.get('support_frames', [])
        frames = extract_frames_from_video(video_full_path, support_frames, max_frames)

        if not frames:
            return None

        # Format question with choices (similar to ChartQA 'query')
        query = format_question_with_choices(sample['question'], sample['choices'])

        # Label is the answer (similar to ChartQA 'label' which is a list)
        label = [sample['answer']]

        return {
            'id': sample['id'],
            'query': query,  # Question + choices formatted
            'label': label,  # Answer as list (ChartQA format)
            'image': frames[0],  # First frame as main image
            'frames': frames,  # All frames for video understanding
            'video_path': video_rel_path,
            'support_frames': support_frames,
            'question': sample['question'],  # Keep original question
            'choices': sample['choices'],  # Keep original choices
            'answer': sample['answer']  # Keep original answer
        }

    except Exception as e:
        print(f"⚠️ Error processing sample {sample.get('id', 'unknown')}: {e}")
        return None


print("✅ Processing functions defined")

✅ Processing functions defined


: 

## 7. Process All Samples

In [None]:
print(f"Processing {total_samples} samples...\n")

processed_data = []
failed_count = 0

for sample in tqdm(train_samples, desc="Processing videos"):
    processed = process_sample(sample, VIDEOS_PATH, max_frames=MAX_FRAMES)
    if processed:
        processed_data.append(processed)
    else:
        failed_count += 1

print(f"\n✅ Successfully processed: {len(processed_data)} samples")
print(f"❌ Failed to process: {failed_count} samples")
print(f"Success rate: {len(processed_data)/total_samples*100:.2f}%")

Processing 1490 samples...



Processing videos:  37%|███▋      | 553/1490 [07:51<13:42,  1.14it/s]

## 8. Verify Processed Data

In [None]:
# Display a sample
if processed_data:
    sample = processed_data[0]
    print("📋 Processed Sample Structure:")
    print(f"ID: {sample['id']}")
    print(f"\nQuery (formatted question):\n{sample['query']}")
    print(f"\nLabel (answer): {sample['label']}")
    print(f"\nNumber of frames: {len(sample['frames'])}")
    print(f"Image size: {sample['image'].size}")

    # Display first frame
    import matplotlib.pyplot as plt
    plt.figure(figsize=(8, 6))
    plt.imshow(sample['image'])
    plt.title(f"Sample: {sample['id']}")
    plt.axis('off')
    plt.show()

## 9. Format Data for Fine-tuning (ChartQA-style)

In [None]:
def format_data_for_training(sample, system_message=None):
    """
    Format sample for Qwen2.5-VL fine-tuning with 4-bit quantization.
    This creates the message structure needed for training.

    Compatible with:
    - Qwen2.5-VL-3B-Instruct
    - 4-bit quantization (BitsAndBytes)
    - LoRA fine-tuning
    """
    if system_message is None:
        system_message = (
            "Bạn là trợ lý AI chuyên về an toàn giao thông Việt Nam. "
            "Nhiệm vụ của bạn là phân tích video dashcam và trả lời các câu hỏi "
            "về luật giao thông, biển báo, và tình huống giao thông. "
            "Hãy trả lời chính xác dựa trên nội dung video."
        )

    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}],
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": sample["image"]},  # Main image (first frame)
                {"type": "text", "text": sample["query"]},  # Formatted question
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": sample["label"][0]}],  # Answer
        },
    ]


# Format all processed data
formatted_data = [format_data_for_training(sample) for sample in processed_data]

print(f"✅ Formatted {len(formatted_data)} samples for Qwen2.5-VL training")
print(f"🔧 Compatible with 4-bit quantization + LoRA")
print("\n📋 Sample formatted data:")
print(f"Roles: {[msg['role'] for msg in formatted_data[0]]}")
print(f"\nSystem message: {formatted_data[0][0]['content'][0]['text'][:100]}...")
print(f"\nUser content types: {[c['type'] for c in formatted_data[0][1]['content']]}")
print(f"\nAssistant response: {formatted_data[0][2]['content'][0]['text']}")

## 10. Split Dataset (Train/Val)

In [None]:
import random

# Set seed for reproducibility
random.seed(RANDOM_SEED)

# Shuffle data
indices = list(range(len(processed_data)))
random.shuffle(indices)

# Split indices
split_idx = int(len(indices) * TRAIN_SPLIT)
train_indices = indices[:split_idx]
val_indices = indices[split_idx:]

# Create train and validation sets
train_data = [processed_data[i] for i in train_indices]
val_data = [processed_data[i] for i in val_indices]

# Formatted versions
train_formatted = [formatted_data[i] for i in train_indices]
val_formatted = [formatted_data[i] for i in val_indices]

print(f"✅ Dataset split:")
print(f"   Training samples: {len(train_data)} ({len(train_data)/len(processed_data)*100:.1f}%)")
print(f"   Validation samples: {len(val_data)} ({len(val_data)/len(processed_data)*100:.1f}%)")

## 11. Create HuggingFace Dataset

In [None]:
# Create Dataset objects
train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)

# Create DatasetDict
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset
})

print("✅ HuggingFace Dataset created:")
print(dataset_dict)
print("\n📊 Dataset features:")
print(train_dataset.features)

## 12. Save Datasets

In [None]:
# Save as JSON for easy inspection
output_json = os.path.join(OUTPUT_DIR, "traffic_mcq_dataset.json")
with open(output_json, 'w', encoding='utf-8') as f:
    json.dump({
        'train': [{
            'id': s['id'],
            'query': s['query'],
            'label': s['label'],
            'video_path': s['video_path'],
            'question': s['question'],
            'choices': s['choices'],
            'answer': s['answer']
        } for s in train_data],
        'validation': [{
            'id': s['id'],
            'query': s['query'],
            'label': s['label'],
            'video_path': s['video_path'],
            'question': s['question'],
            'choices': s['choices'],
            'answer': s['answer']
        } for s in val_data]
    }, f, ensure_ascii=False, indent=2)

print(f"✅ Saved JSON dataset to: {output_json}")

# Save formatted data for training (pickle for preserving PIL Images)
import pickle

formatted_output = os.path.join(OUTPUT_DIR, "formatted_training_data.pkl")
with open(formatted_output, 'wb') as f:
    pickle.dump({
        'train': train_formatted,
        'validation': val_formatted
    }, f)

print(f"✅ Saved formatted training data to: {formatted_output}")

# Save dataset statistics
stats = {
    'total_samples': len(processed_data),
    'train_samples': len(train_data),
    'val_samples': len(val_data),
    'failed_samples': failed_count,
    'train_split': TRAIN_SPLIT,
    'max_frames': MAX_FRAMES,
    'random_seed': RANDOM_SEED
}

stats_file = os.path.join(OUTPUT_DIR, "dataset_stats.json")
with open(stats_file, 'w', encoding='utf-8') as f:
    json.dump(stats, f, indent=2)

print(f"✅ Saved dataset statistics to: {stats_file}")

## 13. Dataset Analysis

In [None]:
# Analyze answer distribution
from collections import Counter

# Extract answer options (A, B, C, D)
answer_options = [sample['answer'][0] for sample in processed_data]
answer_counts = Counter(answer_options)

print("📊 Answer Distribution:")
for option, count in sorted(answer_counts.items()):
    print(f"   {option}: {count} ({count/len(processed_data)*100:.1f}%)")

# Analyze number of choices
num_choices = [len(sample['choices']) for sample in processed_data]
choice_counts = Counter(num_choices)

print("\n📊 Number of Choices Distribution:")
for num, count in sorted(choice_counts.items()):
    print(f"   {num} choices: {count} samples ({count/len(processed_data)*100:.1f}%)")

# Analyze frames extracted
num_frames = [len(sample['frames']) for sample in processed_data]
avg_frames = sum(num_frames) / len(num_frames)
print(f"\n📊 Average frames per video: {avg_frames:.2f}")
print(f"   Min frames: {min(num_frames)}")
print(f"   Max frames: {max(num_frames)}")

## 14. Example: Load Dataset for Training

In [None]:
# Example: How to load the formatted data for training
import pickle

# Load formatted data
with open(formatted_output, 'rb') as f:
    loaded_data = pickle.load(f)

loaded_train = loaded_data['train']
loaded_val = loaded_data['validation']

print(f"✅ Loaded training data:")
print(f"   Train: {len(loaded_train)} samples")
print(f"   Validation: {len(loaded_val)} samples")

print("\n📋 Sample loaded data structure:")
print(f"Type: {type(loaded_train[0])}")
print(f"Keys: {[msg['role'] for msg in loaded_train[0]]}")
print(f"\nFirst message:")
print(loaded_train[0][0])  # System message
print(f"\nUser message content types:")
print([c['type'] for c in loaded_train[0][1]['content']])

## 14.5. Test Dataset with Qwen2.5-VL (4-bit Quantized)

In [None]:
# Optional: Test loading Qwen2.5-VL model with 4-bit quantization
# This verifies the dataset is compatible with the model

print("🧪 Testing Qwen2.5-VL model loading with 4-bit quantization...")

try:
    # Configure 4-bit quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=USE_4BIT,
        bnb_4bit_use_double_quant=BNB_4BIT_USE_DOUBLE_QUANT,
        bnb_4bit_quant_type=BNB_4BIT_QUANT_TYPE,
        bnb_4bit_compute_dtype=BNB_4BIT_COMPUTE_DTYPE,
    )

    print(f"📦 Loading model: {MODEL_ID}")
    print(f"⚙️ Quantization config:")
    print(f"   - 4-bit: {USE_4BIT}")
    print(f"   - Compute dtype: {BNB_4BIT_COMPUTE_DTYPE}")
    print(f"   - Quant type: {BNB_4BIT_QUANT_TYPE}")
    print(f"   - Double quant: {BNB_4BIT_USE_DOUBLE_QUANT}")

    # Load processor
    processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
    print(f"✅ Processor loaded")

    # Load model (uncomment to actually load)
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        quantization_config=bnb_config,
        attn_implementation="flash_attention_2",  # Optional: requires flash-attn
    )
    print(f"✅ Model loaded successfully!")
    print(f"📊 Model parameters: {model.num_parameters():,}")
    print(f"💾 Memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

    print("\n💡 To test the model, uncomment the model loading code above")
    print("⚠️ Note: Model loading requires ~3-4GB VRAM with 4-bit quantization")

except Exception as e:
    print(f"❌ Error: {e}")
    print("💡 Make sure you have installed: transformers, bitsandbytes, accelerate")

## 15. Summary and Next Steps

In [None]:
print("="*80)
print("📊 DATASET PREPARATION COMPLETE - Qwen2.5-VL Ready")
print("="*80)
print(f"\n✅ Total processed samples: {len(processed_data)}")
print(f"✅ Training samples: {len(train_data)}")
print(f"✅ Validation samples: {len(val_data)}")
print(f"\n📁 Output files:")
print(f"   - JSON dataset: {output_json}")
print(f"   - Formatted training data: {formatted_output}")
print(f"   - Dataset statistics: {stats_file}")
print(f"\n🎯 Model Configuration:")
print(f"   - Model: {MODEL_ID}")
print(f"   - 4-bit Quantization: {USE_4BIT}")
print(f"   - Compute dtype: {BNB_4BIT_COMPUTE_DTYPE}")
print(f"   - Frames per video: {MAX_FRAMES}")
print(f"\n💡 Next steps:")
print(f"   1. Load 'formatted_training_data.pkl' for fine-tuning")
print(f"   2. Use BitsAndBytesConfig for 4-bit quantization")
print(f"   3. Apply LoRA adapters (recommended for 4-bit training)")
print(f"   4. Fine-tune with SFTTrainer from TRL")
print(f"\n🔗 Compatible with:")
print(f"   - Qwen2.5-VL-3B-Instruct (4-bit quantized)")
print(f"   - LoRA fine-tuning with PEFT")
print(f"   - TRL SFTTrainer")
print(f"   - Flash Attention 2 (optional)")
print(f"\n💾 Estimated VRAM usage:")
print(f"   - Model (4-bit): ~3-4GB")
print(f"   - Training (batch_size=1): ~6-8GB")
print(f"   - Inference: ~4-5GB")
print("="*80)