## Set-up

Necessary imports and helper functions

In [None]:
import json
import os
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from PIL import Image
from typing import Dict, Optional, Sequence, List

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from mask_vis_tools import save_mask, save_box, save_points

torch.cuda.empty_cache()

Define Task Modes

In [None]:
task_modes = {
    "video_cap_en": "<image>\nProvide a caption for this subject in this video.\nPrevious caption: {history}",
    "video_cap_ch": "<image>\n对该视频中的目标对象提供一个描述。\n历史描述：{history}",
}

Load PAM Model

In [None]:
device = "cuda:0"
model_path = "path/to/PAM-ckpt"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path),
    multimodal=True,
    torch_dtype="bfloat16", # bfloat16, float16
    device_map="cuda"
)
model.to(device)

temperature = 0.1
top_p = None
num_beams = 1

Example Video

Please run the 'python videos/extract_mp4_frames.py' to extract the frames(.jpg)

In [None]:
video_dir = "videos/02_juggle"
sample_num = 96
# scan all the JPEG frame names in this directory
frames = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frames.sort(key=lambda p: int(os.path.splitext(p)[0]))
indices = np.linspace(0, len(frames) - 1, num=sample_num, dtype=int)
frames = [os.path.join(video_dir, frames[i]) for i in indices]
print(frames)

# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(frames[frame_idx]))

Specifying a specific object with a box. The model takes a box as input, provided in xyxy format.

In [None]:
def get_bounding_box(mask_array, type="xyxy"):

    if mask_array.dtype != np.uint8:
        mask_uint8 = mask_array.astype(np.uint8)
    else:
        mask_uint8 = mask_array

    contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)

    if not contours:
        return None

    all_points = np.concatenate(contours, axis=0)
    if all_points.size == 0:
        return None

    x, y, w, h = cv2.boundingRect(all_points)
    if type == "xyxy":
        return [int(x), int(y), int(x)+int(w), int(y)+int(h)]
    else:
        return [int(x), int(y), int(w), int(h)]

### prompting box for the first frame
bbox_xywh = [587, 32, 251, 650]
x1, y1 = int(bbox_xywh[0]), int(bbox_xywh[1])
x2, y2 = int(bbox_xywh[0] + bbox_xywh[2]), int(bbox_xywh[1] + bbox_xywh[3])
bbox_xyxy = [x1, y1, x2, y2]

Generate semantic outputs and mask outputs.

In [18]:
### define decoding timesteps
DECODING_TIMESTEPS = [32, 64, 96]
PROMPT_TEMPLATE = task_modes["video_cap_en"]

for segment_idx, current_segment_end_frame in enumerate(DECODING_TIMESTEPS):
    print(f"\n--- Processing Segment {segment_idx + 1} ---")
    if segment_idx == 0:
        segment_start_frame = 0
        visual_prompt = bbox_xyxy
        vp_labels = None
        task_prompt = PROMPT_TEMPLATE.replace("{history}", "None")
    else:
        segment_start_frame = DECODING_TIMESTEPS[segment_idx - 1]
        visual_prompt = last_segment_bbox_xyxy
        vp_labels = None
        task_prompt = PROMPT_TEMPLATE.replace("{history}", history)

    current_frames = frames[segment_start_frame:current_segment_end_frame]

    
    conv = conv_templates["qwen_2"].copy()
    conv.append_message(conv.roles[0], task_prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)

    with torch.inference_mode():
        output_ids, all_masks_logits, all_scores = model.generate(
            input_ids,
            images=[current_frames],
            visual_prompts=[visual_prompts],
            vp_labels=[vp_labels],
            types=['video'],
            do_sample=True if temperature > 0 else False,
            temperature=temperature,
            top_p=top_p,
            num_beams=num_beams,
            max_new_tokens=512,
            use_cache=True)

    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    history = outputs
    print(f"Frames: {segment_start_frame} to {current_segment_end_frame - 1}")
    print("Prediction: ", outputs)

    vis_frame_stride = 8
    for out_frame_idx in range(len(current_frames)):
        if out_frame_idx % vis_frame_stride == 0 or out_frame_idx == len(current_frames) - 1:
            img = Image.open(current_frames[out_frame_idx]).convert('RGB')
            img = np.array(img)
            fig, ax = plt.subplots(figsize=(9, 6))
            ax.imshow(img)
            ax.set_title(f"Segment {segment_idx + 1}, frame {segment_start_frame + out_frame_idx}, mask score: {all_scores[0][out_frame_idx]}")
            for i, mask in enumerate(all_masks_logits[0][out_frame_idx][1]):
                save_mask(mask, ax, borders=True)
            if segment_idx == 0 and out_frame_idx == 0:
                save_box(visual_prompts, ax)
            if out_frame_idx == len(current_frames) - 1:
                last_segment_bbox_xyxy = get_bounding_box(mask)