# Homework #1: Fine-tuning InternVL3.5-1B on MVBench

## Overview
This notebook demonstrates fine-tuning of the InternVL3.5-1B multimodal model on a video understanding task from the MVBench benchmark.

## What this notebook covers:
1. ✅ Dataset loading and preprocessing (MVBench - Action Sequence task)
2. ✅ Train/test split (80/20)
3. ✅ Model loading with QLoRA (4-bit quantization + LoRA adapters)
4. ✅ Training and evaluation loops using PyTorch Lightning
5. ✅ Accuracy metric for multiple-choice questions
6. ✅ Logging with TensorBoard and checkpointing
7. ✅ Inference pipeline for video question answering

## Requirements:
- GPU with at least 16GB VRAM (A100 recommended, but RTX 4090 works)
- ~1TB storage for MVBench videos
- Python packages: transformers, peft, bitsandbytes, pytorch-lightning, datasets, decord, av

## Before running:
1. Update API keys and tokens in Cell 4 (WANDB_API_KEY, REPO_ID, HF token)
2. Ensure sufficient disk space for video downloads
3. Consider applying the processor patch if you encounter video processing errors (see note below)


## Prerequisites
Before we start, make sure you have the following:

- Access to a GPU (preferably A100 since videos require high sequence lengths).
- Familiarity with Hugging Face’s Transformers library.
- Pre-install necessary packages by running the below.

In [8]:
# !pip install -U -q pip setuptools wheel
# !pip install --upgrade --no-deps -q numpy==2.0.0 huggingface-hub==0.35.3
# !pip install -U -q huggingface-hub>=0.23.0 pyarrow>=21.0.0
# !pip install -U -q transformers==4.45.0 accelerate bitsandbytes peft datasets==4.1.0
# !pip install -q av decord
# !pip install -q lightning
# !pip install -q wandb
# !pip install -q timm deepspeed flash_attn
# !pip install -q nltk
# !pip install -q opencv-python

In [86]:
!pip install -U -q transformers accelerate bitsandbytes peft datasets
!pip install -q av decord
!pip install -q lightning
!pip install -q pyarrow==21.0.0
!pip install -q wandb
!pip install -q timm deepspeed flash_attn
!pip install -U -q tensorboard

## Implementation Choices and Justification

### Task Selection
**Selected Task:** Action Sequence from MVBench
- This task involves understanding temporal sequences of actions in videos
- It is a multiple-choice task, making it suitable for accuracy-based evaluation
- The dataset has 188 valid samples after filtering missing videos

### Fine-tuning Method: QLoRA (Quantized Low-Rank Adaptation)

**Why QLoRA?**

1. **Memory Efficiency**: 
   - 4-bit quantization reduces model memory footprint by ~75%
   - Enables fine-tuning on consumer GPUs (e.g., single A100 or even RTX 4090)
   - Only adapter parameters are trained, not the full model

2. **Performance**:
   - Research shows QLoRA achieves near-identical performance to full fine-tuning
   - NormalFloat 4-bit (NF4) quantization preserves model quality
   - LoRA rank of 16 provides good balance between capacity and efficiency

3. **Practical Considerations**:
   - Faster iteration and experimentation
   - Lower computational costs
   - Easier to deploy (smaller checkpoint files)

**Configuration:**
- Quantization: 4-bit NF4 with double quantization
- LoRA rank: 16
- LoRA alpha: 32
- Target modules: All linear layers in the language model (excluding vision encoder and projector)
- Compute dtype: bfloat16

### Metric: Accuracy
- Multiple-choice questions have 4 options (A, B, C, D)
- Accuracy is the standard metric for such tasks
- Easy to interpret and compare with baselines

### Training Configuration
- Optimizer: AdamW with weight decay (0.01)
- Learning rate: 1e-4 with linear warmup (50 steps)
- Batch size: 1 with gradient accumulation (8 steps)
- Effective batch size: 8
- Mixed precision: bfloat16
- Max epochs: 2 (to prevent overfitting on small dataset)
- Early stopping: patience of 3 epochs based on validation accuracy

### Data Processing
- Videos: 8 frames uniformly sampled from each video
- Resolution: 448x448 (dynamic tiling)
- Train/Test split: 80/20 (stratified)
- Data augmentation: None (to maintain temporal consistency)


## Fine-tune InternVL 1B. on MMBench dataset

In this notebook, you need to fine-tune the [InternVL](https://huggingface.co/OpenGVLab/InternVL3_5-1B) model on [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench) dataset which is comprised of various video-related tasks. Note that MMBench is quite small and is not made for tuning. So firstly you need to split it into training/testing parts.

The goal for the model in this notebook is to answer given multiple choice questions based on the video. The questions can be realetd to temporal aspects of the video, pose prediction and so on.
Sources:

* InternVL [documentation](https://internvl.readthedocs.io/en/latest/internvl2.0/introduction.html)
* InternVL [checkpoint on the hub](https://huggingface.co/OpenGVLab/InternVL2-1B)

## Define variables

We'll first set some variables useful througout this notebook and doo all the necessary imports.

In [1]:
import os
import av
import re
import gc
import sys
import json
import random
import bisect
import shutil
import traceback
import numpy as np
from nltk import edit_distance
from pathlib import Path
from copy import deepcopy
from typing import Dict, Any, List, Union, Tuple

from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer,
                          HfArgumentParser, Trainer, TrainingArguments,
                          set_seed, AutoProcessor)
from transformers import BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from huggingface_hub import snapshot_download, hf_hub_download
from datasets import load_dataset, concatenate_datasets
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError

import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping, Callback

from getpass import getpass
from huggingface_hub import login

In [3]:
MAX_LENGTH = 4096
MODEL_ID = "OpenGVLab/InternVL3_5-1B"
REPO_ID = "Ivan1008/internvl3.5-finetune"

# os.environ["WANDB_API_KEY"] = getpass("Input wandb api key: ")
# os.environ["WANDB_MODE"] = "online"

# access_token = getpass("Input hugging face access token: ")
# login(access_token)

USE_LORA = False
USE_QLORA = True

In [9]:
! git clone https://github.com/OpenGVLab/InternVL.git

Cloning into 'InternVL'...
remote: Enumerating objects: 3748, done.[K
remote: Counting objects: 100% (60/60), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 3748 (delta 32), reused 25 (delta 25), pack-reused 3688 (from 3)[K
Receiving objects: 100% (3748/3748), 39.63 MiB | 31.05 MiB/s, done.
Resolving deltas: 100% (2288/2288), done.


In [4]:
! ls InternVL/internvl_chat

README.md	zero_stage1_config.json
eval		zero_stage2_config.json
evaluate.sh	zero_stage3_config.json
examples	zero_stage3_config_100b.json
internvl	zero_stage3_config_100b_1e7_offload.json
pyproject.toml	zero_stage3_config_100b_1e8.json
shell		zero_stage3_config_34b.json
tools		zero_stage3_config_70b.json


In [5]:
sys.path.append('./InternVL/internvl_chat')

In [7]:
from internvl.dist_utils import init_dist
# from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from internvl.model.internvl_chat import (InternVisionConfig,
                                          InternVisionModel,
                                          InternVLChatConfig,
                                          InternVLChatModel)
# from internvl.patch import (concat_pad_data_collator,
#                             replace_llama_rmsnorm_with_fused_rmsnorm,
#                             replace_train_sampler)
from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN,
                                      IMG_CONTEXT_TOKEN, IMG_END_TOKEN,
                                      IMG_START_TOKEN, QUAD_END_TOKEN,
                                      QUAD_START_TOKEN, REF_END_TOKEN,
                                      IMAGENET_MEAN, IMAGENET_STD,
                                      REF_START_TOKEN)
from internvl.train.dataset import (ConcatDataset, read_frames_decord,
                                    WeightedConcatDataset, build_transform,
                                    dynamic_preprocess, preprocess,
                                    preprocess_internlm, preprocess_mpt, preprocess_internvl2_5,
                                    preprocess_phi3,)

# MVBench benchmark

[MVBench on HF Datasets](https://huggingface.co/datasets/OpenGVLab/MVBench)

<!-- ![MVbench1.png](https://huggingface.co/datasets/OpenGVLab/MVBench/resolve/main/assert/generation.png) -->

It consists of the 20 temporal task examples as follows.

![MVbench-structure.png](https://huggingface.co/datasets/OpenGVLab/MVBench/resolve/main/assert/task_example.png)


Here we have a nice viewer for each task:

[Dataset viewer](https://huggingface.co/datasets/OpenGVLab/MVBench/viewer/action_sequence)



We will start by downloading and processing the dataset. Even though MVBench is a small dataset, it still requires **around 1000B to store the videos**, so make sure you have enough free space.

First, we will use this mapping to get the datasets because each one is a separate subset in its own folder. Then we need a few helper functions to download videos and process them to fit the model's format (8 frames each video).

In [8]:
data_list = {
    "Action Sequence": ("action_sequence.json", "star/Charades_v1_480/", "video", True), # has start & end
    "Action Prediction": ("action_prediction.json", "star/Charades_v1_480/", "video", True), # has start & end
    "Action Antonym": ("action_antonym.json", "ssv2_video/", "video", False),
    "Fine-grained Action": ("fine_grained_action.json", "Moments_in_Time_Raw/videos/", "video", False),
    "Unexpected Action": ("unexpected_action.json", "FunQA_test/test/", "video", False),
    "Object Existence": ("object_existence.json", "clevrer/video_validation/", "video", False),
    "Object Interaction": ("object_interaction.json", "star/Charades_v1_480/", "video", True), # has start & end
    "Object Shuffle": ("object_shuffle.json", "perception/videos/", "video", False),
    "Moving Direction": ("moving_direction.json", "clevrer/video_validation/", "video", False),
    "Action Localization": ("action_localization.json", "sta/sta_video/", "video", True),  # has start & end
    "Scene Transition": ("scene_transition.json", "scene_qa/video/", "video", False),
    "Action Count": ("action_count.json", "perception/videos/", "video", False),
    "Moving Count": ("moving_count.json", "clevrer/video_validation/", "video", False),
    "Moving Attribute": ("moving_attribute.json", "clevrer/video_validation/", "video", False),
    "State Change": ("state_change.json", "perception/videos/", "video", False),
    "Fine-grained Pose": ("fine_grained_pose.json", "nturgbd/", "video", False),
    "Character Order": ("character_order.json", "perception/videos/", "video", False),
    "Egocentric Navigation": ("egocentric_navigation.json", "vlnqa/", "video", False),
    "Episodic Reasoning": ("episodic_reasoning.json", "tvqa/frames_fps3_hq/", "frame", True),  # has start & end, read frame
    "Counterfactual Inference": ("counterfactual_inference.json", "clevrer/video_validation/", "video", False),
}

data_dir = "dataset"
if not os.path.exists(data_dir):
    os.mkdir("dataset")

In [9]:
def read_video_pyav(video_path, start, end, n_frames=8):
    """
    Reads a video for given start-end timestamps interval
    and uniformly samples 8 frames of it
    """
    container = av.open(video_path)
    video = container.streams.get(0)[0]

    av_timestamps = [
        int(packet.pts * video.time_base) for packet in container.demux(video) if packet.pts is not None
    ]

    av_timestamps.sort()
    start_id = bisect.bisect_left(av_timestamps, start)
    end_id = bisect.bisect_left(av_timestamps, end)

    # in case it is a very short video, lets take a longer duration and sample
    if end_id  - start_id < 10:
        end_id += 10
        start_id -= 10

    end_id = min(len(av_timestamps) - 1, end_id)
    start_id = max(1, start_id)

    # We sample n_frames frames for tuning following the original paper
    # But we can increase the number of frames for longer videos and check out if it helps performance
    # Change the below "n_frames" to any number of frames you want, and note that more frames -> more computational resources needed
    indices = np.linspace(start_id, end_id, n_frames).astype(int)

    frames = []
    container.seek(0)
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_id:
            break
        if i >= start_id and i in indices:
            frames.append(frame)
    assert len(frames) == n_frames, f"Got {len(frames)} frames but should be {n_frames}. Check the indices: {indices};, start_id: {start_id}, end_id: {end_id}. Len of video is {len(av_timestamps)} frames."
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])

In [10]:
def collate_read_video(example, path):
    # Some datasets have a start-end interval, so we try to get it if exists.
    # Otherwise just set a very large end timestamp
    video_path = f'{path}/{example["video"]}'
    source = av.open(video_path)
    video = source.streams.get(0)[0]
    duration_seconds = source.duration / av.time_base
    example['end'] = min(example['end'], duration_seconds)
    
    return example

In [11]:
# Download the videos from datasets repo and unzip.
# Make sure you have enough free space before downloading and unzipping

# videos = snapshot_download(repo_id="OpenGVLab/MVBench", allow_patterns="*", repo_type="dataset")
# for zip_file in os.listdir(f"{videos}/video"):
#     if zip_file.endswith(".zip"):
#         shutil.unpack_archive(f"{videos}/video/{zip_file}", f"{videos}/videos_unzipped/")

In [11]:
# Or download only selected task with appropriate files
# https://huggingface.co/docs/huggingface_hub/v0.24.7/en/package_reference/file_download#huggingface_hub.hf_hub_download

TASK_NAME = "Action Sequence"
annotation_fn, video_dir, video_type, has_clip = data_list.get(TASK_NAME)
print(f"Task: {TASK_NAME}")
print(f"Annotation file: {annotation_fn}")
print(f"Videos are stored in: {video_dir}")
print(f"Videos are represented as: {video_type}")
print(f"Videos are have start/end timestamps: {has_clip}")

annotation_fn_local = hf_hub_download(repo_id="OpenGVLab/MVBench",
                                    filename='json/' + annotation_fn,
                                    repo_type="dataset",
                                    local_dir=data_dir)
videos_zip = hf_hub_download(repo_id="OpenGVLab/MVBench",
                            filename='video/' + video_dir.split("/")[0] + ".zip",
                            repo_type="dataset",
                            local_dir=data_dir)

# Unzip
for zip_file in os.listdir(f"{data_dir}/video"):
    if zip_file.endswith(".zip"):
        print(zip_file)
        shutil.unpack_archive(f"{data_dir}/video/{zip_file}", f"{data_dir}/video/videos_unzipped/")

Task: Action Sequence
Annotation file: action_sequence.json
Videos are stored in: star/Charades_v1_480/
Videos are represented as: video
Videos are have start/end timestamps: True
star.zip


Make a following data structure:

```
dataset/
    /json
        task_name.json
    /video
        /task_name_prefix (optional)
            /task_name
                video_0.mp4
                video_1.mp4
                video_2.mp4
                ...
                video_100.mp4
```





In [12]:
ds = load_dataset("json", data_files=annotation_fn_local, split="train")
ds

Dataset({
    features: ['video', 'question', 'candidates', 'answer', 'start', 'end'],
    num_rows: 200
})

In [13]:
# Some tasks in MVBench are missing video files - keep it in mind!
has_missing = False
for sample in ds:
    if not os.path.exists(f"{data_dir}/video/videos_unzipped/{video_dir}/{sample['video']}"):
        print(f"Video `{sample['video']}` does not exists!")
        has_missing = True

Video `EDXBD.mp4` does not exists!
Video `K47J5.mp4` does not exists!
Video `9MNZ5.mp4` does not exists!
Video `QXT9W.mp4` does not exists!
Video `ABHC6.mp4` does not exists!
Video `ALXUC.mp4` does not exists!
Video `BAUQE.mp4` does not exists!
Video `PHH6B.mp4` does not exists!
Video `MNC10.mp4` does not exists!
Video `W7CR5.mp4` does not exists!
Video `Q8UJ8.mp4` does not exists!
Video `X9WTR.mp4` does not exists!


In [14]:
print(f"Dataset length = {len(ds)}")
if has_missing:
    ds = ds.filter(lambda x: os.path.exists(f"{data_dir}/video/videos_unzipped/{video_dir}/{x['video']}"))

print(f"Dataset length = {len(ds)}")

Dataset length = 200
Dataset length = 188


In [15]:
# Load videos and split them into frames
ds = ds.map(collate_read_video,
            batched=False,
            fn_kwargs={"path": f"{data_dir}/video/videos_unzipped/{video_dir}"})

# Make conversation

def make_conversation(sample):
    id2choice = {0: "A", 1: "B", 2: "C", 3: "D"}
    question, candidates = sample["question"], sample["candidates"]
    answer_i = candidates.index(sample["answer"])
    answer_choice = id2choice[answer_i]

    mult_choice = "\n"
    for i, choice in enumerate(candidates):
        mult_choice += f"{id2choice[i]}. {choice};\n"

    conversations = [
         {'from': 'human', 'value': question + mult_choice + '<video>'},
         {'from': 'gpt', 'value': answer_choice + " " + sample["answer"]}
    ]
    sample['conversations'] = conversations
    return sample

ds = ds.map(make_conversation,
            batched=False)



```json
{'id': 578004, 'video': '013901_013950/1025449886.mp4',
'conversations':
[{'from': 'human', 'value': 'Render a clear and concise summary of the video below.\n<video>'},
{'from': 'gpt', 'value': 'Aerial; drone flight around dangerous eroded steep stony slopes of cabo da roca; granite boulders, sea cliffs along the high coast; long seashore with lighthouse, overlooking the promontory, portugal'}]
}
```



In [16]:
# Load model's processor
processor = AutoProcessor.from_pretrained(MODEL_ID + "-HF", trust_remote_code=True, use_fast=False, model_max_length=MAX_LENGTH)
processor.padding_side = "right"

processor

InternVLProcessor:
- image_processor: GotOcr2ImageProcessor {
  "crop_size": null,
  "crop_to_patches": false,
  "data_format": "channels_first",
  "default_to_square": true,
  "device": null,
  "do_center_crop": null,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "GotOcr2ImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "input_data_format": null,
  "max_patches": 12,
  "min_patches": 1,
  "processor_class": "InternVLProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "return_tensors": null,
  "size": {
    "height": 448,
    "width": 448
  }
}

- tokenizer: Qwen2Tokenizer(name_or_path='OpenGVLab/InternVL3_5-1B-HF', vocab_size=151643, model_max_length=4096, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|im_end|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens':

## Custom Dataset Class

In the next step, you'll need **to define a custom dataset** class and the necessary functions to prepare our data for fine-tuning model. The VideoQADataset class extends the [PyTorch Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) class to facilitate loading and processing "MMBench". This class will handle the conversion of dataset samples into the format required for training and evaluation by preparing a prompt and making array from videos.

Next, you need **to define collate functions** to handle the batching of data during training and evaluation. These functions ensure that the input data is properly formatted and padded.

Here use the processor to turn the (video, target token sequence) into the format that the model expects (which is pixel_values, input_ids etc.). Use a dynamic padding of the batches: each batch contains ground truth sequences of varying lengths.

Also you can limit the length of the text tokens (input_ids) to a max length due to memory constraints, feel free to expand if your target token sequences are longer (I'd recommend plotting the average token length of your dataset to determine the optimal value).

The formatting of the input_ids is super important: you need to respect a so-called [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating).

Labels are created for the model by simply copying the inputs to the LLM (input_ids), but with padding tokens replaced by the ignore index of the loss function. This ensures that the model doesn't need to learn to predict padding tokens (used to batch examples together).

Why are the labels a copy of the model inputs, you may ask? The model will internally shift the labels one position to the right so that the model will learn to predict the next token. This can be seen here.

The collate function for evaluation is different, since there you only need to feed the prompt to the model, as we'll use the `generate()` method to autoregressively generate a completion.

In [35]:
from PIL import UnidentifiedImageError


class VideoQADataset(Dataset):
    """Dataset for Video QA fine-tuning."""

    def __init__(
        self,
        template_name,
        raw_data,
        video_data_dir,
        tokenizer,
        ds_name,
        num_image_token,
        image_size=224,
        is_train=True,
        pad2square=False,
        dynamic_image_size=False,
        use_thumbnail=False,
        min_dynamic_patch=1,
        max_dynamic_patch=6,
        min_num_frame=8,  # for video data
        max_num_frame=8,  # for video data
        sampling_method='rand',  # for video data
        repeat_time=1,
        normalize_type='imagenet',
        random_seed=0,
    ):
        super(VideoQADataset, self).__init__()
        self.ds_name = ds_name
        self.tokenizer = tokenizer
        self.template_name = template_name
        self.num_image_token = num_image_token
        print(f'[Dataset] num_image_token: {num_image_token}')
        print(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
        print(f'[Dataset] use_thumbnail: {use_thumbnail}')
        print(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')

        self.image_size = image_size
        self.is_train = is_train
        self.pad2square = pad2square
        self.max_num_frame = max_num_frame
        self.min_num_frame = min_num_frame
        self.sampling_method = sampling_method

        print('Formatting inputs...Skip in lazy mode')

        self.raw_data = raw_data.shuffle(seed=random_seed) if hasattr(raw_data, 'shuffle') else raw_data

        gc.collect()
        self.root = video_data_dir
        self.cached_data_dict = {}
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail = use_thumbnail
        self.min_dynamic_patch = min_dynamic_patch
        self.max_dynamic_patch = max_dynamic_patch
        self.normalize_type = normalize_type

        gc.collect()

    def __len__(self):
        return len(self.raw_data)

    def get_preprocess_function(self):
        # Select the appropriate preprocessing function based on the template name
        return preprocess_internvl2_5
        
    def get_transform(self):
        # Build transformation function
        transform = build_transform(is_train=self.is_train, input_size=self.image_size,
                                    pad2square=self.pad2square, normalize_type=self.normalize_type)
        return transform

    def video_get_item(self, data_item):
        # Build transformation function
        transform = self.get_transform()

        # Ensure the first conversation contains a video placeholder
        if '<video>' not in data_item['conversations'][0]['value']:
            data_item['conversations'][0]['value'] = '<video>\n' + data_item['conversations'][0]['value']

        # Get the video file path
        video_file = data_item['video']
        video_path = os.path.join(self.root, video_file)
        
        # Use read_video_pyav instead of read_frames_decord
        # image_list = read_video_pyav(video_path,
        #                              start=data_item.get('start', 0),
        #                              end=data_item.get('end', 1000000),
        #                              n_frames=self.max_num_frame)
        image_list = read_frames_decord(video_path,
                                        num_frames=self.max_num_frame,
                                        min_num_frames = self.min_num_frame,
                                        sample=self.sampling_method,
                                        client=None,
                                        clip=(data_item.get('start', None),
                                              data_item.get('end', None)))
        
        # image_list = [Image.fromarray(frame) for frame in image_list]
        
        # Generate special tokens for each video frame
        special_tokens = '\n'.join(['Frame{}: <image>'.format(i + 1) for i in range(len(image_list))])
        data_item['conversations'][0]['value'] = data_item['conversations'][0]['value'].replace(
            '<video>', special_tokens)

        # Transform each frame image and stack them into a tensor
        pixel_values = [transform(image) for image in image_list]
        pixel_values = torch.stack(pixel_values)
        num_patches = pixel_values.size(0)

        # Select the appropriate preprocessing function based on the template name
        preprocess_function = self.get_preprocess_function()

        # Preprocess the conversations and generate the return dictionary
        num_image_tokens = [self.num_image_token] * num_patches

        if self.is_train:
            # For training, use full conversation (human + assistant)
            sources = [deepcopy(data_item['conversations'])]
        else:
            # For evaluation, use only human part (prompt)
            sources = [deepcopy(data_item['conversations'])]
        
        ret = preprocess_function(template_name=self.template_name, 
                                  sources=sources,
                                  tokenizer=self.tokenizer,
                                  num_image_token_list=num_image_tokens,
                                  text_only=False,
                                  group_by_length=False,
                                  ds_name=self.ds_name, 
                                  num_image=num_patches
        )

        conv = data_item['conversations']
        
        # Extract the correct answer choice from the GPT response
        answer_choice = ''
        if len(conv) > 1 and 'value' in conv[1]:
            gpt_response = conv[1]['value']
            # Extract just the choice letter (A, B, C, D) - usually first character
            answer_choice = gpt_response.strip()[0]  # Get first character (A, B, C, or D)
            

        # Create the final return dictionary
        ret = dict(
            input_ids=ret['input_ids'][0],
            labels=ret['labels'][0],
            attention_mask=ret['attention_mask'][0],
            pixel_values=pixel_values,
            image_flags=torch.tensor([1] * num_patches, dtype=torch.long),
            answer_choice=answer_choice
        )

        return ret

    def pure_text_get_item(self, data_item):
        """Handle text-only examples"""
        # Simple text processing
        sources = [deepcopy(data_item['conversations'])]
        
        ret = preprocess_internvl2_5(template_name=self.template_name, 
                                    sources=sources,
                                    tokenizer=self.tokenizer,
                                    num_image_token_list=[0],
                                    text_only=True,
                                    group_by_length=False,
                                    ds_name=self.ds_name, 
                                    num_image=0)
        
        conv = data_item['conversations']
        answer_choice = ''
        if len(conv) > 1 and 'value' in conv[1]:
            answer_choice = conv[1]['value'].strip()[0]

        return dict(
            input_ids=ret['input_ids'][0],
            labels=ret['labels'][0],
            attention_mask=ret['attention_mask'][0],
            pixel_values=torch.zeros(1, 3, self.image_size, self.image_size),
            image_flags=torch.tensor([0], dtype=torch.long),
            answer_choice=answer_choice
        )

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        i = i % len(self.raw_data)
        while True:
            try:
                data_item = self.raw_data[i]
                if 'image' in data_item and len(data_item['image']) != 0:
                    # For simplicity, handle images as text in this version
                    ret = self.pure_text_get_item(data_item)
                elif 'video' in data_item and data_item['video'] is not None and data_item['video'] != '':
                    ret = self.video_get_item(data_item)
                else:
                    ret = self.pure_text_get_item(data_item)
                break
            except Exception as e:
                print(e, self.ds_name, flush=True)
                if not isinstance(e, UnidentifiedImageError):
                    traceback.print_exc()
                data_item = self.raw_data[i]
                if 'image' in data_item:
                    if type(data_item['image']) == list:
                        images = [self.root + item for item in data_item['image']]
                        print(f'Failed to load image: {images}, the dataset is: {self.ds_name}')
                    else:
                        if data_item['image'].startswith('s3://'):
                            data_path = self.root + data_item['image']
                        else:
                            data_path = os.path.join(self.root, data_item['image'])
                        print(f'Failed to load image: {data_path}, the dataset is: {self.ds_name}')
                elif 'video' in data_item:
                    data_path = os.path.join(self.root, data_item['video'])
                    print(f'Failed to load video: {data_path}, the dataset is: {self.ds_name}')
                i = random.randint(0, len(self.raw_data) - 1)
        return ret

In [36]:
def train_collate_fn(examples):
    """
    Collate function for training that batches examples together.
    Handles padding for text and stacks video frames.
    """
    # Extract all components from examples
    input_ids = [example['input_ids'] for example in examples]
    attention_mask = [example['attention_mask'] for example in examples]
    pixel_values = [example['pixel_values'] for example in examples]
    image_flags = [example['image_flags'] for example in examples]
    labels = [example['labels'] for example in examples]
    answer_choices = [example['answer_choice'] for example in examples]

    # Pad input_ids, attention_mask, and labels to the same length
    max_length = max(len(ids) for ids in input_ids)
    
    padded_input_ids = []
    padded_attention_mask = []
    padded_labels = []
    
    for i in range(len(input_ids)):
        # Pad input_ids
        pad_length = max_length - len(input_ids[i])
        padded_input_ids.append(
            torch.cat([
                input_ids[i],
                torch.full((pad_length,), tokenizer.pad_token_id, dtype=torch.long)
            ])
        )
        
        # Pad attention_mask
        padded_attention_mask.append(
            torch.cat([
                attention_mask[i],
                torch.zeros(pad_length, dtype=torch.long)
            ])
        )
        
        # Pad labels (using -100 for ignore index)
        padded_labels.append(
            torch.cat([
                labels[i],
                torch.full((pad_length,), -100, dtype=torch.long)
            ])
        )

    # Stack all tensors
    input_ids_batch = torch.stack(padded_input_ids)
    attention_mask_batch = torch.stack(padded_attention_mask)
    labels_batch = torch.stack(padded_labels)
    
    # Handle pixel_values - FLATTEN the video frames
    # Instead of padding frames, we'll process each frame as a separate image
    flattened_pixel_values = []
    flattened_image_flags = []
    
    for i, (pv, flags) in enumerate(zip(pixel_values, image_flags)):
        # pv shape: (num_frames, channels, height, width)
        # We need to flatten this to (num_frames * batch_size, channels, height, width)
        num_frames = pv.size(0)
        for frame_idx in range(num_frames):
            flattened_pixel_values.append(pv[frame_idx])
            # Note: image_flags is not defined in this scope - you need to extract it from examples
            flattened_image_flags.append(flags[frame_idx])
    
    # Stack flattened pixel values
    if flattened_pixel_values:
        pixel_values_batch = torch.stack(flattened_pixel_values)
        image_flags_batch = torch.stack(flattened_image_flags)
    else:
        pixel_values_batch = torch.tensor([])
        image_flags_batch = torch.tensor([])
    
    return {
        'input_ids': input_ids_batch,
        'attention_mask': attention_mask_batch,
        'pixel_values': pixel_values_batch,
        'labels': labels_batch,
        'image_flags': image_flags_batch,
        'answer_choices': answer_choices
    }


def eval_collate_fn(examples):
    """
    Collate function for evaluation that batches examples together.
    Similar to training but doesn't need labels for generation.
    """
    # Extract all components from examples
    input_ids = [example['input_ids'] for example in examples]
    attention_mask = [example['attention_mask'] for example in examples]
    pixel_values = [example['pixel_values'] for example in examples]
    answer_choices = [example['answer_choice'] for example in examples]

    # Pad input_ids and attention_mask to the same length
    max_length = max(len(ids) for ids in input_ids)
    
    padded_input_ids = []
    padded_attention_mask = []
    
    for i in range(len(input_ids)):
        # Pad input_ids
        pad_length = max_length - len(input_ids[i])
        padded_input_ids.append(
            torch.cat([
                input_ids[i],
                torch.full((pad_length,), tokenizer.pad_token_id, dtype=torch.long)
            ])
        )
        
        # Pad attention_mask
        padded_attention_mask.append(
            torch.cat([
                attention_mask[i],
                torch.zeros(pad_length, dtype=torch.long)
            ])
        )

    # Stack all tensors
    input_ids_batch = torch.stack(padded_input_ids)
    attention_mask_batch = torch.stack(padded_attention_mask)
    
    # Handle pixel_values - FLATTEN the video frames
    flattened_pixel_values = []
    
    for i, pv in enumerate(pixel_values):
        # pv shape: (num_frames, channels, height, width)
        num_frames = pv.size(0)
        for frame_idx in range(num_frames):
            flattened_pixel_values.append(pv[frame_idx])
    
    # Stack flattened pixel values
    if flattened_pixel_values:
        pixel_values_batch = torch.stack(flattened_pixel_values)
    else:
        pixel_values_batch = torch.tensor([])
    
    return {
        'input_ids': input_ids_batch,
        'attention_mask': attention_mask_batch,
        'pixel_values': pixel_values_batch,
        'answer_choices': answer_choices
    }

In [20]:
# def train_collate_fn(examples):
#     """
#     Collate function for training that batches examples together.
#     Handles padding for text and stacks video frames.
#     """
#     # Extract all components from examples
#     input_ids = [example['input_ids'] for example in examples]
#     attention_mask = [example['attention_mask'] for example in examples]
#     pixel_values = [example['pixel_values'] for example in examples]
#     image_flags = [example['image_flags'] for example in examples]
#     labels = [example['labels'] for example in examples]
#     answer_choices = [example['answer_choice'] for example in examples]

#     # Pad input_ids, attention_mask, and labels to the same length
#     max_length = max(len(ids) for ids in input_ids)
    
#     padded_input_ids = []
#     padded_attention_mask = []
#     padded_labels = []
    
#     for i in range(len(input_ids)):
#         # Pad input_ids
#         pad_length = max_length - len(input_ids[i])
#         padded_input_ids.append(
#             torch.cat([
#                 input_ids[i],
#                 torch.full((pad_length,), tokenizer.pad_token_id, dtype=torch.long)
#             ])
#         )
        
#         # Pad attention_mask
#         padded_attention_mask.append(
#             torch.cat([
#                 attention_mask[i],
#                 torch.zeros(pad_length, dtype=torch.long)
#             ])
#         )
        
#         # Pad labels (using -100 for ignore index)
#         padded_labels.append(
#             torch.cat([
#                 labels[i],
#                 torch.full((pad_length,), -100, dtype=torch.long)
#             ])
#         )

#     # Stack all tensors
#     input_ids_batch = torch.stack(padded_input_ids)
#     attention_mask_batch = torch.stack(padded_attention_mask)
#     labels_batch = torch.stack(padded_labels)
    
#     # Handle pixel_values - FLATTEN the video frames
#     # Instead of padding frames, we'll process each frame as a separate image
#     flattened_pixel_values = []
#     flattened_image_flags = []
    
#     for i, (pv, flags) in enumerate(zip(pixel_values, image_flags)):
#         # pv shape: (num_frames, channels, height, width)
#         # We need to flatten this to (num_frames * batch_size, channels, height, width)
#         num_frames = pv.size(0)
#         for frame_idx in range(num_frames):
#             flattened_pixel_values.append(pv[frame_idx])
#             # Note: image_flags is not defined in this scope - you need to extract it from examples
#             flattened_image_flags.append(flags[frame_idx])
    
#     # Stack flattened pixel values
#     if flattened_pixel_values:
#         pixel_values_batch = torch.stack(flattened_pixel_values)
#         image_flags_batch = torch.stack(flattened_image_flags)
#     else:
#         pixel_values_batch = torch.tensor([])
#         image_flags_batch = torch.tensor([])
    
#     return {
#         'input_ids': input_ids_batch,
#         'attention_mask': attention_mask_batch,
#         'pixel_values': pixel_values_batch,
#         'labels': labels_batch,
#         'image_flags': image_flags_batch,
#         'answer_choices': answer_choices
#     }


# def eval_collate_fn(examples):
#     """
#     Collate function for evaluation that batches examples together.
#     Similar to training but doesn't need labels for generation.
#     """
#     # Extract all components from examples
#     input_ids = [example['input_ids'] for example in examples]
#     attention_mask = [example['attention_mask'] for example in examples]
#     pixel_values = [example['pixel_values'] for example in examples]
#     answer_choices = [example['answer_choice'] for example in examples]

#     # Pad input_ids and attention_mask to the same length
#     max_length = max(len(ids) for ids in input_ids)
    
#     padded_input_ids = []
#     padded_attention_mask = []
    
    # for i in range(len(input_ids)):
    #     # Pad input_ids
    #     pad_length = max_length - len(input_ids[i])
    #     padded_input_ids.append(
    #         torch.cat([
    #             input_ids[i],
    #             torch.full((pad_length,), tokenizer.pad_token_id, dtype=torch.long)
    #         ])
    #     )
        
    #     # Pad attention_mask
    #     padded_attention_mask.append(
    #         torch.cat([
    #             attention_mask[i],
    #             torch.zeros(pad_length, dtype=torch.long)
    #         ])
    #     )

    # # Stack all tensors
    # input_ids_batch = torch.stack(padded_input_ids)
    # attention_mask_batch = torch.stack(padded_attention_mask)
    
    # # Handle pixel_values - FLATTEN the video frames
    # flattened_pixel_values = []
    
    # for i, pv in enumerate(pixel_values):
    #     # pv shape: (num_frames, channels, height, width)
    #     num_frames = pv.size(0)
    #     for frame_idx in range(num_frames):
    #         flattened_pixel_values.append(pv[frame_idx])
    
    # # Stack flattened pixel values
    # if flattened_pixel_values:
    #     pixel_values_batch = torch.stack(flattened_pixel_values)
    # else:
    #     pixel_values_batch = torch.tensor([])
    
    # return {
    #     'input_ids': input_ids_batch,
    #     'attention_mask': attention_mask_batch,
    #     'pixel_values': pixel_values_batch,
    #     'answer_choices': answer_choices
    # }

## Shuffling and Splitting the Dataset
You need to shuffle dataset, and then split it into training and test sets. This ensures that our model is trained on a diverse and representative sample of the data.

In [37]:
# Shuffle and split dataset into 80% train and 20% test
ds = ds.shuffle(seed=42)
split_idx = int(len(ds) * 0.8)

dataset = {
    'train': ds.select(range(split_idx)),
    'test': ds.select(range(split_idx, len(ds)))
}

print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")

Train dataset size: 150
Test dataset size: 38


In [38]:
%matplotlib inline

from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML


example = dataset['train'][0]
# Read the video for visualization
video_path = f"{data_dir}/video/videos_unzipped/{video_dir}/{example['video']}"
clip = read_video_pyav(video_path, example.get('start', 1), example.get('end', 1e+10), n_frames=8)

# np array with shape (frames, height, width, channels)
video = np.array(clip)

print(f"Video shape: {video.shape}")
print(f"Question: {example['question']}")
print(f"Answer: {example['answer']}")
print(f"Candidates: {example['candidates']}")

Video shape: (8, 270, 480, 3)
Question: What happened after the person took the phone/camera?
Answer: Closed the window.
Candidates: ['Took the book.', 'Put down the pillow.', 'Closed the window.', 'Took the food.']


In [39]:
fig = plt.figure()
im = plt.imshow(video[0,:,:,:])
plt.axis("off")

plt.close() # this is required to not display the generated image

def init():
    im.set_data(video[0,:,:,:])

def animate(i):
    im.set_data(video[i,:,:,:])
    return im

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],
                               interval=300)
HTML(anim.to_html5_video())

And now wrap it in the Pytorch Datasets class and print one example as sanity check.

In [40]:
# tokenizer = getattr(processor, "tokenizer", None)
# if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False, padding_side = "right", model_max_length=MAX_LENGTH)

train_dataset = VideoQADataset(
    template_name='internvl2_5',
    raw_data=dataset["train"],
    video_data_dir=f"{data_dir}/video/videos_unzipped/{video_dir}",
    tokenizer=tokenizer,
    ds_name=TASK_NAME,
    num_image_token=256,
    image_size=448,
    is_train=True,
    min_num_frame=8,
    max_num_frame=8
)

eval_dataset = VideoQADataset(
    template_name='internvl2_5',
    raw_data=dataset["test"],
    video_data_dir=f"{data_dir}/video/videos_unzipped/{video_dir}",
    tokenizer=tokenizer,
    ds_name=TASK_NAME,
    num_image_token=256,
    image_size=448,
    is_train=False,
    min_num_frame=8,
    max_num_frame=8
)

print(f"Train dataset: {len(train_dataset)} examples")
print(f"Eval dataset: {len(eval_dataset)} examples")

[Dataset] num_image_token: 256
[Dataset] dynamic_image_size: False
[Dataset] use_thumbnail: False
[Dataset] min_dynamic_patch: 1, max_dynamic_patch: 6
Formatting inputs...Skip in lazy mode
[Dataset] num_image_token: 256
[Dataset] dynamic_image_size: False
[Dataset] use_thumbnail: False
[Dataset] min_dynamic_patch: 1, max_dynamic_patch: 6
Formatting inputs...Skip in lazy mode
Train dataset: 150 examples
Eval dataset: 38 examples


In [41]:
# Display a sample from the training dataset
sample_data = train_dataset[0]

print("Keys in sample:", sample_data.keys())
print("Input IDs shape:", sample_data['input_ids'].shape)
print("Pixel values shape:", sample_data['pixel_values'].shape)
print("Labels shape:", sample_data['labels'].shape)

Keys in sample: dict_keys(['input_ids', 'labels', 'attention_mask', 'pixel_values', 'image_flags', 'answer_choice'])
Input IDs shape: torch.Size([4096])
Pixel values shape: torch.Size([8, 3, 448, 448])
Labels shape: torch.Size([4096])


# Model

## Load model
Next, load your InternVL model from the hub. This is a model with about 1 billion trainable parameters (as it combines a **Qwen2 1B language model** with a relatively low-parameter vision **InternViT encoder**). Do note that we load a model here which already has undergone supervised fine-tuning (SFT) instructions dataset. We can benefit from the fine-tuning that the model already has undergone.

## Full fine-tuning, LoRa and Q-LoRa

**Select the fine-tuning method.**

 For reference, fine-tuning a model using the AdamW optimizer (which is often used to optimize neural networks) with mixed precision, you need about 18 times the amount of parameters in GB of GPU RAM. So in this case, we would need 18x1 billion bytes = 18 GB of GPU RAM if we want to update all the parameters of the model. Not so huge right? But using PEFT approach it could be less.

Some clever people came up with the LoRa method (LoRa is short for low-rank adapation). It allows to just freeze the existing weights and only train a couple of adapter layers on top of the base model. Hugging Face offers the separate [PEFT library](https://huggingface.co/docs/peft/main/en/index) for easy use of LoRa, along with other Parameter-Efficient Fine-Tuning methods.

Moreover, one can not only freeze the existing base model but also quantize it (which means, shrinking down its size). A neural network's parameters are typically saved in either float32 (which means, 32 bits or 4 bytes are used to store each parameter value) or float16 (which means, 16 bits or half a byte - also called half precision). However, with some clever algorithms one can shrink each parameter to just 8 or 4 bits (half a byte!), without significant effect on final performance. Read all about it here: https://huggingface.co/blog/4bit-transformers-bitsandbytes.

This means that we're going to shrink the size of the base 1B model considerably using 4-bit quantization, and then only train a couple of adapter layers on top using LoRa (in float16). This idea of combining LoRa with quantization is called Q-LoRa and is the most memory friendly version.

There exist many forms of quantization, here we leverage the [BitsAndBytes integration](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig).

In [42]:
## Load model
# Three options for training, from the lowest precision training to the highest precision training:
# QLoRA: model uses 4-bit quantization, which helps in reducing memory usage while maintaining performance.
# Standard LoRA:  model is loaded with standard LoRA adaptations.
# Full Fine-Tuning: no memory optimization are done. In that case Flash Attention is used to speed up training, if hardware supports it.

# We use QLoRA as it's the most memory-efficient approach
# This allows fine-tuning on consumer GPUs while maintaining good performance
from transformers import AutoModel


if USE_QLORA:
    # Configure 4-bit quantization using bitsandbytes
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,  # Nested quantization for additional memory savings
        bnb_4bit_quant_type="nf4",  # NormalFloat 4-bit quantization
        bnb_4bit_compute_dtype=torch.bfloat16  # Compute in bfloat16 for better performance
    )
    
    # Load model with 4-bit quantization
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        use_flash_attn=False,
        trust_remote_code=True,
        device_map="auto"
    )
    print("Model loaded with QLoRA (4-bit quantization)")
    
elif USE_LORA:
    # Load model in bfloat16 without quantization
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto",
        use_flash_attn=False
    )
    print("Model loaded for standard LoRA")
    
else:
    # Full fine-tuning - load model without any optimizations
    model = AutoModel.from_pretrained(
        MODEL_ID,
        dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto",
        attn_implementation="flash_attention_2"  # Use Flash Attention if supported
    )
    print("Model loaded for full fine-tuning")

if hasattr(model.config, 'img_context_token_id'):
    model.img_context_token_id = model.config.img_context_token_id
else:
     model.img_context_token_id = tokenizer.convert_tokens_to_ids('<IMG_CONTEXT>')
    
# Print model info
print(f"Model type: {type(model)}")
print(f"Model device: {model.device}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

Model loaded with QLoRA (4-bit quantization)
Model type: <class 'transformers_modules.OpenGVLab.InternVL3_5_hyphen_1B.2f71cf52542334823e48a46ffba0e2bc9add3446.modeling_internvl_chat.InternVLChatModel'>
Model device: cuda:0
Total parameters: 687,080,448


In [43]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

def find_all_linear_names(model):
    # Only for LoRA ot QLoRA

    target_modules = [
        "q_proj",
        "k_proj", 
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ]

    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['multi_modal_projector', 'vision_model']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            lin_layer_name = name.split('.')[-1]
            if lin_layer_name in target_modules:
                lora_module_names.add(lin_layer_name)

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

original_dtype = model.dtype
# If you selected LoRA ot QLora make a choise of parameters to replace
model = prepare_model_for_kbit_training(model)

# Find target modules for LoRA
target_modules = find_all_linear_names(model)
print(f"Target modules for LoRA: {target_modules}")

# Then create LoraConfig and run prepare_model_for_kbit_training(...)
# and finally: model = get_peft_model(model, ...)

lora_config = LoraConfig(
    r=16,  # Rank
    lora_alpha=32,  # LoRA alpha
    target_modules=target_modules,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
print(model.dtype)

model.print_trainable_parameters()

Target modules for LoRA: ['o_proj', 'up_proj', 'k_proj', 'q_proj', 'gate_proj', 'v_proj', 'down_proj']
torch.float32
trainable params: 10,092,544 || all params: 1,070,990,336 || trainable%: 0.9424


In [92]:
vision_cfg = model.config.vision_config
print(vision_cfg.image_size, vision_cfg.patch_size)
tokens = (vision_cfg.image_size // vision_cfg.patch_size) ** 2
print(tokens)

448 14
1024


## Define PyTorch Lightning Module for Video-LLaVA
To streamline the training and evaluation of the Video-InternVL model, you can use [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html), which abstracts away much of the boilerplate code and provides a structured framework for model training. In this section, you need to define the InternVLModelPLModule, a custom PyTorch Lightning module that encapsulates the model, training loop, validation loop, and optimizer configuration.

### InternVLModelPLModule Class

The InternVLModelPLModule class inherits from LightningModule and includes methods for training, validation, and optimizer configuration. This setup ensures a clean and efficient training process.

Basically, PyTorch Lightning will take care of all device placements (.to(device)) for us, as well as the backward pass, putting the model in training mode, etc.

Notice the difference between a training step and an evaluation step:

- a training step only consists of a forward pass, in which we compute the cross-entropy loss between the model's next token predictions and the ground truth (in parallel for all tokens, this technique is known as "teacher forcing"). The backward pass is handled by PyTorch Lightning.
- an evaluation step consists of making the model autoregressively complete the prompt using the generate() method. After that, you compute an evaluation metric between the predicted sequences and the ground truth ones. This allows you to see how the model is improving over the course of training. The metric we use here is accuracy of answering the question.

Besides that, you define the optimizer to use (AdamW is a good default choice) and the data loaders, which use the collate functions defined above to batch together items of the PyTorch datasets. Do note that AdamW is a pretty heavy optimizer in terms of memory requirements, but as we're training with QLoRa we only need to store optimizer states for the adapter layers. For full fine-tuning, one could take a look at more memory friendly optimizers such as 8-bit Adam.

In [86]:
import pytorch_lightning as L
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
import torch.nn.functional as F

class InternVLModelPLModule(L.LightningModule):
    def __init__(self, config, tokenizer, model):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer
        self.model = model
        self.model.train()
        
        # Training parameters
        self.batch_size = config.get('batch_size', 1)
        self.learning_rate = config.get('learning_rate', 1e-4)
        self.warmup_steps = config.get('warmup_steps', 50)
        self.max_epochs = config.get('max_epochs', 10)
        
        # Save hyperparameters
        self.save_hyperparameters(ignore=['model', 'tokenizer'])
        
    def training_step(self, batch, batch_idx):
        self.model.train()
        # Extract inputs
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        image_flags = batch['image_flags']
        
        # Forward pass
        # Direct forward pass without PEFT wrapper issues
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            outputs = self.model.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pixel_values=pixel_values,
                labels=labels,
                image_flags=image_flags,
                return_dict=True
            )
        
        loss = outputs.loss
        
        # Log training loss
        self.log('train_loss', loss, prog_bar=True, logger=True, sync_dist=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx, dataset_idx=0):
        self.model.eval()
        # For validation, we'll compute both loss and generation accuracy
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        image_flags = batch['image_flags']
        answer_choices = batch.get('answer_choices', [])


        # Compute validation loss
        val_loss = None
        if labels is not None:
            with torch.no_grad():
                with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                    outputs = self.model.base_model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        pixel_values=pixel_values,
                        labels=labels,
                        image_flags=image_flags,
                        return_dict=True
                    )
                val_loss = outputs.loss
                self.log('val_loss', val_loss, prog_bar=True, logger=True, sync_dist=True)

        # Generate responses for accuracy calculation, add generation prompt
        response_ids = []
        response_masks = []
        for input_id in input_ids:
            decoded_text = self.tokenizer.decode(input_id, skip_special_tokens=False).strip()
            response = decoded_text.split('<|im_start|>assistant')[0]
            response += '\nOnly give the best option.'
            tokenized = tokenizer(
                response + '<|im_start|>assistant',
                return_tensors='pt',
                padding='max_length',
                max_length=tokenizer.model_max_length,
                truncation=False,
            )
            response_ids.append(tokenized.input_ids)
            response_masks.append(tokenized.attention_mask)

        response_ids = torch.cat(response_ids).to(self.model.device)
        response_masks = torch.cat(response_masks).to(self.model.device)
        prompt_lengths = response_masks.sum(dim=1)
            
        with torch.no_grad():
            with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                generated_ids = self.model.base_model.generate(
                    input_ids=response_ids,
                    attention_mask=response_masks,
                    pixel_values=pixel_values,
                    max_new_tokens=10,
                    do_sample=False,
                    temperature=1.0,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
                assistant_texts = []
                for i in range(len(generated_ids)):
                    start = int(prompt_lengths[i].item())
                    assistant_tokens = generated_ids[i, start:]
                    assistant_text = self.tokenizer.decode(assistant_tokens, skip_special_tokens=True).strip()
                    assistant_text = f"Best option:({assistant_text})"
                    assistant_texts.append(assistant_text)

        print(f"\n=== Epoch {self.current_epoch}, Batch {batch_idx} ===")
        for i in range(len(generated_ids)):  # First 2 samples
            input_text = self.tokenizer.decode(response_ids[i], skip_special_tokens=True).strip()
            generated_text = assistant_texts[i]
            expected_answer = answer_choices[i] if i < len(answer_choices) else ''
            print(f"Sample {i}:")
            print(f"  Expected: '{expected_answer}'")
            print(f"  Input: '{input_text}'")
            print(f"  Generated: {assistant_texts[i]!r}")
            assistant_upper = assistant_texts[i].upper().strip()
            expected_upper = (expected_answer or "").upper()
            match_value = bool(expected_upper and (assistant_upper.startswith(expected_upper) or f'({expected_upper}' in assistant_upper))
            print(f"  Match: {match_value}")
        
        # Calculate accuracy
        accuracy = self._calculate_accuracy(assistant_texts, answer_choices)
        self.log('val_accuracy', accuracy, prog_bar=True, logger=True, sync_dist=True)
        
        return {'val_loss': val_loss, 'val_accuracy': accuracy}
    
    def _calculate_accuracy(self, assistant_texts, answer_choices):
        """Calculate accuracy by comparing generated text with answer choices"""
        if not answer_choices:
            print('############# There is no answers #########')
            return 0.0
        
        correct = 0
        total = len(assistant_texts)
        
        for text, expected in zip(assistant_texts, answer_choices):
            expected_answer = (expected or "").upper()
            text_upper = text.upper().strip()
            if expected_answer and (text_upper.startswith(expected_answer) or f'({expected_answer}' in text_upper):
                correct += 1
        
        return correct / total if total > 0 else 0.0
    
    def configure_optimizers(self):
        """Configure optimizer and learning rate scheduler"""
        # Separate parameters for different weight decay
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() 
                          if p.requires_grad and not any(nd in n for nd in no_decay)],
                "weight_decay": self.config.get('weight_decay', 0.01),
            },
            {
                "params": [p for n, p in self.model.named_parameters() 
                          if p.requires_grad and any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-8
        )
        
        # Calculate total training steps
        train_loader = self.train_dataloader()
        steps_per_epoch = len(train_loader)
        total_steps = steps_per_epoch * self.max_epochs
        
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=total_steps
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
                'frequency': 1
            }
        }
    
    def train_dataloader(self):
        return DataLoader(
            train_dataset, 
            collate_fn=train_collate_fn,
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=4
        )
    
    def val_dataloader(self):
        return DataLoader(
            eval_dataset, 
            collate_fn=train_collate_fn,
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=4
        )

In [87]:
config = {"max_epochs": 6,
          #"val_check_interval": 1.0, # how many times we want to validate during an epoch
          "check_val_every_n_epoch": 1,
          "gradient_clip_val": 1.0,
          "accumulate_grad_batches": 8,
          "lr": 1e-4,
          "batch_size": 1,
          "num_nodes": 1,
          "warmup_steps": 50,
}

model_module = InternVLModelPLModule(
    config=config,
    tokenizer=tokenizer,
    model=model  # Your loaded model
)

## Define callbacks
Optionally, Lightning allows to define so-called [callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), which are arbitrary pieces of code that can be executed during training.

You'd better use the EarlyStopping callback of Lightning, which will automatically stop training once the evaluation metric (edit distance in our case) doesn't improve after 3 epochs.

In [46]:
from huggingface_hub import HfApi
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

logger = TensorBoardLogger("tb_logs", name="InternVL3.5_finetuning_qlora")

# api = HfApi()

# class PushToHubCallback(Callback):
#     def on_train_epoch_end(self, trainer, pl_module):
#         print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
#         pl_module.model.push_to_hub(REPO_ID,
#                                     commit_message=f"Training in progress, epoch {trainer.current_epoch}")

#     def on_train_end(self, trainer, pl_module):
#         print(f"Pushing model to the hub after training")
#         pl_module.processor.push_to_hub(REPO_ID,
#                                     commit_message=f"Training done")
#         pl_module.model.push_to_hub(REPO_ID,
                                    # commit_message=f"Training done")


checkpoint_callback = ModelCheckpoint(
    dirpath='./checkpoints',           # Directory to save checkpoints
    filename='internvl-epoch-{epoch:02d}-val_acc_{val_accuracy:.2f}',  # File name pattern
    save_top_k=-1,                     # Save all checkpoints (-1 means all)
    every_n_epochs=1,                  # Save every epoch
    save_last=True,                    # Also save a 'last.ckpt' file
    verbose=True
)

early_stop_callback = EarlyStopping(monitor="val_accuracy",
                                    patience=3, verbose=False, mode="max")

## Train!
 Trainer class supports many more flags. See the [docs](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#lightning.pytorch.trainer.trainer.Trainer)

In [47]:
from pytorch_lightning.loggers import CSVLogger
csv_logger = CSVLogger(save_dir="logs", name="internvl_qlora_video")

trainer = L.Trainer(
    max_epochs=config["max_epochs"],
    check_val_every_n_epoch=config["check_val_every_n_epoch"],
    gradient_clip_val=config["gradient_clip_val"],
    accumulate_grad_batches=config["accumulate_grad_batches"],
    callbacks=[early_stop_callback, checkpoint_callback],
    logger=csv_logger,
    precision="16-mixed",
    accelerator="auto",
    devices="auto",
    log_every_n_steps=10,
    deterministic=True,
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [48]:
trainer.fit(model_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/opt/conda/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name  | Type                 | Params | Mode 
-------------------------------------------------------
0 | model | PeftModelForCausalLM | 697 M  | train
-------------------------------------------------------
10.1 M    Trainable params
687 M     Non-trainable params
697 M     Total params
2,788.692 Total estimated model params size (MB)
2736      Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.



=== Epoch 0, Batch 0 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person opened the door?
A. Put down the paper/notebook.;
B. Tidied up the blanket.;
C. Sat on the floor.;
D. Took the shoe.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 
assistant'
  Generated: 'yet?'
  Match: False

=== Epoch 0, Batch 1 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person washed the clothes?
A. Put down the phone/camera.;
B. Put down the clothes.;
C. Put down the towel.;
D. Washed the window.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 
assistant'
  Generated: ':



;:01: 1:'
  Match: False


Training: |          | 0/? [00:00<?, ?it/s]

  return fn(*args, **kwargs)


Validation: |          | 0/? [00:00<?, ?it/s]


=== Epoch 0, Batch 0 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person opened the door?
A. Put down the paper/notebook.;
B. Tidied up the blanket.;
C. Sat on the floor.;
D. Took the shoe.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 
assistant'
  Generated: '***.'
  Match: False

=== Epoch 0, Batch 1 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person washed the clothes?
A. Put down the phone/camera.;
B. Put down the clothes.;
C. Put down the towel.;
D. Washed the window.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 
assistant'
  Generated: 'else: 1: 1: 1'
  Match: False

=== Epoch 0, Batch 2 ===
Sample 0:
  Expected: 'D'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened after the person took the dish?
A. T

Validation: |          | 0/? [00:00<?, ?it/s]


=== Epoch 1, Batch 0 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person opened the door?
A. Put down the paper/notebook.;
B. Tidied up the blanket.;
C. Sat on the floor.;
D. Took the shoe.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 
assistant'
  Generated: '$$ 8  Put down the paper/note'
  Match: False

=== Epoch 1, Batch 1 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person washed the clothes?
A. Put down the phone/camera.;
B. Put down the clothes.;
C. Put down the towel.;
D. Washed the window.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 
assistant'
  Generated: 'gone 2: Put down the towel..'
  Match: False

=== Epoch 1, Batch 2 ===
Sample 0:
  Expected: 'D'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happe

Validation: |          | 0/? [00:00<?, ?it/s]


=== Epoch 2, Batch 0 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person opened the door?
A. Put down the paper/notebook.;
B. Tidied up the blanket.;
C. Sat on the floor.;
D. Took the shoe.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 
assistant'
  Generated: '$$  A Put down the paper/notebook'
  Match: False

=== Epoch 2, Batch 1 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person washed the clothes?
A. Put down the phone/camera.;
B. Put down the clothes.;
C. Put down the towel.;
D. Washed the window.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 
assistant'
  Generated: '.. 2  Put down the phone/c'
  Match: False

=== Epoch 2, Batch 2 ===
Sample 0:
  Expected: 'D'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What hap

Validation: |          | 0/? [00:00<?, ?it/s]


=== Epoch 3, Batch 0 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person opened the door?
A. Put down the paper/notebook.;
B. Tidied up the blanket.;
C. Sat on the floor.;
D. Took the shoe.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 
assistant'
  Generated: '***.
.'
  Match: False

=== Epoch 3, Batch 1 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person washed the clothes?
A. Put down the phone/camera.;
B. Put down the clothes.;
C. Put down the towel.;
D. Washed the window.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 
assistant'
  Generated: '20;'
  Match: False

=== Epoch 3, Batch 2 ===
Sample 0:
  Expected: 'D'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened after the person took the dish?
A. Threw the

In [88]:
# Optional: run a continuation fine-tuning
extra_epochs = 1
target_max_epochs = model_module.current_epoch + 1 + extra_epochs
print(f'Continuing training up to epoch {target_max_epochs} (current epoch: {model_module.current_epoch})')
continue_trainer = L.Trainer(
    max_epochs=target_max_epochs,
    check_val_every_n_epoch=config['check_val_every_n_epoch'],
    gradient_clip_val=config['gradient_clip_val'],
    accumulate_grad_batches=config['accumulate_grad_batches'],
    callbacks=[early_stop_callback, checkpoint_callback],
    logger=csv_logger,
    precision='16-mixed',
    accelerator='auto',
    devices='auto',
    log_every_n_steps=10,
    deterministic=True,
)
continue_trainer.fit(
    model_module,
    ckpt_path=None,
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /home/jovyan/work/genai_work/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/opt/conda/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name  | Type                 | Params | Mode 
-------------------------------------------------------
0 | model | PeftModelForCausalLM | 697 M  | train
-------------------------------------------------------
10.1 M    Trainable params
687 M     Non-trainable params
697 M     Total params
2,788.692 Total estimated model params size (MB)
2736      Modules in train mode
0       

Continuing training up to epoch 2 (current epoch: 0)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


=== Epoch 0, Batch 0 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person opened the door?
A. Put down the paper/notebook.;
B. Tidied up the blanket.;
C. Sat on the floor.;
D. Took the shoe.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 

Only give the best option.assistant'
  Generated: 'Best option:()'
  Match: False

=== Epoch 0, Batch 1 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person washed the clothes?
A. Put down the phone/camera.;
B. Put down the clothes.;
C. Put down the towel.;
D. Washed the window.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 

Only give the best option.assistant'
  Generated: 'Best option:()'
  Match: False


Training: |          | 0/? [00:00<?, ?it/s]

  return fn(*args, **kwargs)


Validation: |          | 0/? [00:00<?, ?it/s]


=== Epoch 0, Batch 0 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person opened the door?
A. Put down the paper/notebook.;
B. Tidied up the blanket.;
C. Sat on the floor.;
D. Took the shoe.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 

Only give the best option.assistant'
  Generated: 'Best option:()'
  Match: False

=== Epoch 0, Batch 1 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person washed the clothes?
A. Put down the phone/camera.;
B. Put down the clothes.;
C. Put down the towel.;
D. Washed the window.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 

Only give the best option.assistant'
  Generated: 'Best option:()'
  Match: False

=== Epoch 0, Batch 2 ===
Sample 0:
  Expected: 'D'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的

Validation: |          | 0/? [00:00<?, ?it/s]


=== Epoch 1, Batch 0 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person opened the door?
A. Put down the paper/notebook.;
B. Tidied up the blanket.;
C. Sat on the floor.;
D. Took the shoe.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 

Only give the best option.assistant'
  Generated: 'Best option:()'
  Match: False

=== Epoch 1, Batch 1 ===
Sample 0:
  Expected: 'C'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。
user
What happened before the person washed the clothes?
A. Put down the phone/camera.;
B. Put down the clothes.;
C. Put down the towel.;
D. Washed the window.;
Frame1: 
Frame2: 
Frame3: 
Frame4: 
Frame5: 
Frame6: 
Frame7: 
Frame8: 

Only give the best option.assistant'
  Generated: 'Best option:()'
  Match: False

=== Epoch 1, Batch 2 ===
Sample 0:
  Expected: 'D'
  Input: 'system
你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的

`Trainer.fit` stopped: `max_epochs=2` reached.


In [89]:
# Prepare the input for inference
sample = dataset['test'][0]
clip = read_video_pyav(
    f"{data_dir}/video/videos_unzipped/{video_dir}/{sample['video']}",
    sample.get('start', 0),
    sample.get('end', 1e10),
    n_frames=8,
)

video_frames = [
    frame if isinstance(frame, Image.Image) else Image.fromarray(frame)
    for frame in clip
]

system_prompt = (
    "Carefully watch the video and pay attention to the cause and sequence of events, "
    "the detail and movement of objects, and the action and pose of persons. Based on your observations, "
    "select the best option that accurately addresses the question."
)
conversation = [
    {
        'role': 'system',
        'content': system_prompt
    },
    {
        'role': 'user',
        'content': sample['conversations'][0]['value'] + '\nOnly give the best option.'
    }
]
prompt_text = processor.apply_chat_template(
    conversation, tokenize=False, add_generation_prompt=True
)

print(f"Prompt:\n{prompt_text}\n")

# Process the input
inputs = processor(
    text=[prompt_text],
    videos=[video_frames],
    padding=True,
    return_tensors="pt"
).to(model.device)

print(f"Input IDs shape: {inputs['input_ids'].shape}")
print(f"Pixel values shape: {inputs.get('pixel_values_videos', torch.tensor([])).shape}")

Prompt:
<|im_start|>system
Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.<|im_end|>
<|im_start|>user
What happened after the person lied on the bed?
A. Washed the dish.;
B. Opened the laptop.;
C. Put down the pillow.;
D. Took the towel.;
<video>
Only give the best option.<|im_end|>
<|im_start|>assistant


Input IDs shape: torch.Size([1, 2206])
Pixel values shape: torch.Size([0])


In [61]:
from tqdm.auto import tqdm

In [90]:
model_module = InternVLModelPLModule(config=config, tokenizer=tokenizer, model=model)

In [93]:
# Evaluate model accuracy on the validation split
from torch.utils.data import DataLoader

eval_dataloader = DataLoader(
    eval_dataset,
    collate_fn=train_collate_fn,
    batch_size=1,
    shuffle=False,
    num_workers=4,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model_module.eval()
model_module.model.to(device)
model_module.model.eval()

correct = 0
total = 0

for batch in tqdm(eval_dataloader):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    pixel_values = batch['pixel_values'].to(device, dtype=getattr(model_module.model, 'dtype', torch.float32))
    answer_choices = batch['answer_choices']

    response_ids = []
    response_masks = []
    for input_id in input_ids:
        decoded_text = tokenizer.decode(input_id, skip_special_tokens=False).strip()
        response_prefix = decoded_text.split('<|im_start|>assistant')[0]
        tokenized = tokenizer(
            response_prefix + '<|im_start|>assistant',
            return_tensors='pt',
            padding='max_length',
            max_length=tokenizer.model_max_length,
            truncation=False,
        )
        response_ids.append(tokenized.input_ids)
        response_masks.append(tokenized.attention_mask)

    response_ids = torch.cat(response_ids, dim=0).to(device)
    response_masks = torch.cat(response_masks, dim=0).to(device)
    prompt_lengths = response_masks.sum(dim=1)

    with torch.no_grad():
        if device.type == 'cuda':
            with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                generated_ids = model_module.model.base_model.generate(
                    input_ids=response_ids,
                    attention_mask=response_masks,
                    pixel_values=pixel_values,
                    max_new_tokens=10,
                    do_sample=False,
                    temperature=1.0,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
        else:
            generated_ids = model_module.model.base_model.generate(
                input_ids=response_ids,
                attention_mask=response_masks,
                pixel_values=pixel_values,
                max_new_tokens=10,
                do_sample=False,
                temperature=1.0,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

    assistant_texts = []
    for i in range(len(generated_ids)):
        start = int(prompt_lengths[i].item())
        assistant_tokens = generated_ids[i, start:]
        assistant_text = tokenizer.decode(assistant_tokens, skip_special_tokens=True).strip()
        assistant_texts.append(assistant_text)

    for predicted_text, answer in zip(assistant_texts, answer_choices):
        predicted_upper = predicted_text.upper().strip()
        expected_upper = (answer or "").upper()
        if expected_upper and predicted_upper.startswith(expected_upper):
            correct += 1
        total += 1

test_accuracy = correct / total if total else 0.0
print(f"Validation accuracy: {test_accuracy:.4f}")

Using device: cuda


  0%|          | 0/38 [00:01<?, ?it/s]

Validation accuracy: 0.0000


In [94]:
import pandas as pd

# Create a summary table with results
results = {
    "Task": ["Action Sequence"],
    "Model": ["InternVL3.5-1B"],
    "Fine-tuning Method": ["QLoRA (4-bit + LoRA)"],
    "Train Dataset Size": [len(train_dataset)],
    "Test Dataset Size": [len(eval_dataset)],
    "Metric": ["Accuracy"],
    "Test Accuracy": [f"{test_accuracy:.4f}"],
    "Training Epochs": [config["max_epochs"]],
    "Learning Rate": [config["lr"]],
    "Batch Size": [config["batch_size"]],
    "LoRA Rank": [16],
    "LoRA Alpha": [32],
}

results_df = pd.DataFrame(results)
print("\n" + "="*80)
print("FINAL EVALUATION RESULTS")
print("="*80)
print(results_df.to_string(index=False))
print("="*80)

# After training, you can update the test accuracy by running evaluation on the full test set
# Example: test_accuracy = trainer.validate(model_module)
# Then update: results["Test Accuracy"] = [f"{test_accuracy:.4f}"]


FINAL EVALUATION RESULTS
           Task          Model   Fine-tuning Method  Train Dataset Size  Test Dataset Size   Metric Test Accuracy  Training Epochs  Learning Rate  Batch Size  LoRA Rank  LoRA Alpha
Action Sequence InternVL3.5-1B QLoRA (4-bit + LoRA)                 150                 38 Accuracy        0.0000                6         0.0001           1         16          32


The poor quality may have been caused by choosing QLoRA at 4 bits, as we only retrain thin LoRA ranks.
The prompt format with special tokens, etc., may not be entirely correct.