In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

os.chdir('../')

In [3]:
from utils.config import Config
config_file = "configs/config_mistral.json"
cfg = Config.from_file(config_file)

In [4]:
import io

from models import VideoChat2_it_mistral
from utils.easydict import EasyDict
import torch

from transformers import StoppingCriteria, StoppingCriteriaList

from PIL import Image
import numpy as np
import numpy as np
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms import PILToTensor
from torchvision import transforms
from dataset.video_transforms import (
    GroupNormalize, GroupScale, GroupCenterCrop, 
    Stack, ToTorchFormatTensor
)
from torch.utils.data import Dataset
from torchvision.transforms.functional import InterpolationMode

from torchvision import transforms

import matplotlib.pyplot as plt

from IPython.display import Video, HTML

from peft import get_peft_model, LoraConfig, TaskType
import copy

import json
from collections import OrderedDict

from tqdm import tqdm

import time
import decord
decord.bridge.set_bridge("torch")

In [5]:
# load stage2 model
cfg.model.vision_encoder.num_frames = 4
model = VideoChat2_it_mistral(config=cfg.model)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# add lora to run stage3 model
# peft_config = LoraConfig(
#     task_type=TaskType.CAUSAL_LM, inference_mode=False, 
#     r=16, lora_alpha=32, lora_dropout=0.,
#     target_modules=[
#         "q_proj", "k_proj", "v_proj", "o_proj",
#          "gate_proj", "up_proj", "down_proj", "lm_head"
#     ]
# )
# model.mistral_model = get_peft_model(model.mistral_model, peft_config)

In [6]:
# state_dict = torch.load("your_model_path/videochat2/videochat2_mistral_7b_stage3.pth", "cpu")


# if 'model' in state_dict.keys():
#     msg = model.load_state_dict(state_dict['model'], strict=False)
# else:
#     msg = model.load_state_dict(state_dict, strict=False)
# print(msg)

model = model.to(torch.device(cfg.device))
model = model.eval()

In [13]:
def get_prompt(conv):
    ret = conv.system + conv.sep
    for role, message in conv.messages:
        if message:
            ret += role + " " + message + " " + conv.sep
        else:
            ret += role
    return ret


def get_prompt2(conv):
    ret = conv.system + conv.sep
    count = 0
    for role, message in conv.messages:
        count += 1
        if count == len(conv.messages):
            ret += role + " " + message
        else:
            if message:
                ret += role + " " + message + " " + conv.sep
            else:
                ret += role
    return ret


def get_context_emb(conv, model, img_list, answer_prompt=None, print_res=False):
    if answer_prompt:
        prompt = get_prompt2(conv)
    else:
        prompt = get_prompt(conv)
    if print_res:
        print(prompt)
    if '<VideoHere>' in prompt:
        prompt_segs = prompt.split('<VideoHere>')
    else:
        prompt_segs = prompt.split('<ImageHere>')
    assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
    with torch.no_grad():
        seg_tokens = [
            model.mistral_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to("cuda:0").input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        # seg_embs = [model.mistral_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
        seg_embs = [model.mistral_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
    mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
    mixed_embs = torch.cat(mixed_embs, dim=1)
    return mixed_embs


def ask(text, conv):
    conv.messages.append([conv.roles[0], text])
        

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False
    
    
def answer(conv, model, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None, print_res=False):
    stop_words_ids = [
        torch.tensor([2]).to("cuda:0"),
        torch.tensor([29871, 2]).to("cuda:0")]  # '</s>' can be encoded in two different ways.
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    
    conv.messages.append([conv.roles[1], answer_prompt])
    embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt, print_res=print_res)
    with torch.no_grad():
        outputs = model.mistral_model.generate(
            inputs_embeds=embs,
            max_new_tokens=max_new_tokens,
            stopping_criteria=stopping_criteria,
            num_beams=num_beams,
            do_sample=do_sample,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
        )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
            output_token = output_token[1:]
    if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
            output_token = output_token[1:]
    output_text = model.mistral_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('</s>')[0]  # remove the stop sign </s>
#     output_text = output_text.split('[/INST]')[-1].strip()
    conv.messages[-1][1] = output_text + '</s>'
    return output_text, output_token.cpu().numpy()

In [14]:
def get_index(num_frames, num_segments):
    seg_size = float(num_frames - 1) / num_segments
    start = int(seg_size / 2)
    offsets = np.array([
        start + int(np.round(seg_size * idx)) for idx in range(num_segments)
    ])
    return offsets


def load_video(video_path, num_segments=8, return_msg=False, resolution=224):
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    num_frames = len(vr)
    frame_indices = get_index(num_frames, num_segments)

    # transform
    crop_size = resolution
    scale_size = resolution
    input_mean = [0.48145466, 0.4578275, 0.40821073]
    input_std = [0.26862954, 0.26130258, 0.27577711]

    transform = T.Compose([
        GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
        GroupCenterCrop(crop_size),
        Stack(),
        ToTorchFormatTensor(),
        GroupNormalize(input_mean, input_std) 
    ])

    images_group = list()
    for frame_index in frame_indices:
        img = Image.fromarray(vr[frame_index].numpy())
        images_group.append(img)
    torch_imgs = transform(images_group)
    if return_msg:
        fps = float(vr.get_avg_fps())
        sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
        # " " should be added in the start and end
        msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
        return torch_imgs, msg
    else:
        return torch_imgs

In [15]:
def get_sinusoid_encoding_table(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784): 
    ''' Sinusoid position encoding table ''' 
    # TODO: make it with torch instead of numpy 
    def get_position_angle_vec(position): 
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 
    
    # generate checkpoint position embedding
    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)]) 
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 
    sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
    
    print(f"n_position: {n_position}")
    print(f"pre_n_position: {pre_n_position}")
    
    if n_position != pre_n_position:
        T = ckpt_num_frame # checkpoint frame
        P = 14 # checkpoint size
        C = d_hid
        new_P = int((n_position // cur_frame) ** 0.5) # testing size
        if new_P != 14:
            print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
            print(f'Interpolate the position embedding')
            sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
            sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
            sinusoid_table = torch.nn.functional.interpolate(
                sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
            # BT, C, H, W -> BT, H, W, C ->  B, T, H, W, C
            sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
            sinusoid_table = sinusoid_table.flatten(1, 3)  # B, THW, C
    
    if cur_frame != ckpt_num_frame:
        print(f'Pretraining uses 4 frames, but current frame is {cur_frame}')
        print(f'Interpolate the position embedding')
        T = ckpt_num_frame # checkpoint frame
        new_T = cur_frame # testing frame
        # interpolate
        P = int((n_position // cur_frame) ** 0.5) # testing size
        C = d_hid
        sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
        sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T)  # BHW, C, T
        sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
        sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
        sinusoid_table = sinusoid_table.flatten(1, 3)  # B, THW, C
        
    return sinusoid_table

In [23]:
vid_path = "./example/yoga.mp4"
# vid_path = "./demo/example/jesse_dance.mp4"

# num_frame = 8
num_frame = 16
# resolution = 384
resolution = 224
vid, msg = load_video(vid_path, num_segments=num_frame, return_msg=True, resolution=resolution)
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
model.vision_encoder.encoder.pos_embed = new_pos_emb

print(msg)
    
# The model expects inputs of shape: T x C x H x W
TC, H, W = vid.shape
video = vid.reshape(1, TC//3, 3, H, W).to("cuda:0")

img_list = []
with torch.no_grad():
    image_emb, _ = model.encode_img(video, "Watch the video and answer the question.")
#     image_emb, _ = model.encode_img(video, "")

img_list.append(image_emb)

HTML(f'<video alt="test" controls><source src="{vid_path}" type="video/mp4"></video>')

n_position: 3136
pre_n_position: 784
Pretraining uses 4 frames, but current frame is 16
Interpolate the position embedding
The video contains 16 frames sampled at 0.3, 0.9, 1.5, 2.2, 2.8, 3.4, 4.0, 4.7, 5.3, 5.9, 6.5, 7.2, 7.8, 8.4, 9.0, 9.6 seconds.




In [17]:
chat = EasyDict({
    "system": "",
    "roles": ("[INST]", "[/INST]"),
    "messages": [],
    "sep": ""
})

chat.messages.append([chat.roles[0], "<Video><VideoHere></Video> [/INST]"])
ask("Describe the video in details.", chat)

llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=512, print_res=True)[0]
print(llm_message)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The current flash attention version does not support sliding window attention, for a more memory efficient implementation make sure to upgrade flash-attn library.


[INST] <Video><VideoHere></Video> [/INST] [INST] Describe the video in details. [/INST]
A young woman practicing yoga on the roof of a house in the mountains. the woman is doing a yoga pose called the warrior. the woman is wearing a black top and gray pants. the woman is looking at the camera. the woman is standing on one leg and stretching her other leg out to the side. the woman is holding her arms out to the side. the woman is standing on a yoga mat. the woman is standing in front of a beautiful view of the mountains. the woman is practicing yoga in the morning. the woman is practicing yoga in the sunshine. the woman is practicing yoga in the fresh air. the woman is practicing yoga in the mountains. the woman is practicing yoga in the summer. the woman is practicing yoga in the fall. the woman is practicing yoga in the winter. the woman is practicing yoga in the spring. the woman is practicing yoga in the morning. the woman is practicing yoga in the afternoon. the woman is practicin

In [26]:
chat = EasyDict({
    "system": "",
    "roles": ("[INST]", "[/INST]"),
    "messages": [],
    "sep": ""
})

chat.messages.append([chat.roles[0], "<Video><VideoHere></Video> [/INST]"])
ask("What's she doing? (A) playing basterball. (B) doing homework. (C) practicing yoga. Only give the option", chat)

llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=512, print_res=True, answer_prompt="Best Option:(")[0]
print(llm_message)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[INST] <Video><VideoHere></Video> [/INST] [INST] What's she doing? (A) playing basterball. (B) doing homework. (C) practicing yoga. Only give the option [/INST] Best Option:(
C) practicing yoga.

a young woman practicing yoga on a rooftop with a view of the mountains.

a young woman practicing yoga on a rooftop with a view of the mountains.

a young woman practicing yoga on a rooftop with a view of the mountains.

a young woman practicing yoga on a rooftop with a view of the mountains.

a young woman practicing yoga on a rooftop with a view of the mountains.

a young woman practicing yoga on a rooftop with a view of the mountains.

a young woman practicing yoga on a rooftop with a view of the mountains.

a young woman practicing yoga on a rooftop with a view of the mountains.

a young woman practicing yoga on a rooftop with a view of the mountains.

a young woman practicing yoga on a rooftop with a view of the mountains.

a young woman practicing yoga on a rooftop with a view of the mo

In [None]:
# img_path = "./demo/example/dog.png"
img_path = "./demo/example/bear.jpg"
img = Image.open(img_path).convert('RGB')

plt.imshow(img)

resolution = 224
# resolution = 384
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2, cur_frame=1, ckpt_num_frame=1, pre_n_position=14*14)
model.vision_encoder.encoder.img_pos_embed = new_pos_emb

transform = transforms.Compose(
    [
        transforms.Resize(
            (resolution, resolution), interpolation=InterpolationMode.BICUBIC
        ),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ]
)

img = transform(img).unsqueeze(0).unsqueeze(0).cuda()
img_list = []
with torch.no_grad():
#     image_emb, _ = model.encode_img(img, "")
    image_emb, _ = model.encode_img(img, "Observe the image and answer the question.")
img_list.append(image_emb)

In [None]:
chat = EasyDict({
    "system": "",
    "roles": ("[INST]", "[/INST]"),
    "messages": [],
    "sep": ""
})

chat.messages.append([chat.roles[0], f"<Image><ImageHere></Image> [/INST]"])
ask("Describe the following image in details.", chat)

llm_message = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True,)[0]
print(llm_message)

In [None]:
def check_answer_egoschema(pred, qid):
    correct = 0
    answer_content = ans_dict[qid]['content'].lower()
    if answer_content[-1] == ".":
        answer_content = answer_content[:-1]
    if ans_dict[qid]['answer'].lower() in pred.lower():
        flag = True
        for kk in ["(A)", "(B)", "(C)", "(D)", "(E)"]:
            if kk != ans_dict[qid]['answer'].lower() and kk in pred.lower():
                flag = ans_dict
                break
        if flag:
            correct += 1
    elif answer_content in pred.lower():
        correct = 1
    elif answer_content.replace("a ", "") in pred.lower():
        correct = 1
    elif answer_content.replace("an ", "") in pred.lower():
        correct = 1
    return correct

def infer_egoschema(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False,
        num_segments=8,
    ):
    vid_path = os.path.join("your_data_path/egoschema/videos", data_sample['video'])
    video, _ = load_video(vid_path, num_segments=num_segments, return_msg=True)
    TC, H, W = video.shape
    video = video.reshape(1, TC//3, 3, H, W).to("cuda:0")
    
    video_list = []
    with torch.no_grad():
        if system_q:
            video_emb, _ = model.encode_img(video, system + data_sample['question'])
        else:
            video_emb, _ = model.encode_img(video, system)
    video_list.append(video_emb)

    chat = EasyDict({
        "system": system,
        "roles": ("[INST]", "[/INST]"),
        "messages": [],
        "sep": ""
    })

    chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
    
    if system_llm:
        prompt = system + data_sample['QA'][0]['q'] + question_prompt
    else:
        prompt = data_sample['QA'][0]['q'] + question_prompt
    
    ask(prompt, chat)

    llm_message = answer(
        conv=chat, model=model, do_sample=False, 
        img_list=video_list, max_new_tokens=100, 
        answer_prompt=answer_prompt, print_res=print_res
    )[0]
    # remove potential explanation
    llm_message = return_prompt + llm_message.strip().split('\n')[0]
    print(llm_message)
    print(f"GT: {data_sample['QA'][0]['a']}")
    return llm_message


import csv
# You can find the csv files in https://github.com/imagegridworth/IG-VLM/blob/main/data/multiple_choice_qa/EgoSchema.csv
with open("your_data_path/EgoSchema.csv", mode='r', encoding='utf-8') as file:
    reader = csv.reader(file)

    json_data = []
    ans_dict = {}
    
    for idx, msg in enumerate(reader):
        if idx == 0:
            print(msg)
            continue
            
        video = msg[1] + '.mp4'
        input_str = f"Question: {msg[3].capitalize()}\nOptions:\n"
    
        target_index = -1
        for i, candidate in enumerate(msg[5:]):
            option = chr(ord('A') + i)
            input_str += f"({option}) {candidate}\n"
            if candidate == msg[4]:
                target_index = i
            
        assert target_index != -1
        correct = chr(ord('A') + target_index)
        
        json_data.append({
            'video': video,
            "QA": [{
                "i": "",
                "q": input_str.strip(),
                "a": f"Answer: ({correct}) {msg[4]}",
            }]
        })

        ans_dict[idx - 1] = {
            'video': video,
            'answer': f"({correct})",
            'content': msg[4],
        }


#  position embedding
num_frame = 16
resolution = 224
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
model.vision_encoder.encoder.pos_embed = new_pos_emb

correct = 0
total = 0
total_num = len(json_data)

output = ""

for idx, example in enumerate(tqdm(json_data)):
    start = time.time()
    llm_message = infer_egoschema(
        example, 
        "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n", 
        question_prompt="\nOnly give the best option.", 
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=False,
        system_llm=False,
        num_segments=16
    )
    
    duration = time.time() - start
    output += (example["video"] + '\n')
    output += (llm_message + '\n')
    correct += check_answer_egoschema(llm_message, idx)
    total += 1
    print("Acc:", correct / total)
    print('-' * 20, f'{idx+1}/{total_num} done,', f'cost: {duration:.2f}s', '-' * 20)

with open("./demo/egoschema/your_prediction.txt", "w") as f:
    f.writelines(output)

In [None]:
# You can find the csv files in https://github.com/egoschema/EgoSchema/blob/main/questions.json
with open("your_data_path/EgoSchema/questions.json", "r") as f:
    full_data = json.load(f)

full_egoschema = []
for data in full_data:
    video = data['q_uid'] + '.mp4'
    input_str = f"Question: {data['question'].capitalize()}\nOptions:\n"

    for i, candidate in enumerate(['option 0', 'option 1', 'option 2', 'option 3', 'option 4']):
        option = chr(ord('A') + i)
        input_str += f"({option}) {data[candidate]}\n"
    
    full_egoschema.append({
        'q_uid': data['q_uid'],
        'video': video,
        "QA": [{
            "i": "",
            "q": input_str.strip(),
            "a": "",
        }]
    })


def infer_full_egoschema(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False,
        num_segments=8,
    ):
    vid_path = os.path.join("your_data_path/egoschema/videos", data_sample['video'])
    print(vid_path)
    video, _ = load_video(vid_path, num_segments=num_segments, return_msg=True)
    TC, H, W = video.shape
    video = video.reshape(1, TC//3, 3, H, W).to("cuda:0")
    
    video_list = []
    with torch.no_grad():
        if system_q:
            video_emb, _ = model.encode_img(video, system + data_sample['question'])
        else:
            video_emb, _ = model.encode_img(video, system)
    video_list.append(video_emb)

    chat = EasyDict({
        "system": system,
        "roles": ("[INST]", "[/INST]"),
        "messages": [],
        "sep": ""
    })

    chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
    
    if system_llm:
        prompt = system + data_sample['QA'][0]['q'] + question_prompt
    else:
        prompt = data_sample['QA'][0]['q'] + question_prompt
    
    ask(prompt, chat)

    llm_message = answer(
        conv=chat, model=model, do_sample=False, 
        img_list=video_list, max_new_tokens=100, 
        answer_prompt=answer_prompt, print_res=print_res
    )[0]
    # remove potential explanation
    llm_message = return_prompt + llm_message.strip().split('\n')[0]
    print(llm_message)
    return llm_message


#  position embedding
num_frame = 16
resolution = 224
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
model.vision_encoder.encoder.pos_embed = new_pos_emb


ans_dict = {}

for idx, example in enumerate(tqdm(full_egoschema)):
    start = time.time()
    llm_message = infer_full_egoschema(
        example, 
        "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n", 
        question_prompt="\nOnly give the best option.", 
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=False,
        system_llm=False,
        num_segments=16,
    )

    assert llm_message[0] == '(' and llm_message[2] == ')'
    ans = ord(llm_message[1]) - ord('A')
    assert ans in [0, 1, 2, 3, 4]
    ans_dict[example['q_uid']] = ans


with open("./demo/egoschema/your_prediction.json", "w") as f:
    json.dump(ans_dict, f)

# Then you can run https://github.com/egoschema/EgoSchema/blob/main/validate.py to get the score
# python3 validate.py --f ./your_prediction.json

In [None]:
import webvtt
import re

def clean_text(text):
    cleaned_text = re.sub(r'[^A-Za-z0-9\s]', '', text)
    return cleaned_text


def read_vtt_and_concatenate(file_path, tokenizer, max_len=4096):
    prev = ""
    subtitles = []
    for caption in webvtt.read(file_path):
        # Split the caption text into individual lines
        lines = caption.text.split('\n')
        for line in lines:
            # Clean the text and check for repetition
            line = clean_text(line)
            if prev != line and line:
                subtitles.append(line)
                prev = line

    # Join subtitles to check length
    full_text = ' '.join(subtitles)
    tokenized_ids = tokenizer(full_text, add_special_tokens=False).input_ids

    # If the tokenized length is within the limit, return the full text
    if len(tokenized_ids) <= max_len:
        return full_text

    # Otherwise, we need to trim the text to fit within the limit
    # We will keep the first half and the last half
    half_len = max_len // 2
    start_text = ' '.join(subtitles[:half_len])
    end_text = ' '.join(subtitles[-half_len:])
    
    # Re-tokenize to ensure the total length is within the limit
    start_tokenized_ids = tokenizer(start_text, add_special_tokens=False).input_ids
    end_tokenized_ids = tokenizer(end_text, add_special_tokens=False).input_ids

    # Adjust the lengths to fit within the max_len
    while len(start_tokenized_ids) + len(end_tokenized_ids) > max_len:
        if len(start_tokenized_ids) > len(end_tokenized_ids):
            start_tokenized_ids.pop()
        else:
            end_tokenized_ids.pop(0)
    
    # Combine the adjusted parts
    adjusted_text = tokenizer.decode(start_tokenized_ids) + ' ... ' + tokenizer.decode(end_tokenized_ids)
    
    return adjusted_text

    
class MME_dataset(Dataset):
    def __init__(self, data_prefix, anno_path, num_segments=16, resolution=224, max_subtitle_len=4096):
        self.data_prefix = data_prefix
        with open(anno_path, 'r') as f:
            self.data_list = json.load(f)
            
        self.num_segments = num_segments
        self.max_subtitle_len = max_subtitle_len
        
        # transform
        crop_size = resolution
        scale_size = resolution
        input_mean = [0.48145466, 0.4578275, 0.40821073]
        input_std = [0.26862954, 0.26130258, 0.27577711]
        self.transform = T.Compose([
            GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
            GroupCenterCrop(crop_size),
            Stack(),
            ToTorchFormatTensor(),
            GroupNormalize(input_mean, input_std) 
        ])
    
    def __str__(self):
        task_dict = {}
        total = 0
        for data in self.data_list:
            if data['duration_category'] not in ans_dict:
                task_dict[data['duration_category']] = {}
            for q in data['questions']:
                if q['task_type'] not in ans_dict[data['duration_category']]:
                    ans_dict[data['duration_category']][q['task_type']] = 0
                ans_dict[data['duration_category']][q['task_type']] += 1
                total += 1

        res = f"There are {len(self.data_list)} videos.\n"
        res += f"There are {total} QAs.\n"
        for k, v in task_dict.items():
            res += f"------{k}------\n"
            for kk, vv in task_dict.items():
                res += f"{kk}: {vv}\n"
                
        return res.rstrip()
        
    def __len__(self):
        return len(self.data_list)
    
    def get_index(self, bound, fps, max_frame, first_idx=0):
        if bound:
            start, end = bound[0], bound[1]
        else:
            start, end = -100000, 100000
        start_idx = max(first_idx, round(start * fps))
        end_idx = min(round(end * fps), max_frame)
        seg_size = float(end_idx - start_idx) / self.num_segments
        frame_indices = np.array([
            int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
            for idx in range(self.num_segments)
        ])
        return frame_indices

    def read_frame(self, video_path, bound=None):
        video_path = os.path.join(video_path, str(self.num_segments))
        
        if os.path.exists(video_path):
            frame_list = [p for p in os.listdir(video_path)]
        else:
            raise Exception
            
        images_group = list()
        
        for frame_name in frame_list:
            img = Image.open(os.path.join(video_path, frame_name))
            images_group.append(img)
        torch_imgs = self.transform(images_group)
        return torch_imgs
    
    def read_video(self, video_path, bound=None):
        vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
        max_frame = len(vr) - 1
        fps = float(vr.get_avg_fps())
        
        images_group = list()
        frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) 
        for frame_index in frame_indices:
            img = Image.fromarray(vr[frame_index].numpy())
            images_group.append(img)
        torch_imgs = self.transform(images_group)
        return torch_imgs

    def qa_template(self, data):
        question = f"Question: {data['question']}\n"
        question += "Options:\n"
        answer = data['answer']
        answer = f"({answer}) {data['choices'][ord(answer) - ord('A')][3:]}"
        for idx, c in enumerate(data['choices']):
            cur_choice, cur_text = c[0], c[3:]
            question += f"({cur_choice}) {cur_text}\n"
        question = question.rstrip()
        return question, answer

    def __getitem__(self, idx):
        video_name = self.data_list[idx]['url'].split("watch?v=")[1]
        video_path = os.path.join(self.data_prefix, "frames", video_name)

        # We store the videos with only 16 or 32 frames for testing,
        # since directly reading the whold videos cost a lot of time.
        # You can also read the whole video via self.read_video(video_path)
        torch_imgs = self.read_frame(video_path)
        duration_category = self.data_list[idx]['duration_category']
        qa_list = []
        for qa in self.data_list[idx]['questions']:
            qa_list.append(self.qa_template(qa))

        subtitle = ""
        try:
            subtitle_path = os.path.join(self.data_prefix, "subtitle", video_name + ".vtt")
            if os.path.exists(subtitle_path):
                subtitle = read_vtt_and_concatenate(subtitle_path, model.mistral_tokenizer, self.max_subtitle_len)
        except Exception:
            subtitle = ""
            print(f"Error for {subtitle_path}")
            
        return {
            'subtitle': subtitle,
            'video': torch_imgs, 
            'qa_list': qa_list,
            'duration_category': duration_category
        }
    

def infer_mme(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False,
        add_subtitle=False,
    ):
    assert system_q == False, "do not support system_q now"
    video = data_sample["video"]
    TC, H, W = video.shape
    video = video.reshape(1, TC//3, 3, H, W).to("cuda:0")
    
    video_list = []
    with torch.no_grad():
        if system_q:
            raise NotImplementedError
        else:
            video_emb, _ = model.encode_img(video, system)
    video_list.append(video_emb)

    pred_list = []
    gt_list = []
    for idx, qa in enumerate(data_sample['qa_list']):
        print(f"----------qa_{idx}---------", flush=True)
        chat = EasyDict({
            "system": system,
            "roles": ("[INST]", "[/INST]"),
            "messages": [],
            "sep": ""
        })
    
        if add_subtitle:
            if data_sample['subtitle'] != '':
                subtitle = f"This video's subtitles are listed below: {data_sample['subtitle']}"
                chat.messages.append([chat.roles[0], f"{subtitle}\n<Video><VideoHere></Video> [/INST]"])
            else:
                chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
        else:
            chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
    
        if system_llm:
            prompt = system + qa[0] + question_prompt
        else:
            prompt = qa[0] + question_prompt
        
        ask(prompt, chat)
    
        llm_message = answer(
            conv=chat, model=model, do_sample=False, 
            img_list=video_list, max_new_tokens=100, 
            answer_prompt=answer_prompt, print_res=print_res
        )[0]
        # remove potential explanation
        llm_message = return_prompt + llm_message.strip().split('\n')[0]
        print(f"Pred: {llm_message}", flush=True)
        print(f"GT: {qa[1]}", flush=True)
        pred_list.append(llm_message[1])
        gt_list.append(qa[1][1])
    return pred_list, gt_list

    
#  position embedding
num_frame = 16
resolution = 224
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
model.vision_encoder.encoder.pos_embed = new_pos_emb

data_dir = "your_data_path/videomme"
anno_path =  "your_data_path/Video-MME.json"
dataset = MME_dataset(
    data_dir, 
    anno_path, 
    num_segments=num_frame, resolution=resolution
)

with open(anno_path, 'r') as f:
    res_json_data = json.load(f)

save_path = "./demo/videomme/your_prediction"

correct = 0
total = 0
res_list = []
acc_dict = {}

for idx, example in enumerate(tqdm(dataset)):
    duration_category = example['duration_category']
    if duration_category not in acc_dict:
        acc_dict[duration_category] = [0, 0] # correct, total
    qa_count = len(example['qa_list'])
    acc_dict[duration_category][1] += qa_count
    total += qa_count
    pred_list, gt_list = infer_mme(
        example, 
        "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n",
        question_prompt="\nOnly give the best option.",
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=False,
        system_llm=True,
        # add_subtitle=True, # Comment this line to add subtitles, we use the whole subtitles by default.
    )
    res_list.append({
        'pred': pred_list,
        'gt': gt_list
    })
    qa_idx = 0
    for pred, gt in zip(pred_list, gt_list):
        if pred == gt:
            acc_dict[duration_category][0] += 1
            correct += 1
        res_json_data[idx]['questions'][qa_idx]['response'] = pred
        qa_idx += 1
    print(f"Part  Acc: {acc_dict[duration_category][0] / acc_dict[duration_category][1] * 100 :.2f}%")
    print(f"Total Acc: {correct / total * 100 :.2f}%")
    print('-' * 50, duration_category, '-' * 50)

with open(f"{save_path}.json", "w") as f:
    json.dump({
        "acc_dict": acc_dict,
        "res_list": res_list
    }, f)

with open(f"{save_path}_full.json", "w") as f:
    json.dump(res_json_data, f)

# Then you can run https://github.com/BradyFU/Video-MME/blob/main/evaluation/eval_your_results.py to get the score
# python3 eval.py --results_file your_prediction_full.json --video_duration_type short,medium,long