# Fine-tuning InternVL3.5-1B on MVBench

This notebook demonstrates fine-tuning of the InternVL3.5-1B multimodal model on a video understanding task from the MVBench benchmark.

## Before running:
1. Update API keys and tokens in Cell 7 (WANDB_API_KEY, REPO_ID, HF token)
2. Ensure sufficient disk space for video downloads

In [8]:
# !pip install -U -q pip setuptools wheel
# !pip install --upgrade --no-deps -q numpy==2.0.0 huggingface-hub==0.35.3
# !pip install -U -q huggingface-hub>=0.23.0 pyarrow>=21.0.0
# !pip install -U -q transformers==4.45.0 accelerate bitsandbytes peft datasets==4.1.0
# !pip install -q av decord
# !pip install -q lightning
# !pip install -q wandb
# !pip install -q timm deepspeed flash_attn
# !pip install -q nltk
# !pip install -q opencv-python

In [None]:
!pip install -U -q transformers accelerate bitsandbytes peft datasets
!pip install -q av decord
!pip install -q lightning
!pip install -q pyarrow==21.0.0
!pip install -q wandb
!pip install -q timm deepspeed flash_attn
!pip install -U -q tensorboard

## Implementation Choices and Justification

### Task Selection
**Selected Task:** Action Sequence from MVBench
- This task involves understanding temporal sequences of actions in videos
- It is a multiple-choice task, making it suitable for accuracy-based evaluation
- The dataset has 188 valid samples after filtering missing videos

### Fine-tuning Method: QLoRA (Quantized Low-Rank Adaptation)

**Why QLoRA?**

1. **Memory Efficiency**: 
   - 4-bit quantization reduces model memory footprint by ~75%
   - Enables fine-tuning on consumer GPUs (e.g., single A100 or even RTX 4090)
   - Only adapter parameters are trained, not the full model

2. **Performance**:
   - Research shows QLoRA achieves near-identical performance to full fine-tuning
   - NormalFloat 4-bit (NF4) quantization preserves model quality
   - LoRA rank of 16 provides good balance between capacity and efficiency

3. **Practical Considerations**:
   - Faster iteration and experimentation
   - Lower computational costs
   - Easier to deploy (smaller checkpoint files)

**Configuration:**
- Quantization: 4-bit NF4 with double quantization
- LoRA rank: 16
- LoRA alpha: 32
- Target modules: All linear layers in the language model (excluding vision encoder and projector)
- Compute dtype: bfloat16

### Metric: Accuracy
- Multiple-choice questions have 4 options (A, B, C, D)
- Accuracy is the standard metric for such tasks
- Easy to interpret and compare with baselines

### Training Configuration
- Optimizer: AdamW with weight decay (0.01)
- Learning rate: 1e-4 with linear warmup (50 steps)
- Batch size: 1 with gradient accumulation (8 steps)
- Effective batch size: 8
- Mixed precision: bfloat16
- Max epochs: 2 (to prevent overfitting on small dataset)
- Early stopping: patience of 3 epochs based on validation accuracy

### Data Processing
- Videos: 8 frames uniformly sampled from each video
- Resolution: 448x448 (dynamic tiling)
- Train/Test split: 80/20 (stratified)
- Data augmentation: None (to maintain temporal consistency)


## Fine-tune InternVL 1B. on MMBench dataset

In this notebook, you need to fine-tune the [InternVL](https://huggingface.co/OpenGVLab/InternVL3_5-1B) model on [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench) dataset which is comprised of various video-related tasks. Note that MMBench is quite small and is not made for tuning. So firstly you need to split it into training/testing parts.

The goal for the model in this notebook is to answer given multiple choice questions based on the video. The questions can be realetd to temporal aspects of the video, pose prediction and so on.
Sources:

* InternVL [documentation](https://internvl.readthedocs.io/en/latest/internvl2.0/introduction.html)
* InternVL [checkpoint on the hub](https://huggingface.co/OpenGVLab/InternVL2-1B)

## Define variables

We'll first set some variables useful througout this notebook and doo all the necessary imports.

In [35]:
import os
import av
import re
import gc
import sys
import json
import random
import bisect
import shutil
import traceback
import numpy as np
from nltk import edit_distance
from pathlib import Path
from copy import deepcopy
from typing import Dict, Any, List, Union, Tuple

from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer,
                          HfArgumentParser, Trainer, TrainingArguments,
                          set_seed, AutoProcessor)
from transformers import BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from huggingface_hub import snapshot_download, hf_hub_download
from datasets import load_dataset, concatenate_datasets
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError

import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping, Callback

from getpass import getpass
from huggingface_hub import login

In [2]:
MAX_LENGTH = 160
MODEL_ID = "OpenGVLab/InternVL3_5-1B"
REPO_ID = "Ivan1008/internvl3.5-finetune"

# os.environ["WANDB_API_KEY"] = getpass("Input wandb api key: ")
# os.environ["WANDB_MODE"] = "online"

access_token = getpass("Input hugging face access token: ")
login(access_token)

USE_LORA = False
USE_QLORA = True

Input hugging face access token:  ········


In [9]:
! git clone https://github.com/OpenGVLab/InternVL.git

Cloning into 'InternVL'...
remote: Enumerating objects: 3748, done.[K
remote: Counting objects: 100% (60/60), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 3748 (delta 32), reused 25 (delta 25), pack-reused 3688 (from 3)[K
Receiving objects: 100% (3748/3748), 39.63 MiB | 31.05 MiB/s, done.
Resolving deltas: 100% (2288/2288), done.


In [3]:
! ls InternVL/internvl_chat

README.md	zero_stage1_config.json
eval		zero_stage2_config.json
evaluate.sh	zero_stage3_config.json
examples	zero_stage3_config_100b.json
internvl	zero_stage3_config_100b_1e7_offload.json
pyproject.toml	zero_stage3_config_100b_1e8.json
shell		zero_stage3_config_34b.json
tools		zero_stage3_config_70b.json


In [4]:
sys.path.append('./InternVL/internvl_chat')

In [5]:
from internvl.dist_utils import init_dist
# from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from internvl.model.internvl_chat import (InternVisionConfig,
                                          InternVisionModel,
                                          InternVLChatConfig,
                                          InternVLChatModel)
from internvl.patch import (concat_pad_data_collator,
                            replace_llama_rmsnorm_with_fused_rmsnorm,
                            replace_train_sampler)
from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN,
                                      IMG_CONTEXT_TOKEN, IMG_END_TOKEN,
                                      IMG_START_TOKEN, QUAD_END_TOKEN,
                                      QUAD_START_TOKEN, REF_END_TOKEN,
                                      IMAGENET_MEAN, IMAGENET_STD,
                                      REF_START_TOKEN)
from internvl.train.dataset import (ConcatDataset, read_frames_decord,
                                    WeightedConcatDataset, build_transform,
                                    dynamic_preprocess, preprocess,
                                    preprocess_internlm, preprocess_mpt,
                                    preprocess_phi3,)

/opt/conda/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/opt/conda/compiler_compat/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status


petrel_client is not installed. If you read data locally instead of from ceph, ignore it.




# MVBench benchmark

[MVBench on HF Datasets](https://huggingface.co/datasets/OpenGVLab/MVBench)

<!-- ![MVbench1.png](https://huggingface.co/datasets/OpenGVLab/MVBench/resolve/main/assert/generation.png) -->

It consists of the 20 temporal task examples as follows.

![MVbench-structure.png](https://huggingface.co/datasets/OpenGVLab/MVBench/resolve/main/assert/task_example.png)


Here we have a nice viewer for each task:

[Dataset viewer](https://huggingface.co/datasets/OpenGVLab/MVBench/viewer/action_sequence)



We will start by downloading and processing the dataset. Even though MVBench is a small dataset, it still requires **around 1000B to store the videos**, so make sure you have enough free space.

First, we will use this mapping to get the datasets because each one is a separate subset in its own folder. Then we need a few helper functions to download videos and process them to fit the model's format (8 frames each video).

In [6]:
data_list = {
    "Action Sequence": ("action_sequence.json", "star/Charades_v1_480/", "video", True), # has start & end
    "Action Prediction": ("action_prediction.json", "star/Charades_v1_480/", "video", True), # has start & end
    "Action Antonym": ("action_antonym.json", "ssv2_video/", "video", False),
    "Fine-grained Action": ("fine_grained_action.json", "Moments_in_Time_Raw/videos/", "video", False),
    "Unexpected Action": ("unexpected_action.json", "FunQA_test/test/", "video", False),
    "Object Existence": ("object_existence.json", "clevrer/video_validation/", "video", False),
    "Object Interaction": ("object_interaction.json", "star/Charades_v1_480/", "video", True), # has start & end
    "Object Shuffle": ("object_shuffle.json", "perception/videos/", "video", False),
    "Moving Direction": ("moving_direction.json", "clevrer/video_validation/", "video", False),
    "Action Localization": ("action_localization.json", "sta/sta_video/", "video", True),  # has start & end
    "Scene Transition": ("scene_transition.json", "scene_qa/video/", "video", False),
    "Action Count": ("action_count.json", "perception/videos/", "video", False),
    "Moving Count": ("moving_count.json", "clevrer/video_validation/", "video", False),
    "Moving Attribute": ("moving_attribute.json", "clevrer/video_validation/", "video", False),
    "State Change": ("state_change.json", "perception/videos/", "video", False),
    "Fine-grained Pose": ("fine_grained_pose.json", "nturgbd/", "video", False),
    "Character Order": ("character_order.json", "perception/videos/", "video", False),
    "Egocentric Navigation": ("egocentric_navigation.json", "vlnqa/", "video", False),
    "Episodic Reasoning": ("episodic_reasoning.json", "tvqa/frames_fps3_hq/", "frame", True),  # has start & end, read frame
    "Counterfactual Inference": ("counterfactual_inference.json", "clevrer/video_validation/", "video", False),
}

data_dir = "dataset"
if not os.path.exists(data_dir):
    os.mkdir("dataset")

In [7]:
def read_video_pyav(video_path, start, end, n_frames=8):
    """
    Reads a video for given start-end timestamps interval
    and uniformly samples 8 frames of it
    """
    container = av.open(video_path)
    video = container.streams.get(0)[0]

    av_timestamps = [
        int(packet.pts * video.time_base) for packet in container.demux(video) if packet.pts is not None
    ]

    av_timestamps.sort()
    start_id = bisect.bisect_left(av_timestamps, start)
    end_id = bisect.bisect_left(av_timestamps, end)

    # in case it is a very short video, lets take a longer duration and sample
    if end_id  - start_id < 10:
        end_id += 10
        start_id -= 10

    end_id = min(len(av_timestamps) - 1, end_id)
    start_id = max(1, start_id)

    # We sample n_frames frames for tuning following the original paper
    # But we can increase the number of frames for longer videos and check out if it helps performance
    # Change the below "n_frames" to any number of frames you want, and note that more frames -> more computational resources needed
    indices = np.linspace(start_id, end_id, n_frames).astype(int)

    frames = []
    container.seek(0)
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_id:
            break
        if i >= start_id and i in indices:
            frames.append(frame)
    assert len(frames) == n_frames, f"Got {len(frames)} frames but should be {n_frames}. Check the indices: {indices};, start_id: {start_id}, end_id: {end_id}. Len of video is {len(av_timestamps)} frames."
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

In [8]:
def collate_read_video(example, path):
    # Some datasets have a start-end interval, so we try to get it if exists.
    # Otherwise just set a very large end timestamp
    clip = read_video_pyav(f'{path}/{example["video"]}', example.get("start", 1), example.get("end", 1e+10))
    example["clip"] = clip
    return example

In [11]:
# Download the videos from datasets repo and unzip.
# Make sure you have enough free space before downloading and unzipping

# videos = snapshot_download(repo_id="OpenGVLab/MVBench", allow_patterns="*", repo_type="dataset")
# for zip_file in os.listdir(f"{videos}/video"):
#     if zip_file.endswith(".zip"):
#         shutil.unpack_archive(f"{videos}/video/{zip_file}", f"{videos}/videos_unzipped/")

In [9]:
# Or download only selected task with appropriate files
# https://huggingface.co/docs/huggingface_hub/v0.24.7/en/package_reference/file_download#huggingface_hub.hf_hub_download

TASK_NAME = "Action Sequence"
annotation_fn, video_dir, video_type, has_clip = data_list.get(TASK_NAME)
print(f"Task: {TASK_NAME}")
print(f"Annotation file: {annotation_fn}")
print(f"Videos are stored in: {video_dir}")
print(f"Videos are represented as: {video_type}")
print(f"Videos are have start/end timestamps: {has_clip}")

annotation_fn_local = hf_hub_download(repo_id="OpenGVLab/MVBench",
                                    filename='json/' + annotation_fn,
                                    repo_type="dataset",
                                    local_dir=data_dir)
videos_zip = hf_hub_download(repo_id="OpenGVLab/MVBench",
                            filename='video/' + video_dir.split("/")[0] + ".zip",
                            repo_type="dataset",
                            local_dir=data_dir)

# Unzip
for zip_file in os.listdir(f"{data_dir}/video"):
    if zip_file.endswith(".zip"):
        print(zip_file)
        shutil.unpack_archive(f"{data_dir}/video/{zip_file}", f"{data_dir}/video/videos_unzipped/")

Task: Action Sequence
Annotation file: action_sequence.json
Videos are stored in: star/Charades_v1_480/
Videos are represented as: video
Videos are have start/end timestamps: True
star.zip


Make a following data structure:

```
dataset/
    /json
        task_name.json
    /video
        /task_name_prefix (optional)
            /task_name
                video_0.mp4
                video_1.mp4
                video_2.mp4
                ...
                video_100.mp4
```





In [10]:
ds = load_dataset("json", data_files=annotation_fn_local, split="train")
ds

Dataset({
    features: ['video', 'question', 'candidates', 'answer', 'start', 'end'],
    num_rows: 200
})

In [11]:
# Some tasks in MVBench are missing video files - keep it in mind!
has_missing = False
for sample in ds:
    if not os.path.exists(f"{data_dir}/video/videos_unzipped/{video_dir}/{sample['video']}"):
        print(f"Video `{sample['video']}` does not exists!")
        has_missing = True

Video `EDXBD.mp4` does not exists!
Video `K47J5.mp4` does not exists!
Video `9MNZ5.mp4` does not exists!
Video `QXT9W.mp4` does not exists!
Video `ABHC6.mp4` does not exists!
Video `ALXUC.mp4` does not exists!
Video `BAUQE.mp4` does not exists!
Video `PHH6B.mp4` does not exists!
Video `MNC10.mp4` does not exists!
Video `W7CR5.mp4` does not exists!
Video `Q8UJ8.mp4` does not exists!
Video `X9WTR.mp4` does not exists!


In [12]:
print(f"Dataset length = {len(ds)}")
if has_missing:
    ds = ds.filter(lambda x: os.path.exists(f"{data_dir}/video/videos_unzipped/{video_dir}/{x['video']}"))

print(f"Dataset length = {len(ds)}")

Dataset length = 200
Dataset length = 188


In [13]:
# Load videos and split them into frames
# ds = ds.map(collate_read_video,
#             batched=False,
#             fn_kwargs={"path": f"{data_dir}/video/videos_unzipped/{video_dir}"})

# Make conversation

def make_conversation(sample):
    id2choice = {0: "A", 1: "B", 2: "C", 3: "D"}
    question, candidates = sample["question"], sample["candidates"]
    answer_i = candidates.index(sample["answer"])
    answer_choice = id2choice[answer_i]

    mult_choice = "\n"
    for i, choice in enumerate(candidates):
        mult_choice += f"{id2choice[i]}. {choice};\n"

    conversations = [
         {'from': 'human', 'value': question + mult_choice + '<video>'},
         {'from': 'gpt', 'value': answer_choice + " " + sample["answer"]}
    ]
    sample['conversations'] = conversations
    return sample

ds = ds.map(make_conversation,
            batched=False)

In [14]:
ds[0]

{'video': 'ZS9XR.mp4',
 'question': 'What happened after the person took the food?',
 'candidates': ['Ate the medicine.',
  'Tidied up the blanket.',
  'Put down the cup/glass/bottle.',
  'Took the box.'],
 'answer': 'Ate the medicine.',
 'start': 1.5,
 'end': 17.1,
 'conversations': [{'from': 'human',
   'value': 'What happened after the person took the food?\nA. Ate the medicine.;\nB. Tidied up the blanket.;\nC. Put down the cup/glass/bottle.;\nD. Took the box.;\n<video>'},
  {'from': 'gpt', 'value': 'A Ate the medicine.'}]}

## Important: Apply Processor Patch for Video Support

**Note**: There is a known issue with InternVL processor for video frame handling. A patch file (`internvl-processor.diff`) is provided in the workspace. You may need to apply it to fix video frame indexing if you encounter errors during processing.

To apply the patch:
```bash
# Navigate to your transformers installation
cd $(python -c "import transformers; import os; print(os.path.dirname(transformers.__file__))")

# Apply the patch
patch -p2 < /path/to/internvl-processor.diff
```

This patch fixes the binding of video frames to video placeholders in the prompt, ensuring correct frame-to-placeholder mapping.




```json
{'id': 578004, 'video': '013901_013950/1025449886.mp4',
'conversations':
[{'from': 'human', 'value': 'Render a clear and concise summary of the video below.\n<video>'},
{'from': 'gpt', 'value': 'Aerial; drone flight around dangerous eroded steep stony slopes of cabo da roca; granite boulders, sea cliffs along the high coast; long seashore with lighthouse, overlooking the promontory, portugal'}]
}
```



In [15]:
# Load model's processor
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
processor.padding_side = "right"

processor

Qwen2TokenizerFast(name_or_path='OpenGVLab/InternVL3_5-1B', vocab_size=151643, model_max_length=14588, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|im_end|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>', '<|object_ref_start|>', '<|object_ref_end|>', '<|box_start|>', '<|box_end|>', '<|quad_start|>', '<|quad_end|>', '<|vision_start|>', '<|vision_end|>', '<|vision_pad|>', '<|image_pad|>', '<|video_pad|>', '<img>', '</img>', '<IMG_CONTEXT>', '<quad>', '</quad>', '<ref>', '</ref>', '<box>', '</box>']}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True)

## Custom Dataset Class

In the next step, you'll need **to define a custom dataset** class and the necessary functions to prepare our data for fine-tuning model. The VideoQADataset class extends the [PyTorch Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) class to facilitate loading and processing "MMBench". This class will handle the conversion of dataset samples into the format required for training and evaluation by preparing a prompt and making array from videos.

Next, you need **to define collate functions** to handle the batching of data during training and evaluation. These functions ensure that the input data is properly formatted and padded.

Here use the processor to turn the (video, target token sequence) into the format that the model expects (which is pixel_values, input_ids etc.). Use a dynamic padding of the batches: each batch contains ground truth sequences of varying lengths.

Also you can limit the length of the text tokens (input_ids) to a max length due to memory constraints, feel free to expand if your target token sequences are longer (I'd recommend plotting the average token length of your dataset to determine the optimal value).

The formatting of the input_ids is super important: you need to respect a so-called [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating).

Labels are created for the model by simply copying the inputs to the LLM (input_ids), but with padding tokens replaced by the ignore index of the loss function. This ensures that the model doesn't need to learn to predict padding tokens (used to batch examples together).

Why are the labels a copy of the model inputs, you may ask? The model will internally shift the labels one position to the right so that the model will learn to predict the next token. This can be seen here.

The collate function for evaluation is different, since there you only need to feed the prompt to the model, as we'll use the `generate()` method to autoregressively generate a completion.

In [None]:
# class VideoQADataset(Dataset):
#     """Dataset for Video QA fine-tuning."""

#     def __init__(
#         self,
#         template_name,
#         raw_data,
#         video_data_dir,
#         tokenizer,
#         ds_name,
#         num_image_token,
#         image_size=224,
#         is_train=True,
#         pad2square=False,
#         dynamic_image_size=False,
#         use_thumbnail=False,
#         min_dynamic_patch=1,
#         max_dynamic_patch=6,
#         min_num_frame=4,  # for video data
#         max_num_frame=12,  # for video data
#         sampling_method='rand',  # for video data
#         repeat_time=1,
#         normalize_type='imagenet',
#         random_seed=0,
#     ):
#         super().__init__()
#         self.ds_name = ds_name
#         self.tokenizer = tokenizer
#         self.group_by_length = None
#         self.tcs_loader = None
#         self.template_name = template_name
#         self.num_image_token = num_image_token
#         print(f'[Dataset] num_image_token: {num_image_token}')
#         print(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
#         print(f'[Dataset] use_thumbnail: {use_thumbnail}')
#         print(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')

#         self.image_size = image_size
#         self.is_train = is_train
#         self.pad2square = pad2square
#         self.max_num_frame = max_num_frame
#         self.num_frames = max_num_frame
#         self.min_num_frame = min_num_frame
#         self.sampling_method = sampling_method

#         print('Formatting inputs...Skip in lazy mode')

#         self.raw_data = raw_data
#         self.rng = np.random.default_rng(seed=random_seed)
#         self.rng.shuffle(self.raw_data)

#         gc.collect()
#         self.root = video_data_dir
#         self.cached_data_dict = {}
#         self.dynamic_image_size = dynamic_image_size
#         self.use_thumbnail = use_thumbnail
#         self.min_dynamic_patch = min_dynamic_patch
#         self.max_dynamic_patch = max_dynamic_patch
#         self.normalize_type = normalize_type

#         gc.collect()

#     def __len__(self):
#         return len(self.raw_data)

#     def get_preprocess_function(self):
#         # Select the appropriate preprocessing function based on the template name
#         return preprocess_internlm

#     def load_image(self, image_path):
#         # Load the image using tcs_loader if available, otherwise use PIL
#         if self.tcs_loader is not None and 's3://' in image_path:
#             return self.tcs_loader(image_path)
#         return Image.open(image_path).convert('RGB')

#     def get_image_path(self, image_path):
#         if image_path.startswith('s3://'):  # for ceph
#             image_path = self.root + image_path
#         else:  # for local image
#             image_path = os.path.join(self.root, image_path)
#         return image_path

#     def get_transform(self):
#         # Build transformation function
#         transform = build_transform(is_train=self.is_train, input_size=self.image_size,
#                                     pad2square=self.pad2square, normalize_type=self.normalize_type)
#         return transform


#     def video_get_item(self, data_item):
#         # Build transformation function
#         transform = self.get_transform()

#         # Ensure the first conversation contains a video placeholder
#         if '<video>' not in data_item['conversations'][0]['value']:
#             data_item['conversations'][0]['value'] = '<video>\n' + data_item['conversations'][0]['value']

#         # Get the video file path
#         video_file = data_item['video']
#         video_path = os.path.join(self.root, video_file)

#         # ИСПРАВЛЕНИЕ: использовать self.max_num_frame
#         num_frames = self.max_num_frame
#         tcs_loader = getattr(self, 'tcs_loader', None)
        
#         # Load the video frames
#         image_list = read_frames_decord(video_path,
#                                         num_frames=num_frames,
#                                         client=tcs_loader,
#                                         clip=(data_item.get('start', None),
#                                               data_item.get('end', None)))

#         # Generate special tokens for each video frame
#         special_tokens = '\n'.join(['Frame{}: <image>'.format(i + 1) for i in range(len(image_list))])
#         data_item['conversations'][0]['value'] = data_item['conversations'][0]['value'].replace(
#             '<video>', special_tokens)

#         # Transform each frame image and stack them into a tensor
#         pixel_values = [transform(image) for image in image_list]
#         pixel_values = torch.stack(pixel_values)
#         num_patches = pixel_values.size(0)

#         # Select the appropriate preprocessing function based on the template name
#         preprocess_function = self.get_preprocess_function()

#         # ИСПРАВЛЕНИЕ: добавить group_by_length как атрибут с fallback
#         group_by_length = getattr(self, 'group_by_length', None)
        
#         # Preprocess the conversations and generate the return dictionary
#         num_image_tokens = [self.num_image_token] * num_patches
#         ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
#                                   self.tokenizer, num_image_tokens, group_by_length=group_by_length,
#                                   ds_name=self.ds_name, num_image=num_patches)

#         # Create the final return dictionary
#         ret = dict(
#             input_ids=ret['input_ids'][0],
#             labels=ret['labels'][0],
#             attention_mask=ret['attention_mask'][0],
#             pixel_values=pixel_values,
#             image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
#         )
#         return ret

#     def pure_text_get_item(self, data_item):
#         # Build transformation function
#         transform = self.get_transform()

#         # Create a blank white image
#         image = Image.new('RGB', (224, 224), (255, 255, 255))

#         # Dynamically preprocess the image to generate patches
#         images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=1,
#                                     image_size=self.image_size, use_thumbnail=self.use_thumbnail)

#         # Apply the transformation to each image patch and stack them into a tensor
#         pixel_values = [transform(image) for image in images]
#         pixel_values = torch.stack(pixel_values)
#         num_patches = pixel_values.size(0)

#         # Ensure there is only one patch
#         assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'

#         # Select the appropriate preprocessing function based on the template name
#         preprocess_function = self.get_preprocess_function()

#         # Preprocess the conversations and generate the return dictionary
#         ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
#                                   self.tokenizer, [self.num_image_token * num_patches], text_only=True,
#                                   group_by_length=self.group_by_length, ds_name=self.ds_name)

#         # Create the final return dictionary
#         ret = dict(
#             input_ids=ret['input_ids'][0],
#             labels=ret['labels'][0],
#             attention_mask=ret['attention_mask'][0],
#             pixel_values=pixel_values,
#             image_flags=torch.tensor([0] * num_patches, dtype=torch.long)
#         )
#         return ret

#     def __getitem__(self, i) -> Dict[str, torch.Tensor]:
#         i = i % len(self.raw_data)
#         while True:
#             try:
#                 data_item = json.loads(self.raw_data[i])
#                 if 'image' in data_item and len(data_item['image']) != 0:
#                     if type(data_item['image']) == list:
#                         ret = self.multi_modal_multi_image_get_item(data_item)
#                     else:
#                         ret = self.multi_modal_get_item(data_item)
#                 elif 'video' in data_item and data_item['video'] is not None and data_item['video'] != '':
#                     ret = self.video_get_item(data_item)
#                 else:
#                     ret = self.pure_text_get_item(data_item)
#                 break
#             except Exception as e:
#                 print(e, self.ds_name, flush=True)
#                 if not isinstance(e, UnidentifiedImageError):
#                     traceback.print_exc()
#                 data_item = json.loads(self.raw_data[i])
#                 if 'image' in data_item:
#                     if type(data_item['image']) == list:
#                         images = [self.root + item for item in data_item['image']]
#                         print(f'Failed to load image: {images}, the dataset is: {self.ds_name}')
#                     else:
#                         if data_item['image'].startswith('s3://'):
#                             data_path = self.root + data_item['image']
#                         else:
#                             data_path = os.path.join(self.root, data_item['image'])
#                         print(f'Failed to load image: {data_path}, the dataset is: {self.ds_name}')
#                 elif 'video' in data_item:
#                     data_path = os.path.join(self.root, data_item['video'])
#                     print(f'Failed to load video: {data_path}, the dataset is: {self.ds_name}')
#                 i = random.randint(0, len(self.raw_data) - 1)
#         return ret

In [None]:
# class VideoQADatasetWithAnswers(VideoQADataset):
#     """
#     Wrapper around VideoQADataset that also stores original example data
#     for extracting ground truth answers during evaluation.
#     """
#     def __init__(self, original_data, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.original_data = original_data
    
#     def __getitem__(self, idx):
#         item = super().__getitem__(idx)
        
#         original_example = self.original_data[idx]
#         item['original_example'] = original_example
        
#         return item

In [None]:
# Simple dataset and collate functions
import json
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

def read_video_pyav(video_path, n_frames=8):
    """Read video frames using PyAV"""
    import av
    container = av.open(video_path)
    video_stream = container.streams.get(0)[0]
    
    total_frames = video_stream.frames
    indices = np.linspace(0, total_frames - 1, n_frames).astype(int)
    
    frames = []
    container.seek(0)
    for i, frame in enumerate(container.decode(video_stream)):
        if i in indices:
            frames.append(frame.to_ndarray(format="rgb24"))
        if len(frames) == n_frames:
            break
    
    return np.stack(frames)

def build_transform(image_size=224):
    """Build image transforms"""
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        normalize
    ])

def format_text_for_internvl(conversations, is_train=True):
    """Format text for InternVL chat template"""
    if is_train:
        human_msg = conversations[0]['value']
        assistant_msg = conversations[1]['value']
        # Replace <video> with frame tokens
        frame_tokens = '\n'.join([f'Frame{i+1}: <image>' for i in range(8)])
        human_msg = human_msg.replace('<video>', frame_tokens)
        text = f"<s>[INST] {human_msg} [/INST] {assistant_msg}</s>"
    else:
        human_msg = conversations[0]['value']
        human_msg = human_msg.replace('<video>', '')
        text = f"<s>[INST] {human_msg} [/INST]"
    return text

class SimpleVideoDataset(Dataset):
    """Simple video dataset for MVBench"""
    
    def __init__(self, data_list, video_dir, tokenizer, image_size=224, is_train=True):
        self.data_list = data_list
        self.video_dir = video_dir
        self.tokenizer = tokenizer
        self.image_size = image_size
        self.is_train = is_train
        self.transform = build_transform(image_size)
        
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        item = self.data_list[idx]
        
        # Load video frames
        video_path = f"{self.video_dir}/{item['video']}"
        video_frames = read_video_pyav(video_path)
        
        # Transform frames
        pixel_values = [self.transform(Image.fromarray(frame)) for frame in video_frames]
        pixel_values = torch.stack(pixel_values)  # (N, C, H, W)
        
        # Format text
        formatted_text = format_text_for_internvl(item['conversations'], self.is_train)
        
        # Tokenize
        encoded = self.tokenizer(
            formatted_text,
            truncation=True,
            max_length=512,
            padding=False,
            return_tensors="pt"
        )
        
        input_ids = encoded['input_ids'].squeeze(0)
        attention_mask = encoded['attention_mask'].squeeze(0)
        
        # Create labels (copy of input_ids for training)
        labels = input_ids.clone() if self.is_train else torch.full_like(input_ids, -100)
        
        # Create image flags (1 for video frames)
        image_flags = torch.ones(len(pixel_values), dtype=torch.long)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'pixel_values': pixel_values,
            'image_flags': image_flags,
            'labels': labels
        }

def simple_train_collate_fn(batch):
    """Simple training collate function"""
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    pixel_values = [item['pixel_values'] for item in batch]
    image_flags = [item['image_flags'] for item in batch]
    labels = [item['labels'] for item in batch]
    
    # Pad text sequences
    max_len = max(len(ids) for ids in input_ids)
    
    padded_input_ids = []
    padded_attention_mask = []
    padded_labels = []
    
    pad_id = 0  # tokenizer.pad_token_id
    
    for i in range(len(input_ids)):
        pad_len = max_len - len(input_ids[i])
        padded_input_ids.append(torch.cat([input_ids[i], torch.full((pad_len,), pad_id, dtype=torch.long)]))
        padded_attention_mask.append(torch.cat([attention_mask[i], torch.zeros(pad_len, dtype=torch.long)]))
        padded_labels.append(torch.cat([labels[i], torch.full((pad_len,), -100, dtype=torch.long)]))
    
    # Stack text data
    input_ids_batch = torch.stack(padded_input_ids)
    attention_mask_batch = torch.stack(padded_attention_mask)
    labels_batch = torch.stack(padded_labels)
    
    # Flatten video frames
    flattened_pixel_values = []
    flattened_image_flags = []
    
    for pv, flags in zip(pixel_values, image_flags):
        for frame_idx in range(pv.size(0)):
            flattened_pixel_values.append(pv[frame_idx])
            flattened_image_flags.append(flags[frame_idx] if frame_idx < len(flags) else torch.tensor(1))
    
    pixel_values_batch = torch.stack(flattened_pixel_values)
    image_flags_batch = torch.stack(flattened_image_flags)
    
    return input_ids_batch, attention_mask_batch, pixel_values_batch, labels_batch, image_flags_batch

def simple_eval_collate_fn(batch):
    """Simple evaluation collate function"""
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    pixel_values = [item['pixel_values'] for item in batch]
    image_flags = [item['image_flags'] for item in batch]
    
    # Pad text sequences
    max_len = max(len(ids) for ids in input_ids)
    
    padded_input_ids = []
    padded_attention_mask = []
    
    pad_id = 0  # tokenizer.pad_token_id
    
    for i in range(len(input_ids)):
        pad_len = max_len - len(input_ids[i])
        padded_input_ids.append(torch.cat([input_ids[i], torch.full((pad_len,), pad_id, dtype=torch.long)]))
        padded_attention_mask.append(torch.cat([attention_mask[i], torch.zeros(pad_len, dtype=torch.long)]))
    
    # Stack text data
    input_ids_batch = torch.stack(padded_input_ids)
    attention_mask_batch = torch.stack(padded_attention_mask)
    
    # Flatten video frames
    flattened_pixel_values = []
    flattened_image_flags = []
    
    for pv, flags in zip(pixel_values, image_flags):
        for frame_idx in range(pv.size(0)):
            flattened_pixel_values.append(pv[frame_idx])
            flattened_image_flags.append(flags[frame_idx] if frame_idx < len(flags) else torch.tensor(1))
    
    pixel_values_batch = torch.stack(flattened_pixel_values)
    image_flags_batch = torch.stack(flattened_image_flags)
    
    return input_ids_batch, attention_mask_batch, pixel_values_batch, None, image_flags_batch

# Create datasets
train_dataset = SimpleVideoDataset(
    data_list=dataset["train"],
    video_dir=f"{data_dir}/video/videos_unzipped/{video_dir}",
    tokenizer=tokenizer,
    image_size=224,
    is_train=True
)

eval_dataset = SimpleVideoDataset(
    data_list=dataset["test"],
    video_dir=f"{data_dir}/video/videos_unzipped/{video_dir}",
    tokenizer=tokenizer,
    image_size=224,
    is_train=False
)

print(f"Train dataset: {len(train_dataset)} examples")
print(f"Eval dataset: {len(eval_dataset)} examples")

# Dataset and collate functions based on template
import json
import torch
import numpy as np
import gc
import random
import os
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from PIL import UnidentifiedImageError
import traceback
from copy import deepcopy

def read_video_pyav(video_path, start, end, n_frames=8):
    """Read video frames using PyAV"""
    import av
    import bisect
    
    container = av.open(video_path)
    video = container.streams.get(0)[0]

    av_timestamps = [
        int(packet.pts * video.time_base) for packet in container.demux(video) if packet.pts is not None
    ]

    av_timestamps.sort()
    start_id = bisect.bisect_left(av_timestamps, start)
    end_id = bisect.bisect_left(av_timestamps, end)

    if end_id  - start_id < 10:
        end_id += 10
        start_id -= 10

    end_id = min(len(av_timestamps) - 1, end_id)
    start_id = max(1, start_id)

    indices = np.linspace(start_id, end_id, n_frames).astype(int)

    frames = []
    container.seek(0)
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_id:
            break
        if i >= start_id and i in indices:
            frames.append(frame)
    assert len(frames) == n_frames, f"Got {len(frames)} frames but should be {n_frames}"
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

def build_transform(is_train=True, input_size=224, pad2square=False, normalize_type='imagenet'):
    """Build image transforms"""
    if normalize_type == 'imagenet':
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    else:
        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    
    return transforms.Compose([
        transforms.Resize((input_size, input_size), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        normalize
    ])

def preprocess_internvl2_5(template_name, sources, tokenizer, num_image_token_list, text_only=False, 
                         group_by_length=False, ds_name=None, num_image=1):
    """Simple preprocessing function for InternVL2.5"""
    # For simplicity, we'll use manual text formatting
    conversations = sources[0]
    
    if len(conversations) >= 2:
        human_msg = conversations[0]['value']
        assistant_msg = conversations[1]['value']
        
        # Format for InternVL
        frame_tokens = '\n'.join([f'Frame{i+1}: <image>' for i in range(num_image)])
        human_msg = human_msg.replace('<video>', frame_tokens)
        
        text = f"<s>[INST] {human_msg} [/INST] {assistant_msg}</s>"
    else:
        human_msg = conversations[0]['value']
        human_msg = human_msg.replace('<video>', '')
        text = f"<s>[INST] {human_msg} [/INST]"
    
    # Tokenize
    encoded = tokenizer(
        text,
        truncation=True,
        max_length=512,
        padding=False,
        return_tensors="pt"
    )
    
    input_ids = encoded['input_ids']
    attention_mask = encoded['attention_mask']
    
    # Create labels
    labels = input_ids.clone()
    
    return {
        'input_ids': input_ids,
        'labels': labels,
        'attention_mask': attention_mask
    }

class VideoQADataset(Dataset):
    """Dataset for Video QA fine-tuning."""

    def __init__(
        self,
        template_name,
        raw_data,
        video_data_dir,
        tokenizer,
        ds_name,
        num_image_token,
        image_size=224,
        is_train=True,
        pad2square=False,
        dynamic_image_size=False,
        use_thumbnail=False,
        min_dynamic_patch=1,
        max_dynamic_patch=6,
        min_num_frame=8,  # for video data
        max_num_frame=8,  # for video data
        sampling_method='rand',  # for video data
        repeat_time=1,
        normalize_type='imagenet',
        random_seed=0,
    ):
        super(VideoQADataset, self).__init__()
        self.ds_name = ds_name
        self.tokenizer = tokenizer
        self.template_name = template_name
        self.num_image_token = num_image_token
        print(f'[Dataset] num_image_token: {num_image_token}')
        print(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
        print(f'[Dataset] use_thumbnail: {use_thumbnail}')
        print(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')

        self.image_size = image_size
        self.is_train = is_train
        self.pad2square = pad2square
        self.max_num_frame = max_num_frame
        self.min_num_frame = min_num_frame
        self.sampling_method = sampling_method

        print('Formatting inputs...Skip in lazy mode')

        self.raw_data = raw_data.shuffle(seed=random_seed) if hasattr(raw_data, 'shuffle') else raw_data

        gc.collect()
        self.root = video_data_dir
        self.cached_data_dict = {}
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail = use_thumbnail
        self.min_dynamic_patch = min_dynamic_patch
        self.max_dynamic_patch = max_dynamic_patch
        self.normalize_type = normalize_type

        gc.collect()

    def __len__(self):
        return len(self.raw_data)

    def get_preprocess_function(self):
        # Select the appropriate preprocessing function based on the template name
        return preprocess_internvl2_5
        
    def get_transform(self):
        # Build transformation function
        transform = build_transform(is_train=self.is_train, input_size=self.image_size,
                                    pad2square=self.pad2square, normalize_type=self.normalize_type)
        return transform

    def video_get_item(self, data_item):
        # Build transformation function
        transform = self.get_transform()

        # Ensure the first conversation contains a video placeholder
        if '<video>' not in data_item['conversations'][0]['value']:
            data_item['conversations'][0]['value'] = '<video>\n' + data_item['conversations'][0]['value']

        # Get the video file path
        video_file = data_item['video']
        video_path = os.path.join(self.root, video_file)
        
        # Use read_video_pyav instead of read_frames_decord
        image_list = read_video_pyav(video_path,
                                     start=data_item.get('start', 0),
                                     end=data_item.get('end', 1000000),
                                     n_frames=self.max_num_frame)
        
        # Convert numpy arrays to PIL Images
        image_list = [Image.fromarray(frame) for frame in image_list]
        
        # Generate special tokens for each video frame
        special_tokens = '\n'.join(['Frame{}: <image>'.format(i + 1) for i in range(len(image_list))])
        data_item['conversations'][0]['value'] = data_item['conversations'][0]['value'].replace(
            '<video>', special_tokens)

        # Transform each frame image and stack them into a tensor
        pixel_values = [transform(image) for image in image_list]
        pixel_values = torch.stack(pixel_values)
        num_patches = pixel_values.size(0)

        # Select the appropriate preprocessing function based on the template name
        preprocess_function = self.get_preprocess_function()

        # Preprocess the conversations and generate the return dictionary
        num_image_tokens = [self.num_image_token] * num_patches

        if self.is_train:
            # For training, use full conversation (human + assistant)
            sources = [deepcopy(data_item['conversations'])]
        else:
            # For evaluation, use only human part (prompt)
            sources = [deepcopy(data_item['conversations'])]
        
        ret = preprocess_function(template_name=self.template_name, 
                                  sources=sources,
                                  tokenizer=self.tokenizer,
                                  num_image_token_list=num_image_tokens,
                                  text_only=False,
                                  group_by_length=False,
                                  ds_name=self.ds_name, 
                                  num_image=num_patches
        )

        conv = data_item['conversations']
        
        # Extract the correct answer choice from the GPT response
        answer_choice = ''
        if len(conv) > 1 and 'value' in conv[1]:
            gpt_response = conv[1]['value']
            # Extract just the choice letter (A, B, C, D) - usually first character
            answer_choice = gpt_response.strip()[0]  # Get first character (A, B, C, or D)
            

        # Create the final return dictionary
        ret = dict(
            input_ids=ret['input_ids'][0],
            labels=ret['labels'][0],
            attention_mask=ret['attention_mask'][0],
            pixel_values=pixel_values,
            image_flags=torch.tensor([1] * num_patches, dtype=torch.long),
            answer_choice=answer_choice
        )

        return ret

    def pure_text_get_item(self, data_item):
        """Handle text-only examples"""
        # Simple text processing
        sources = [deepcopy(data_item['conversations'])]
        
        ret = preprocess_internvl2_5(template_name=self.template_name, 
                                    sources=sources,
                                    tokenizer=self.tokenizer,
                                    num_image_token_list=[0],
                                    text_only=True,
                                    group_by_length=False,
                                    ds_name=self.ds_name, 
                                    num_image=0)
        
        conv = data_item['conversations']
        answer_choice = ''
        if len(conv) > 1 and 'value' in conv[1]:
            answer_choice = conv[1]['value'].strip()[0]

        return dict(
            input_ids=ret['input_ids'][0],
            labels=ret['labels'][0],
            attention_mask=ret['attention_mask'][0],
            pixel_values=torch.zeros(1, 3, self.image_size, self.image_size),
            image_flags=torch.tensor([0], dtype=torch.long),
            answer_choice=answer_choice
        )

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        i = i % len(self.raw_data)
        while True:
            try:
                data_item = self.raw_data[i]
                if 'image' in data_item and len(data_item['image']) != 0:
                    # For simplicity, handle images as text in this version
                    ret = self.pure_text_get_item(data_item)
                elif 'video' in data_item and data_item['video'] is not None and data_item['video'] != '':
                    ret = self.video_get_item(data_item)
                else:
                    ret = self.pure_text_get_item(data_item)
                break
            except Exception as e:
                print(e, self.ds_name, flush=True)
                if not isinstance(e, UnidentifiedImageError):
                    traceback.print_exc()
                data_item = self.raw_data[i]
                if 'image' in data_item:
                    if type(data_item['image']) == list:
                        images = [self.root + item for item in data_item['image']]
                        print(f'Failed to load image: {images}, the dataset is: {self.ds_name}')
                    else:
                        if data_item['image'].startswith('s3://'):
                            data_path = self.root + data_item['image']
                        else:
                            data_path = os.path.join(self.root, data_item['image'])
                        print(f'Failed to load image: {data_path}, the dataset is: {self.ds_name}')
                elif 'video' in data_item:
                    data_path = os.path.join(self.root, data_item['video'])
                    print(f'Failed to load video: {data_path}, the dataset is: {self.ds_name}')
                i = random.randint(0, len(self.raw_data) - 1)
        return ret

# Create datasets
train_dataset = VideoQADataset(
    template_name='internlm2-chat',
    raw_data=dataset["train"],
    video_data_dir=f"{data_dir}/video/videos_unzipped/{video_dir}",
    tokenizer=tokenizer,
    ds_name="mvbench_train",
    num_image_token=256,
    image_size=224,
    is_train=True,
    min_num_frame=8,
    max_num_frame=8
)

eval_dataset = VideoQADataset(
    template_name='internlm2-chat',
    raw_data=dataset["test"],
    video_data_dir=f"{data_dir}/video/videos_unzipped/{video_dir}",
    tokenizer=tokenizer,
    ds_name="mvbench_eval",
    num_image_token=256,
    image_size=224,
    is_train=False,
    min_num_frame=8,
    max_num_frame=8
)

print(f"Train dataset: {len(train_dataset)} examples")
print(f"Eval dataset: {len(eval_dataset)} examples")

In [None]:
# Collate functions based on template
def train_collate_fn(examples):
    """
    Collate function for training that batches examples together.
    Handles padding for text and stacks video frames.
    """
    # Extract all components from examples
    input_ids = [example['input_ids'] for example in examples]
    attention_mask = [example['attention_mask'] for example in examples]
    pixel_values = [example['pixel_values'] for example in examples]
    image_flags = [example['image_flags'] for example in examples]
    labels = [example['labels'] for example in examples]
    answer_choices = [example['answer_choice'] for example in examples]

    # Pad input_ids, attention_mask, and labels to the same length
    max_length = max(len(ids) for ids in input_ids)
    
    padded_input_ids = []
    padded_attention_mask = []
    padded_labels = []
    
    for i in range(len(input_ids)):
        # Pad input_ids
        pad_length = max_length - len(input_ids[i])
        padded_input_ids.append(
            torch.cat([
                input_ids[i],
                torch.full((pad_length,), tokenizer.pad_token_id, dtype=torch.long)
            ])
        )
        
        # Pad attention_mask
        padded_attention_mask.append(
            torch.cat([
                attention_mask[i],
                torch.zeros(pad_length, dtype=torch.long)
            ])
        )
        
        # Pad labels (using -100 for ignore index)
        padded_labels.append(
            torch.cat([
                labels[i],
                torch.full((pad_length,), -100, dtype=torch.long)
            ])
        )

    # Stack all tensors
    input_ids_batch = torch.stack(padded_input_ids)
    attention_mask_batch = torch.stack(padded_attention_mask)
    labels_batch = torch.stack(padded_labels)
    
    # Handle pixel_values - FLATTEN the video frames
    # Instead of padding frames, we'll process each frame as a separate image
    flattened_pixel_values = []
    flattened_image_flags = []
    
    for i, (pv, flags) in enumerate(zip(pixel_values, image_flags)):
        # pv shape: (num_frames, channels, height, width)
        # We need to flatten this to (num_frames * batch_size, channels, height, width)
        num_frames = pv.size(0)
        for frame_idx in range(num_frames):
            flattened_pixel_values.append(pv[frame_idx])
            # Handle image_flags properly
            if frame_idx < len(flags):
                flattened_image_flags.append(flags[frame_idx])
            else:
                flattened_image_flags.append(torch.tensor(1))
    
    # Stack flattened pixel values
    if flattened_pixel_values:
        pixel_values_batch = torch.stack(flattened_pixel_values)
        image_flags_batch = torch.stack(flattened_image_flags)
    else:
        pixel_values_batch = torch.tensor([])
        image_flags_batch = torch.tensor([])
    
    return {
        'input_ids': input_ids_batch,
        'attention_mask': attention_mask_batch,
        'pixel_values': pixel_values_batch,
        'labels': labels_batch,
        'image_flags': image_flags_batch,
        'answer_choices': answer_choices
    }


def eval_collate_fn(examples):
    """
    Collate function for evaluation that batches examples together.
    Similar to training but doesn't need labels for generation.
    """
    # Extract all components from examples
    input_ids = [example['input_ids'] for example in examples]
    attention_mask = [example['attention_mask'] for example in examples]
    pixel_values = [example['pixel_values'] for example in examples]
    answer_choices = [example['answer_choice'] for example in examples]

    # Pad input_ids and attention_mask to the same length
    max_length = max(len(ids) for ids in input_ids)
    
    padded_input_ids = []
    padded_attention_mask = []
    
    for i in range(len(input_ids)):
        # Pad input_ids
        pad_length = max_length - len(input_ids[i])
        padded_input_ids.append(
            torch.cat([
                input_ids[i],
                torch.full((pad_length,), tokenizer.pad_token_id, dtype=torch.long)
            ])
        )
        
        # Pad attention_mask
        padded_attention_mask.append(
            torch.cat([
                attention_mask[i],
                torch.zeros(pad_length, dtype=torch.long)
            ])
        )

    # Stack all tensors
    input_ids_batch = torch.stack(padded_input_ids)
    attention_mask_batch = torch.stack(padded_attention_mask)
    
    # Handle pixel_values - FLATTEN the video frames
    flattened_pixel_values = []
    
    for i, pv in enumerate(pixel_values):
        # pv shape: (num_frames, channels, height, width)
        num_frames = pv.size(0)
        for frame_idx in range(num_frames):
            flattened_pixel_values.append(pv[frame_idx])
    
    # Stack flattened pixel values
    if flattened_pixel_values:
        pixel_values_batch = torch.stack(flattened_pixel_values)
    else:
        pixel_values_batch = torch.tensor([])
    
    return {
        'input_ids': input_ids_batch,
        'attention_mask': attention_mask_batch,
        'pixel_values': pixel_values_batch,
        'answer_choices': answer_choices
    }

# Simple collate functions for backward compatibility
def simple_train_collate_fn(batch):
    """Wrapper for train_collate_fn to maintain compatibility"""
    result = train_collate_fn(batch)
    return result['input_ids'], result['attention_mask'], result['pixel_values'], result['labels'], result['image_flags']

def simple_eval_collate_fn(batch):
    """Wrapper for eval_collate_fn to maintain compatibility"""
    result = eval_collate_fn(batch)
    return result['input_ids'], result['attention_mask'], result['pixel_values'], None, None

print("Collate functions created")

In [20]:
%matplotlib inline

from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML


example = dataset['train'][0]
# Read the video for visualization
video_path = f"{data_dir}/video/videos_unzipped/{video_dir}/{example['video']}"
clip = read_video_pyav(video_path, example.get('start', 1), example.get('end', 1e+10), n_frames=8)

# np array with shape (frames, height, width, channels)
video = np.array(clip)

print(f"Video shape: {video.shape}")
print(f"Question: {example['question']}")
print(f"Answer: {example['answer']}")
print(f"Candidates: {example['candidates']}")

Video shape: (8, 270, 480, 3)
Question: What happened after the person held the clothes?
Answer: Threw the pillow.
Candidates: ['Took the broom.', 'Threw the pillow.', 'Sat on the sofa/couch.', 'Put down the pillow.']


In [21]:
fig = plt.figure()
im = plt.imshow(video[0,:,:,:])
plt.axis("off")

plt.close() # this is required to not display the generated image

def init():
    im.set_data(video[0,:,:,:])

def animate(i):
    im.set_data(video[i,:,:,:])
    return im

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],
                               interval=300)
HTML(anim.to_html5_video())

And now wrap it in the Pytorch Datasets class and print one example as sanity check.

In [22]:
train_raw_data = [json.dumps(example) for example in dataset["train"]]
eval_raw_data = [json.dumps(example) for example in dataset["test"]]

tokenizer = getattr(processor, "tokenizer", None)
if tokenizer is None:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=True)

train_dataset = VideoQADataset(
    template_name='internlm2-chat',
    raw_data=train_raw_data,
    video_data_dir=f"{data_dir}/video/videos_unzipped/{video_dir}",
    tokenizer=tokenizer,
    ds_name='mvbench_action_sequence',
    num_image_token=256,
    image_size=448,
    is_train=True,
    pad2square=False,
    dynamic_image_size=True,
    use_thumbnail=False,
    min_dynamic_patch=1,
    max_dynamic_patch=6,
    min_num_frame=8,
    max_num_frame=8,
    sampling_method='uniform',
    normalize_type='imagenet',
    random_seed=42,
)

eval_dataset = VideoQADatasetWithAnswers(
    original_data=dataset["test"],
    template_name='internlm2-chat',
    raw_data=eval_raw_data,
    video_data_dir=f"{data_dir}/video/videos_unzipped/{video_dir}",
    tokenizer=tokenizer,
    ds_name='mvbench_action_sequence',
    num_image_token=256,
    image_size=448,
    is_train=False,
    pad2square=False,
    dynamic_image_size=True,
    use_thumbnail=False,
    min_dynamic_patch=1,
    max_dynamic_patch=6,
    min_num_frame=8,
    max_num_frame=8,
    sampling_method='uniform',
    normalize_type='imagenet',
    random_seed=42,
)

print(f"Train dataset: {len(train_dataset)} examples")
print(f"Test dataset: {len(eval_dataset)} examples")

[Dataset] num_image_token: 256
[Dataset] dynamic_image_size: True
[Dataset] use_thumbnail: False
[Dataset] min_dynamic_patch: 1, max_dynamic_patch: 6
Formatting inputs...Skip in lazy mode
[Dataset] num_image_token: 256
[Dataset] dynamic_image_size: True
[Dataset] use_thumbnail: False
[Dataset] min_dynamic_patch: 1, max_dynamic_patch: 6
Formatting inputs...Skip in lazy mode
Train dataset: 150 examples
Test dataset: 38 examples


In [23]:
# Display a sample from the training dataset
sample_data = train_dataset[0]

print("Keys in sample:", sample_data.keys())
print("Input IDs shape:", sample_data['input_ids'].shape)
print("Pixel values shape:", sample_data['pixel_values'].shape)
print("Labels shape:", sample_data['labels'].shape)

raw_example = json.loads(train_dataset.raw_data[0])
print("\nOriginal conversation:", raw_example['conversations'])
print("Video file:", raw_example['video'])

Keys in sample: dict_keys(['input_ids', 'labels', 'attention_mask', 'pixel_values', 'image_flags'])
Input IDs shape: torch.Size([14588])
Pixel values shape: torch.Size([4, 3, 448, 448])
Labels shape: torch.Size([14588])

Original conversation: [{'from': 'human', 'value': 'What happened before the person tidied up the clothes?\nA. Opened the bag.;\nB. Took the sandwich.;\nC. Took the blanket.;\nD. Sat on the floor.;\n<video>'}, {'from': 'gpt', 'value': 'C Took the blanket.'}]
Video file: D1NT7.mp4


# Model

## Load model
Next, load your InternVL model from the hub. This is a model with about 1 billion trainable parameters (as it combines a **Qwen2 1B language model** with a relatively low-parameter vision **InternViT encoder**). Do note that we load a model here which already has undergone supervised fine-tuning (SFT) instructions dataset. We can benefit from the fine-tuning that the model already has undergone.

## Full fine-tuning, LoRa and Q-LoRa

**Select the fine-tuning method.**

 For reference, fine-tuning a model using the AdamW optimizer (which is often used to optimize neural networks) with mixed precision, you need about 18 times the amount of parameters in GB of GPU RAM. So in this case, we would need 18x1 billion bytes = 18 GB of GPU RAM if we want to update all the parameters of the model. Not so huge right? But using PEFT approach it could be less.

Some clever people came up with the LoRa method (LoRa is short for low-rank adapation). It allows to just freeze the existing weights and only train a couple of adapter layers on top of the base model. Hugging Face offers the separate [PEFT library](https://huggingface.co/docs/peft/main/en/index) for easy use of LoRa, along with other Parameter-Efficient Fine-Tuning methods.

Moreover, one can not only freeze the existing base model but also quantize it (which means, shrinking down its size). A neural network's parameters are typically saved in either float32 (which means, 32 bits or 4 bytes are used to store each parameter value) or float16 (which means, 16 bits or half a byte - also called half precision). However, with some clever algorithms one can shrink each parameter to just 8 or 4 bits (half a byte!), without significant effect on final performance. Read all about it here: https://huggingface.co/blog/4bit-transformers-bitsandbytes.

This means that we're going to shrink the size of the base 1B model considerably using 4-bit quantization, and then only train a couple of adapter layers on top using LoRa (in float16). This idea of combining LoRa with quantization is called Q-LoRa and is the most memory friendly version.

There exist many forms of quantization, here we leverage the [BitsAndBytes integration](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig).

Some problems with attempting of import Qwen3Config:
```
--> 526 config, kwargs = AutoConfig.from_pretrained(
    527     pretrained_model_name_or_path,
    528     return_unused_kwargs=True,
    529     trust_remote_code=trust_remote_code,
    530     code_revision=code_revision,
    531     _commit_hash=commit_hash,
    532     **hub_kwargs,
    533     **kwargs,
    534 )
    536 # if torch_dtype=auto was passed here, ensure to pass it on
    537 if kwargs_orig.get("torch_dtype", None) == "auto":

File /opt/conda/lib/python3.11/site-packages/transformers/models/auto/configuration_auto.py:1035, in AutoConfig.from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
   1033     if os.path.isdir(pretrained_model_name_or_path):
   1034         config_class.register_for_auto_class()
-> 1035     return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
   1036 elif "model_type" in config_dict:
   1037     try:

File /opt/conda/lib/python3.11/site-packages/transformers/configuration_utils.py:568, in PretrainedConfig.from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)
    562     if config_dict["model_type"] != cls.model_type:
    563         logger.warning(
    564             f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
    565             f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
    566         )
--> 568 return cls.from_dict(config_dict, **kwargs)

File /opt/conda/lib/python3.11/site-packages/transformers/configuration_utils.py:734, in PretrainedConfig.from_dict(cls, config_dict, **kwargs)
    731 # We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
    732 config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
--> 734 config = cls(**config_dict)
    736 if hasattr(config, "pruned_heads"):
    737     config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}

File ~/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL3_5-1B/2f71cf52542334823e48a46ffba0e2bc9add3446/configuration_internvl_chat.py:67, in InternVLChatConfig.__init__(self, vision_config, llm_config, use_backbone_lora, use_llm_lora, select_layer, force_image_size, downsample_ratio, template, dynamic_image_size, use_thumbnail, ps_version, min_dynamic_patch, max_dynamic_patch, **kwargs)
     65     self.llm_config = Qwen3MoeConfig(**llm_config)
     66 elif architecture == 'Qwen3ForCausalLM':
---> 67     from transformers import Qwen3Config
     68     self.llm_config = Qwen3Config(**llm_config)
     69 else:

ImportError: cannot import name 'Qwen3Config' from 'transformers' (/opt/conda/lib/python3.11/site-packages/transformers/__init__.py)
```
So replace it by Qwen2Config:

In [24]:
import transformers
from transformers import PreTrainedModel, PretrainedConfig

if not hasattr(transformers, "Qwen3Config"):
    try:
        from transformers import Qwen2Config as _Base
    except ImportError:
        from transformers import PretrainedConfig as _Base

    class Qwen3Config(_Base):
        model_type = "qwen3"
        def __init__(self, **kwargs):
            super().__init__(**kwargs)

    transformers.Qwen3Config = Qwen3Config

Analogically we do it for Qwen3ForCausalLM, Qwen3MoeForCausalLM:

In [25]:
if not hasattr(transformers, "Qwen3ForCausalLM"):
    class Qwen3ForCausalLM(PreTrainedModel):
        def __init__(self, config): super().__init__(config)
    transformers.Qwen3ForCausalLM = Qwen3ForCausalLM

if not hasattr(transformers, "Qwen3MoeForCausalLM"):
    class Qwen3MoeForCausalLM(PreTrainedModel):
        def __init__(self, config): super().__init__(config)
    transformers.Qwen3MoeForCausalLM = Qwen3MoeForCausalLM

In [26]:
## Load model
# Three options for training, from the lowest precision training to the highest precision training:
# QLoRA: model uses 4-bit quantization, which helps in reducing memory usage while maintaining performance.
# Standard LoRA:  model is loaded with standard LoRA adaptations.
# Full Fine-Tuning: no memory optimization are done. In that case Flash Attention is used to speed up training, if hardware supports it.

# We use QLoRA as it's the most memory-efficient approach
# This allows fine-tuning on consumer GPUs while maintaining good performance
from transformers import AutoConfig, AutoModel

cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)

if hasattr(cfg, "llm_config"):
    llm_arch = getattr(cfg.llm_config, "architectures", [""])[0]
    if llm_arch == "Qwen3ForCausalLM":
        cfg.llm_config.architectures[0] = "Qwen2ForCausalLM"

if USE_QLORA:
    # Configure 4-bit quantization using bitsandbytes
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,  # Nested quantization for additional memory savings
        bnb_4bit_quant_type="nf4",  # NormalFloat 4-bit quantization
        bnb_4bit_compute_dtype=torch.bfloat16  # Compute in bfloat16 for better performance
    )
    
    # Load model with 4-bit quantization
    model = AutoModel.from_pretrained(
        MODEL_ID,
        config=cfg,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="cuda:0"
    )
    print("Model loaded with QLoRA (4-bit quantization)")
    
elif USE_LORA:
    # Load model in bfloat16 without quantization
    model = AutoModel.from_pretrained(
        MODEL_ID,
        dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto",
        use_flash_attn=False
    )
    print("Model loaded for standard LoRA")
    
else:
    # Full fine-tuning - load model without any optimizations
    model = AutoModel.from_pretrained(
        MODEL_ID,
        dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto",
        attn_implementation="flash_attention_2"  # Use Flash Attention if supported
    )
    print("Model loaded for full fine-tuning")

# Print model info
print(f"Model type: {type(model)}")
print(f"Model device: {model.device}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

Some weights of the model checkpoint at OpenGVLab/InternVL3_5-1B were not used when initializing InternVLChatModel: ['language_model.model.layers.0.self_attn.k_norm.weight', 'language_model.model.layers.0.self_attn.q_norm.weight', 'language_model.model.layers.1.self_attn.k_norm.weight', 'language_model.model.layers.1.self_attn.q_norm.weight', 'language_model.model.layers.10.self_attn.k_norm.weight', 'language_model.model.layers.10.self_attn.q_norm.weight', 'language_model.model.layers.11.self_attn.k_norm.weight', 'language_model.model.layers.11.self_attn.q_norm.weight', 'language_model.model.layers.12.self_attn.k_norm.weight', 'language_model.model.layers.12.self_attn.q_norm.weight', 'language_model.model.layers.13.self_attn.k_norm.weight', 'language_model.model.layers.13.self_attn.q_norm.weight', 'language_model.model.layers.14.self_attn.k_norm.weight', 'language_model.model.layers.14.self_attn.q_norm.weight', 'language_model.model.layers.15.self_attn.k_norm.weight', 'language_model.m

Model loaded with QLoRA (4-bit quantization)
Model type: <class 'transformers_modules.OpenGVLab.InternVL3_5-1B.2f71cf52542334823e48a46ffba0e2bc9add3446.modeling_internvl_chat.InternVLChatModel'>
Model device: cuda:0
Total parameters: 687,130,624


In [27]:
def find_all_linear_names(model):
    # Only for LoRA ot QLoRA

    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['multi_modal_projector', 'vision_model']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:
        lora_module_names.remove('lm_head')
    return list(lora_module_names)
    

target_modules = find_all_linear_names(model.language_model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=target_modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

if USE_QLORA:
    model = prepare_model_for_kbit_training(model)

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

trainable params: 8,716,288 || all params: 1,069,664,256 || trainable%: 0.8149


## Define PyTorch Lightning Module for Video-LLaVA
To streamline the training and evaluation of the Video-InternVL model, you can use [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html), which abstracts away much of the boilerplate code and provides a structured framework for model training. In this section, you need to define the InternVLModelPLModule, a custom PyTorch Lightning module that encapsulates the model, training loop, validation loop, and optimizer configuration.

### InternVLModelPLModule Class

The InternVLModelPLModule class inherits from LightningModule and includes methods for training, validation, and optimizer configuration. This setup ensures a clean and efficient training process.

Basically, PyTorch Lightning will take care of all device placements (.to(device)) for us, as well as the backward pass, putting the model in training mode, etc.

Notice the difference between a training step and an evaluation step:

- a training step only consists of a forward pass, in which we compute the cross-entropy loss between the model's next token predictions and the ground truth (in parallel for all tokens, this technique is known as "teacher forcing"). The backward pass is handled by PyTorch Lightning.
- an evaluation step consists of making the model autoregressively complete the prompt using the generate() method. After that, you compute an evaluation metric between the predicted sequences and the ground truth ones. This allows you to see how the model is improving over the course of training. The metric we use here is accuracy of answering the question.

Besides that, you define the optimizer to use (AdamW is a good default choice) and the data loaders, which use the collate functions defined above to batch together items of the PyTorch datasets. Do note that AdamW is a pretty heavy optimizer in terms of memory requirements, but as we're training with QLoRa we only need to store optimizer states for the adapter layers. For full fine-tuning, one could take a look at more memory friendly optimizers such as 8-bit Adam.

In [None]:
class InternVLModelPLModule(L.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        
        self.lr = config.get("lr", 1e-4)
        self.batch_size = config.get("batch_size", 1)
        self.max_epochs = config.get("max_epochs", 2)
        self.warmup_steps = config.get("warmup_steps", 50)
        
        self.validation_outputs = []
        
    def training_step(self, batch, batch_idx):
        """
        Training step: forward pass and loss computation
        """
        # FIX: Updated to handle image_flags
        input_ids, attention_mask, pixel_values_videos, labels, image_flags = batch
        
        # Reshape pixel_values for model [batch*frames, C, H, W]
        if pixel_values_videos.ndim == 5:
            batch_size, num_frames, c, h, w = pixel_values_videos.shape
            pixel_values_videos = pixel_values_videos.reshape(batch_size * num_frames, c, h, w)
            image_flags = image_flags.reshape(batch_size * num_frames)
        
        # Forward pass
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values_videos,
            labels=labels,
            image_flags=image_flags  # FIX: Pass image_flags to model
        )
        
        loss = outputs.loss
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        """
        Validation step: generate answers and compute accuracy
        """
        if batch is None:
            return
            
        # FIX: Updated to handle image_flags
        input_ids, attention_mask, pixel_values_videos, answer_choices, image_flags = batch
        
        # Reshape pixel_values for model
        if pixel_values_videos.ndim == 5:
            batch_size, num_frames, c, h, w = pixel_values_videos.shape
            pixel_values_videos = pixel_values_videos.reshape(batch_size * num_frames, c, h, w)
            image_flags = image_flags.reshape(batch_size * num_frames)
        
        # Generate answers
        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values_videos,
            image_flags=image_flags,  # FIX: Pass image_flags to generate
            max_new_tokens=20,
            do_sample=False,
            pad_token_id=self.processor.tokenizer.pad_token_id if hasattr(self.processor.tokenizer, 'pad_token_id') else self.processor.tokenizer.eos_token_id
        )
        
        # Decode predictions
        generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
        
        # Extract only generated part (after prompt)
        prompts = self.processor.batch_decode(input_ids, skip_special_tokens=True)
        generated_answers = []
        
        for gen_text, prompt in zip(generated_texts, prompts):
            if gen_text.startswith(prompt):
                answer = gen_text[len(prompt):].strip()
            else:
                answer = gen_text.strip()
            generated_answers.append(answer)
        
        # Extract predicted choices (A, B, C, D)
        predicted_choices = []
        for answer in generated_answers:
            choice = None
            # Look for first choice letter in answer
            for char in answer:
                if char.upper() in ['A', 'B', 'C', 'D']:
                    choice = char.upper()
                    break
            predicted_choices.append(choice if choice else "")
        
        # Compute accuracy
        correct = sum(1 for pred, gt in zip(predicted_choices, answer_choices) if pred == gt)
        total = len(answer_choices)
        accuracy = correct / total if total > 0 else 0.0
        
        # Store for epoch metrics
        self.validation_outputs.append({
            'correct': correct,
            'total': total,
            'accuracy': accuracy
        })
        
        self.log("val_accuracy", accuracy, prog_bar=True, on_step=False, on_epoch=True)
        
        return accuracy
    
    def on_validation_epoch_end(self):
        """
        Compute and log epoch metrics
        """
        if len(self.validation_outputs) > 0:
            total_correct = sum(x['correct'] for x in self.validation_outputs)
            total_samples = sum(x['total'] for x in self.validation_outputs)
            epoch_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
            
            self.log("val_epoch_accuracy", epoch_accuracy, prog_bar=True, logger=True)
            
            print(f"\n{'='*50}")
            print(f"Validation Epoch Results:")
            print(f"Accuracy: {epoch_accuracy:.4f} ({total_correct}/{total_samples})")
            print(f"{'='*50}")
            
            # Clear for next epoch
            self.validation_outputs.clear()

    def configure_optimizers(self):
        """
        Configure optimizer and scheduler
        """
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0.01
        )
        
        # Linear warmup scheduler
        def lr_lambda(current_step):
            if current_step < self.warmup_steps:
                return float(current_step) / float(max(1, self.warmup_steps))
            return 1.0
        
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1
            }
        }

    def train_dataloader(self):
        return DataLoader(
            train_dataset, 
            collate_fn=simple_train_collate_fn,
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=2
        )

    def val_dataloader(self):
        return DataLoader(
            eval_dataset, 
            collate_fn=simple_eval_collate_fn,
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=2
        )

In [None]:
# Simple Lightning module
import lightning as L
import torch

class SimpleInternVLPLModule(L.LightningModule):
    """Simple Lightning module for InternVL fine-tuning"""
    
    def __init__(self, model, learning_rate=1e-4):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        
    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, pixel_values, labels, image_flags = batch
        
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_flags=image_flags,
            labels=labels
        )
        
        loss = outputs.loss
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, pixel_values, labels, image_flags = batch
        
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_flags=image_flags,
            labels=labels
        )
        
        loss = outputs.loss
        self.log('val_loss', loss, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

# Create model module
simple_model_module = SimpleInternVLPLModule(model, learning_rate=1e-4)
print("Simple Lightning module created")

In [None]:
# Lightning module adapted from Yurii's notebook
import lightning as L
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
import torch.nn.functional as F
import torch

class InternVLModelPLModule(L.LightningModule):
    def __init__(self, config, tokenizer, model):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer
        self.model = model
        self.model.train()
        
        # Training parameters
        self.batch_size = config.get('batch_size', 1)
        self.learning_rate = config.get('learning_rate', 1e-4)
        self.warmup_steps = config.get('warmup_steps', 50)
        self.max_epochs = config.get('max_epochs', 10)
        
        # Save hyperparameters
        self.save_hyperparameters(ignore=['model', 'tokenizer'])
        
    def training_step(self, batch, batch_idx):
        self.model.train()
        # Extract inputs
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        image_flags = batch['image_flags']
        
        # Forward pass
        # Direct forward pass without PEFT wrapper issues
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            outputs = self.model.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pixel_values=pixel_values,
                labels=labels,
                image_flags=image_flags,
                return_dict=True
            )
        
        loss = outputs.loss
        
        # Log training loss
        self.log('train_loss', loss, prog_bar=True, logger=True, sync_dist=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx, dataset_idx=0):
        self.model.eval()
        # For validation, we'll compute both loss and generation accuracy
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        image_flags = batch['image_flags']
        answer_choices = batch.get('answer_choices', [])

        # Compute validation loss
        val_loss = None
        if labels is not None:
            with torch.no_grad():
                with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                    outputs = self.model.base_model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        pixel_values=pixel_values,
                        labels=labels,
                        image_flags=image_flags,
                        return_dict=True
                    )
                val_loss = outputs.loss
                self.log('val_loss', val_loss, prog_bar=True, logger=True, sync_dist=True)

        # Generate responses for accuracy calculation, add generation prompt
        response_ids = []
        for input_id in input_ids:
            decoded_text = self.tokenizer.decode(input_id, skip_special_tokens=False).strip()
            # For InternVL format, we need to handle the chat template differently
            if '[/INST]' in decoded_text:
                response = decoded_text.split('[/INST]')[0] + '[/INST]'
            else:
                response = decoded_text + ' '  # Fallback
            
            response_id = self.tokenizer(
                response,
                return_tensors='pt',
                padding='max_length',
                max_length=self.tokenizer.model_max_length,
                truncation=False,
            ).input_ids

            response_ids.append(response_id)

        response_ids = torch.cat(response_ids).to(self.model.device)
            
        with torch.no_grad():
            with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                generated_ids = self.model.base_model.generate(
                    input_ids=response_ids,
                    pixel_values=pixel_values,
                    image_flags=image_flags,
                    max_new_tokens=10,
                    do_sample=False,
                    temperature=1.0,
                    pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )

        print(f"\n=== Epoch {self.current_epoch}, Batch {batch_idx} ===")
        for i in range(min(len(generated_ids), 2)):  # First 2 samples
            input_text = self.tokenizer.decode(response_ids[i], skip_special_tokens=True).strip()
            generated_text = self.tokenizer.decode(generated_ids[i], skip_special_tokens=True).strip()
            expected_answer = answer_choices[i] if i < len(answer_choices) else ''
            print(f"Sample {i}:")
            print(f"  Expected: '{expected_answer}'")
            print(f"  Input: '{input_text}'")
            print(f"  Generated: '{generated_text}'")
            print(f"  Match: {generated_text.startswith(expected_answer) if expected_answer else 'N/A'}")
        
        # Calculate accuracy
        accuracy = self._calculate_accuracy(generated_ids, answer_choices, response_ids)
        self.log('val_accuracy', accuracy, prog_bar=True, logger=True, sync_dist=True)
        
        return {'val_loss': val_loss, 'val_accuracy': accuracy}
    
    def _calculate_accuracy(self, generated_ids, answer_choices, response_ids):
        """Calculate accuracy by comparing generated text with answer choices"""
        if not answer_choices:
            print('############# There is no answers #########')
            return 0.0
        
        correct = 0
        total = len(generated_ids)
        
        for i in range(total):
            generated_tokens = generated_ids[i]
            # Decode generated text
            generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
            
            # Get expected answer (first character: A, B, C, or D)
            expected_answer = answer_choices[i] if i < len(answer_choices) else ''
            
            # Check if generated text starts with the expected answer
            if expected_answer and generated_text.startswith(expected_answer):
                correct += 1
        
        return correct / total if total > 0 else 0.0
    
    def configure_optimizers(self):
        """Configure optimizer and learning rate scheduler"""
        # Separate parameters for different weight decay
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() 
                          if p.requires_grad and not any(nd in n for nd in no_decay)],
                "weight_decay": self.config.get('weight_decay', 0.01),
            },
            {
                "params": [p for n, p in self.model.named_parameters() 
                          if p.requires_grad and any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-8
        )
        
        # Calculate total training steps
        train_loader = self.train_dataloader()
        steps_per_epoch = len(train_loader)
        total_steps = steps_per_epoch * self.max_epochs
        
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=total_steps
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
                'frequency': 1
            }
        }
    
    def train_dataloader(self):
        return DataLoader(
            train_dataset, 
            collate_fn=train_collate_fn,
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=4
        )
    
    def val_dataloader(self):
        return DataLoader(
            eval_dataset, 
            collate_fn=eval_collate_fn,
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=4
        )

# Create model module with Yurii's configuration
config = {
    "learning_rate": 1e-4,
    "batch_size": 1,
    "max_epochs": 1,
    "warmup_steps": 50,
    "weight_decay": 0.01,
}

model_module = InternVLModelPLModule(config, tokenizer, model)

print("Lightning module adapted from Yurii's notebook")

## Define callbacks
Optionally, Lightning allows to define so-called [callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), which are arbitrary pieces of code that can be executed during training.

You'd better use the EarlyStopping callback of Lightning, which will automatically stop training once the evaluation metric (edit distance in our case) doesn't improve after 3 epochs.

In [None]:
from huggingface_hub import HfApi
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="InternVL3.5_finetuning_qlora")

api = HfApi()

class PushToHubCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
        pl_module.model.push_to_hub(REPO_ID,
                                    commit_message=f"Training in progress, epoch {trainer.current_epoch}")

    def on_train_end(self, trainer, pl_module):
        print(f"Pushing model to the hub after training")
        pl_module.processor.push_to_hub(REPO_ID,
                                    commit_message=f"Training done")
        pl_module.model.push_to_hub(REPO_ID,
                                    commit_message=f"Training done")

# FIX: Monitor val_epoch_accuracy instead of val_accuracy
early_stop_callback = EarlyStopping(monitor="val_epoch_accuracy",
                                    patience=3, verbose=False, mode="max")

## Train!
 Trainer class supports many more flags. See the [docs](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#lightning.pytorch.trainer.trainer.Trainer)

In [None]:
import torch

os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.set_float32_matmul_precision('high')

# Test the complete pipeline before training
print("Testing complete pipeline...")

# Test data loading
print("\n1. Testing data loading...")
try:
    sample = train_dataset[0]
    print(f"✅ Train sample loaded successfully")
    print(f"   Keys: {list(sample.keys())}")
    print(f"   Input IDs shape: {sample['input_ids'].shape}")
    print(f"   Pixel values shape: {sample['pixel_values'].shape}")
    print(f"   Image flags shape: {sample['image_flags'].shape}")
except Exception as e:
    print(f"❌ Error loading train sample: {e}")

# Test collate functions  
print("\n2. Testing collate functions...")
try:
    # Create small batches - get multiple samples to test batching
    train_samples = [train_dataset[0], train_dataset[1]]
    train_batch = simple_train_collate_fn(train_samples)
    print(f"✅ Train collate function works")
    print(f"   Batch shapes: input_ids={train_batch[0].shape}, attention_mask={train_batch[1].shape}")
    print(f"   Pixel values shape: {train_batch[2].shape}, labels shape: {train_batch[3].shape}")
    print(f"   Image flags shape: {train_batch[4].shape}")
    
    eval_samples = [eval_dataset[0], eval_dataset[1]]
    eval_batch = simple_eval_collate_fn(eval_samples)
    print(f"✅ Eval collate function works")
    print(f"   Answer choices: {len(eval_batch[3])} answers")
    print(f"   Image flags shape: {eval_batch[4].shape}")
    
except Exception as e:
    print(f"❌ Error in collate functions: {e}")
    import traceback
    traceback.print_exc()

# Test model forward pass
print("\n3. Testing model forward pass...")
try:
    # Move batch to device
    device = next(model.parameters()).device
    input_ids, attention_mask, pixel_values, labels, image_flags = train_batch
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    labels = labels.to(device)
    image_flags = image_flags.to(device)
    
    # Reshape pixel values and image_flags for model
    if pixel_values.ndim == 5:
        batch_size, num_frames, c, h, w = pixel_values.shape
        pixel_values = pixel_values.reshape(batch_size * num_frames, c, h, w).to(device)
        image_flags = image_flags.reshape(batch_size * num_frames).to(device)
    
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            labels=labels,
            image_flags=image_flags
        )
    
    print(f"✅ Model forward pass works")
    print(f"   Loss: {outputs.loss.item():.4f}")
    print(f"   Logits shape: {outputs.logits.shape}")
    
except Exception as e:
    print(f"❌ Error in model forward pass: {e}")
    import traceback
    traceback.print_exc()

# Test model generation
print("\n4. Testing model generation...")
try:
    with torch.no_grad():
        # Use smaller batch for generation test
        gen_input_ids = input_ids[:1]  # Use just first sample
        gen_attention_mask = attention_mask[:1]
        
        # Calculate frames for first video
        first_video_frames = train_samples[0]['pixel_values'].shape[0]
        gen_pixel_values = pixel_values[:first_video_frames]  # First video only
        gen_image_flags = image_flags[:first_video_frames]
        
        generated_ids = model.generate(
            input_ids=gen_input_ids,
            attention_mask=gen_attention_mask,
            pixel_values=gen_pixel_values,
            image_flags=gen_image_flags,
            max_new_tokens=10,
            do_sample=False
        )
    
    generated_text = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"✅ Model generation works")
    print(f"   Generated text length: {len(generated_text)} chars")
    print(f"   Generated: {generated_text[:100]}...")
    
except Exception as e:
    print(f"❌ Error in model generation: {e}")
    import traceback
    traceback.print_exc()

# Test Lightning module
print("\n5. Testing Lightning module...")
try:
    # Test training step
    loss = model_module.training_step(train_batch, 0)
    print(f"✅ Lightning training step works")
    print(f"   Training loss: {loss.item():.4f}")
    
    # Test validation step
    val_accuracy = model_module.validation_step(eval_batch, 0)
    print(f"✅ Lightning validation step works")
    print(f"   Validation accuracy: {val_accuracy:.4f}")
    
except Exception as e:
    print(f"❌ Error in Lightning module: {e}")
    import traceback
    traceback.print_exc()

print(f"\n{'='*60}")
print("PIPELINE TEST COMPLETE")
print(f"{'='*60}")
print("If all tests passed ✅, you can proceed with training!")

In [None]:
# Test the pipeline
print("Testing pipeline...")
try:
    # Test dataset
    sample = train_dataset[0]
    print(f"Sample keys: {sample.keys()}")
    print(f"Input IDs shape: {sample['input_ids'].shape}")
    print(f"Pixel values shape: {sample['pixel_values'].shape}")
    print(f"Image flags shape: {sample['image_flags'].shape}")
    
    # Test collate function
    train_samples = [train_dataset[0], train_dataset[1]]
    batch = simple_train_collate_fn(train_samples)
    
    input_ids, attention_mask, pixel_values, labels, image_flags = batch
    print(f"\nBatch shapes:")
    print(f"  input_ids: {input_ids.shape}")
    print(f"  attention_mask: {attention_mask.shape}")
    print(f"  pixel_values: {pixel_values.shape}")
    print(f"  labels: {labels.shape}")
    print(f"  image_flags: {image_flags.shape}")
    
    # Test forward pass
    model.eval()
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_flags=image_flags,
            labels=labels
        )
    
    print(f"\n✅ Forward pass successful!")
    print(f"   Loss: {outputs.loss:.4f}")
    
except Exception as e:
    print(f"❌ Error: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Test the pipeline with new template structure
print("Testing pipeline with template structure...")
try:
    # Test dataset
    sample = train_dataset[0]
    print(f"Sample keys: {list(sample.keys())}")
    print(f"Input IDs shape: {sample['input_ids'].shape}")
    print(f"Pixel values shape: {sample['pixel_values'].shape}")
    print(f"Image flags shape: {sample['image_flags'].shape}")
    print(f"Answer choice: {sample['answer_choice']}")
    
    # Test collate function
    train_samples = [train_dataset[0]]
    batch = train_collate_fn(train_samples)
    
    print(f"\nBatch format: dict")
    print(f"Batch keys: {list(batch.keys())}")
    print(f"  input_ids: {batch['input_ids'].shape}")
    print(f"  attention_mask: {batch['attention_mask'].shape}")
    print(f"  pixel_values: {batch['pixel_values'].shape}")
    print(f"  image_flags: {batch['image_flags'].shape}")
    print(f"  labels: {batch['labels'].shape}")
    print(f"  answer_choices: {batch['answer_choices']}")
    
    # Test forward pass with Lightning module
    model.eval()
    with torch.no_grad():
        outputs = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            pixel_values=batch['pixel_values'],
            image_flags=batch['image_flags'],
            labels=batch['labels']
        )
    
    print(f"\n✅ Forward pass successful!")
    print(f"   Loss: {outputs.loss:.4f}")
    
    # Test Lightning module training step
    model_module.train()
    loss = model_module.training_step(batch, 0)
    print(f"✅ Lightning training step successful!")
    print(f"   Training loss: {loss:.4f}")
    
    print("\n✅ All tests passed! Template structure working correctly.")
    
except Exception as e:
    print(f"❌ Error: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# DataLoaders are now defined inside the Lightning module
# The module will automatically create train_dataloader() and val_dataloader()
print("DataLoaders are now handled internally by the Lightning module")
print(f"Training configuration: {config}")

# Test the adapted module
print("\nTesting adapted Lightning module...")
try:
    # Test training step
    train_batch = next(iter(DataLoader(train_dataset, collate_fn=train_collate_fn, batch_size=1)))
    loss = model_module.training_step(train_batch, 0)
    print(f"✅ Training step successful: {loss:.4f}")
    
    # Test validation step
    eval_batch = next(iter(DataLoader(eval_dataset, collate_fn=eval_collate_fn, batch_size=1)))
    val_results = model_module.validation_step(eval_batch, 0)
    print(f"✅ Validation step successful: loss={val_results.get('val_loss', 'N/A')}, accuracy={val_results.get('val_accuracy', 'N/A'):.3f}")
    
    print("✅ Lightning module adapted successfully from Yurii's notebook!")
    
except Exception as e:
    print(f"❌ Error: {e}")
    import traceback
    traceback.print_exc()

# Training with adapted Lightning module (Yurii's version)
print("Starting training with Yurii's Lightning module...")

try:
    trainer = L.Trainer(
        max_epochs=config["max_epochs"],
        accelerator="auto",
        precision="16-mixed",
        enable_checkpointing=False,
        logger=False,
        gradient_clip_val=1.0,
        accumulate_grad_batches=1
    )
    
    trainer.fit(model_module)
    
    print("✅ Training completed successfully!")
    
    # Test final model
    model_module.eval()
    with torch.no_grad():
        test_batch = next(iter(DataLoader(eval_dataset, collate_fn=eval_collate_fn, batch_size=1)))
        
        # Extract components for manual generation test
        input_ids = test_batch['input_ids']
        pixel_values = test_batch['pixel_values']
        image_flags = test_batch['image_flags']
        
        # Prepare prompt for generation
        response_ids = []
        for input_id in input_ids:
            decoded_text = tokenizer.decode(input_id, skip_special_tokens=False).strip()
            if '[/INST]' in decoded_text:
                response = decoded_text.split('[/INST]')[0] + '[/INST]'
            else:
                response = decoded_text + ' '
            
            response_id = tokenizer(
                response,
                return_tensors='pt'
            ).input_ids
            response_ids.append(response_id)
        
        response_ids = torch.cat(response_ids).to(model.device)
        
        # Generate response
        generated_ids = model_module.model.base_model.generate(
            input_ids=response_ids,
            pixel_values=pixel_values,
            image_flags=image_flags,
            max_new_tokens=15,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
        )
        
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        expected_answer = test_batch.get('answer_choices', [''])[0]
        
        print(f"\nFinal test:")
        print(f"  Expected: '{expected_answer}'")
        print(f"  Generated: '{generated_text}'")
        print(f"  Match: {generated_text.startswith(expected_answer) if expected_answer else 'N/A'}")
    
except Exception as e:
    print(f"❌ Training error: {e}")
    import traceback
    traceback.print_exc()

In [None]:
from transformers import AutoProcessor, BitsAndBytesConfig, AutoModelForCausalLM
import torch
from peft import PeftModel

# Load the fine-tuned model for inference
# Option 1: Load from HuggingFace Hub (if you pushed it)
# processor_inference = AutoProcessor.from_pretrained(REPO_ID, trust_remote_code=True)
# model_inference = AutoModelForCausalLM.from_pretrained(REPO_ID, trust_remote_code=True, device_map="auto")

# Option 2: Load from local checkpoint
# First load the base model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model_inference = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto"
)

# If using PEFT/LoRA, load the adapter weights
# model_inference = PeftModel.from_pretrained(model_inference, "path/to/checkpoint")

processor_inference = processor  # Use the same processor

print("Model loaded for inference")
print(f"Model device: {model_inference.device}")

See one example from the validation set here and plot 8 frames to see what is happening in the video.

In [None]:
# Set up Trainer with custom DataLoaders
import pytorch_lightning as L
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

print("🚀 Setting up Lightning Trainer...")

# Set up Wandb logger (if API key is provided)
if WANDB_API_KEY:
    wandb_logger = WandbLogger(
        project="internvl-mvbench-finetune",
        name=f"internvl3.5-1b-qlora-mvbench",
        log_model=True
    )
    print("✅ Wandb logger initialized")
else:
    wandb_logger = None
    print("⚠️ Wandb API key not provided, using console logging only")

# Set up callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath="./checkpoints",
    filename="internvl-mvbench-{epoch:02d}-{val_accuracy:.3f}",
    save_top_k=2,
    monitor="val_accuracy",
    mode="max",
    save_last=True,
    every_n_epochs=1
)

lr_monitor = LearningRateMonitor(logging_interval='step')

class WandbModelUpload(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Pushing model to the hub after epoch {trainer.current_epoch}")
        pl_module.model.push_to_hub(REPO_ID,
                                    commit_message=f"Training in progress, epoch {trainer.current_epoch}")

    def on_train_end(self, trainer, pl_module):
        print(f"Pushing model to the hub after training")
        pl_module.processor.push_to_hub(REPO_ID,
                                    commit_message=f"Training done")
        pl_module.model.push_to_hub(REPO_ID,
                                    commit_message=f"Training done")

# Create trainer
trainer = L.Trainer(
    max_epochs=config["max_epochs"],
    precision=config["precision"],
    accelerator="auto",
    devices=1,
    accumulate_grad_batches=2,  # Effective batch size = 2
    gradient_clip_val=1.0,
    val_check_interval=0.5,  # Validate twice per epoch
    logger=wandb_logger,
    callbacks=[checkpoint_callback, lr_monitor],
    enable_checkpointing=True,
    enable_progress_bar=True,
    log_every_n_steps=1,
    deterministic=True,
    benchmark=False
)

# Set up data for trainer
trainer.model = model_module
trainer.train_dataloader = train_dataloader
trainer.val_dataloaders = val_dataloader

print("✅ Lightning Trainer configured successfully")
print(f"  Max epochs: {config['max_epochs']}")
print(f"  Precision: {config['precision']}")
print(f"  Effective batch size: {config['batch_size'] * 2} (with accumulation)")
print(f"  Validation check interval: 0.5 epoch")
print(f"  Checkpoint directory: ./checkpoints")

Next you need to prepare the video for the model, along with the text prompt we used during training. You need to apply the chat template to make sure the format is respected.

Notice that this is exactly the same as what you did for the evaluation data collate function.

In [None]:
# Prepare the input for inference
video_frames = [Image.fromarray(frame) for frame in clip]

# Format the conversation for inference (question only, without answer)
conversation = [sample['conversations'][0]]  # Only the human question
prompt_text = processor_inference.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)

print(f"Prompt:\n{prompt_text}\n")

# Process the input
inputs = processor_inference(
    text=[prompt_text],
    videos=[video_frames],
    padding=True,
    return_tensors="pt"
).to(model_inference.device)

print(f"Input IDs shape: {inputs['input_ids'].shape}")
print(f"Pixel values shape: {inputs.get('pixel_values_videos', torch.tensor([])).shape}")

# Test pipeline
try:
    train_samples = [train_dataset[0], train_dataset[1]]
    test_batch = simple_train_collate_fn(train_samples)
    
    input_ids, attention_mask, pixel_values_videos, labels, image_flags = test_batch
    
    print(f"Batch created:")
    print(f"  input_ids: {input_ids.shape}")
    print(f"  pixel_values: {pixel_values_videos.shape} (dtype: {pixel_values_videos.dtype})")
    print(f"  image_flags: {image_flags.shape}")
    
    model.eval()
    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values_videos,
            image_flags=image_flags,
            labels=labels
        )
    
    print(f"Forward pass successful: loss {outputs.loss:.4f}")
    
    eval_samples = [eval_dataset[0]]
    eval_batch = simple_eval_collate_fn(eval_samples)
    gen_input_ids, gen_attention_mask, gen_pixel_values, _, gen_image_flags = eval_batch
    
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        generated_ids = model.generate(
            input_ids=gen_input_ids,
            attention_mask=gen_attention_mask,
            pixel_values=gen_pixel_values,
            image_flags=gen_image_flags,
            max_new_tokens=10,
            do_sample=False
        )
    
    generated_text = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"Generation works: {generated_text[:50]}...")
    print("Pipeline test completed successfully!")
    
except Exception as e:
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()

print("🔧 Testing complete pipeline with FIXED PLModule and autocast...")

try:
    # Create test batch using our FIXED collators
    print("Creating test batch...")
    train_samples = [train_dataset[0], train_dataset[1]]
    test_batch = simple_train_collate_fn_fixed(train_samples)
    
    input_ids, attention_mask, pixel_values_videos, labels, image_flags = test_batch
    
    print(f"✅ Batch created successfully:")
    print(f"  input_ids: {input_ids.shape} (device: {input_ids.device})")
    print(f"  attention_mask: {attention_mask.shape} (device: {attention_mask.device})")

In [None]:
import pandas as pd

# Create a summary table with results
results = {
    "Task": ["Action Sequence"],
    "Model": ["InternVL3.5-1B"],
    "Fine-tuning Method": ["QLoRA (4-bit + LoRA)"],
    "Train Dataset Size": [len(train_dataset)],
    "Test Dataset Size": [len(eval_dataset)],
    "Metric": ["Accuracy"],
    "Test Accuracy": ["[TO BE FILLED AFTER TRAINING]"],
    "Training Epochs": [config["max_epochs"]],
    "Learning Rate": [config["lr"]],
    "Batch Size": [config["batch_size"]],
    "LoRA Rank": [16],
    "LoRA Alpha": [32],
}

results_df = pd.DataFrame(results)
print("\n" + "="*80)
print("FINAL EVALUATION RESULTS")
print("="*80)
print(results_df.to_string(index=False))
print("="*80)

# After training, you can update the test accuracy by running evaluation on the full test set
# Example: test_accuracy = trainer.validate(model_module)
# Then update: results["Test Accuracy"] = [f"{test_accuracy:.4f}"]


In [None]:
# Start training
print("Starting InternVL3.5-1B fine-tuning...")

try:
    trainer.fit(
        model=model_module,
        train_dataloaders=train_dataloader,
        val_dataloaders=val_dataloader
    )
    
    print("Training completed successfully!")
    
    # Test trained model
    model_module.eval()
    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
        test_batch = next(iter(val_dataloader))
        input_ids, attention_mask, pixel_values, answer_choices, image_flags = test_batch
        
        generated_ids = model_module.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_flags=image_flags,
            max_new_tokens=15,
            do_sample=False
        )
        
        generated_text = model_module.processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        print(f"Generated: {generated_text[:100]}...")
        print(f"Expected: {answer_choices[0]}")
    
    print("Fine-tuning completed! Model ready for use.")
    
except KeyboardInterrupt:
    print("Training interrupted by user")
    
except Exception as e:
    print(f"Training failed: {e}")
    import traceback
    traceback.print_exc()