<a href="https://colab.research.google.com/github/Yogesh914/VideoGPT-plus/blob/main/VideoGPT%2B_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VideoGPT+ Demo 🎥

## Setup

In [None]:
!git clone https://github.com/mbzuai-oryx/VideoGPT-plus

In [None]:
%cd VideoGPT-plus

In [None]:
!pip install huggingface_hub

from huggingface_hub import notebook_login
notebook_login()

In [None]:
!git clone https://huggingface.co/OpenGVLab/InternVideo2-Stage2_1B-224p-f4

In [None]:
!mkdir OpenGVLab
!mv InternVideo2-Stage2_1B-224p-f4 OpenGVLab/

In [None]:
!git clone https://huggingface.co/MBZUAI/VideoGPT-plus_Phi3-mini-4k

In [None]:
!pip install shortuuid timm einops flash-attn decord mmengine ninja peft

## Inference

In [None]:
from tqdm import tqdm
from videogpt_plus.conversation import conv_templates
from videogpt_plus.model.builder import load_pretrained_model
from videogpt_plus.mm_utils import tokenizer_image_token, get_model_name_from_path
from eval.vcgbench.inference.ddp import *
from eval.video_encoding import _get_rawvideo_dec
import traceback
import pandas as pd
import torch
import os

# Disable parameter resetting for specific layers
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)

# Configuration
model_path = "VideoGPT-plus_Phi3-mini-4k/vcgbench"
model_base = "microsoft/Phi-3-mini-4k-instruct"
video_path = "/home/ayildiz/sample_videos/test.mp4"
qs = "What animal is in this video, and what is it doing?"

conv_mode = "phi3_instruct"
stop_str = "<|end|>"
temperature = 0.0

# Load pretrained model and tokenizer
model_path = os.path.expanduser(model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path, model_base, model_name
)

# Configure model for multimodal use
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
    tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
    tokenizer.add_tokens(
        [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
    )
model.resize_token_embeddings(len(tokenizer))

# Load vision towers
vision_tower = model.get_vision_tower()
vision_tower.load_model(model.config.mm_vision_tower)
video_processor = vision_tower.image_processor

image_vision_tower = model.get_image_vision_tower()
image_vision_tower.load_model()
image_processor = image_vision_tower.image_processor

# Move model to GPU
model = model.to("cuda")

# Process video if it exists
if os.path.exists(video_path):
    video_frames, context_frames, slice_len = _get_rawvideo_dec(
        video_path,
        image_processor,
        video_processor,
        max_frames=NUM_FRAMES,
        image_resolution=224,
        num_video_frames=NUM_FRAMES,
        num_context_images=NUM_CONTEXT_IMAGES,
    )

# Prepare query string with image tokens
if model.config.mm_use_im_start_end:
    qs = (
        DEFAULT_IM_START_TOKEN
        + DEFAULT_IMAGE_TOKEN * slice_len
        + DEFAULT_IM_END_TOKEN
        + "\n"
        + qs
    )
else:
    qs = DEFAULT_IMAGE_TOKEN * slice_len + "\n" + qs

# Create conversation prompt
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = (
    tokenizer_image_token(
        prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
    )
    .unsqueeze(0)
    .cuda()
)

# Generate output
with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=torch.stack(video_frames).half().cuda(),
        context_images=torch.stack(context_frames).half().cuda(),
        do_sample=True if temperature > 0 else False,
        temperature=temperature,
        top_p=None,
        num_beams=1,
        max_new_tokens=1024,
        use_cache=True,
    )

# Validate output
input_token_len = input_ids.shape[1]
n_diff_input_output = (
    (input_ids != output_ids[:, :input_token_len]).sum().item()
)
if n_diff_input_output > 0:
    print(
        f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
    )

# Decode and clean up output
outputs = tokenizer.batch_decode(
    output_ids[:, input_token_len:], skip_special_tokens=True
)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
    outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
outputs = outputs.replace("<|end|>", "")
outputs = outputs.strip()

print(outputs)