In [1]:
%load_ext autoreload
%autoreload 2

import os
os.chdir('../')

In [None]:
# In[1]# In[1]:
import os
import sys
import copy
import torch
import argparse
from transformers import StoppingCriteria, StoppingCriteriaList
from math import ceil
from PIL import Image
import numpy as np
import decord

decord.bridge.set_bridge('torch')
from torchvision.transforms.functional import InterpolationMode
import json
import time
import datetime
from tqdm import tqdm
import random

random.seed(1234)

from utils.config import Config
from utils.easydict import EasyDict
from transformers import StoppingCriteria, StoppingCriteriaList
from decord import VideoReader, cpu
import torchvision.transforms as T
from dataset.video_transforms import (
    GroupNormalize, GroupScale, GroupCenterCrop, 
    Stack, ToTorchFormatTensor
)
from peft import get_peft_model, LoraConfig, TaskType
from io import BytesIO
from models import *

try:
    from petrel_client.client import Client
    has_client = True
    print("Client on!")
except:
    has_client = False
    print("Client off!")

if has_client:
    client = Client('~/petreloss.conf')
else:
    client = None


# In[2]:


def get_args():
    parser = argparse.ArgumentParser()
    #与测试任务无关
    parser.add_argument('--model_type', default="VideoChat2_it4_mistral_LinearProAda")
    parser.add_argument('--model_dir', default="./download/parameters")
    parser.add_argument('--model_pth', default="timesuite")
    parser.add_argument('--output_dir', default="Please input model output dir!")
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--infer_clip_frames', type=int, default=8)
    
    args = parser.parse_args(args=[])    
    return args


args = get_args()
args_list_str = '\n' + '\n'.join([f'{k:<25}: {v}' for k, v in vars(args).items()])
print(args_list_str)


# In[5]:

# config_file = "configs/config_mistral.json"
config_file = args.model_dir+"/config.json"

cfg = Config.from_file(config_file)
cfg.model.use_lora = False
cfg.model.pretrained_path=None
cfg.device="cuda"

print("vision_encoder.num_frames:", cfg.model.vision_encoder.num_frames)
# cfg.model.vision_encoder.num_frames = 4

model_cls = eval(args.model_type)
# model = VideoChat2_it_mistral(config=cfg.model)
model = model_cls(config=cfg.model)


# add lora to run stage3 model
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False, 
    r=16, lora_alpha=32, lora_dropout=0.,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
         "gate_proj", "up_proj", "down_proj", "lm_head"
    ]
)
model.mistral_model = get_peft_model(model.mistral_model, peft_config)


state_dict = torch.load(args.model_dir+"/"+args.model_pth+".pth", "cpu")


if 'model' in state_dict.keys():
    msg = model.load_state_dict(state_dict['model'], strict=False)
else:
    msg = model.load_state_dict(state_dict, strict=False)
print(msg)

model = model.to(torch.device(cfg.device))
model = model.eval()

print('Model Initialization Finished')







def get_prompt(conv):
    ret = conv.system + conv.sep
    for role, message in conv.messages:
        if message:
            ret += role + " " + message + " " + conv.sep
        else:
            ret += role
    return ret


def get_prompt2(conv):
    ret = conv.system + conv.sep
    count = 0
    for role, message in conv.messages:
        count += 1
        if count == len(conv.messages):
            ret += role + " " + message
        else:
            if message:
                ret += role + " " + message + " " + conv.sep
            else:
                ret += role
    return ret


def get_context_emb(conv, model, img_list, answer_prompt=None, print_res=False):
    if answer_prompt:
        prompt = get_prompt2(conv)
    else:
        prompt = get_prompt(conv)
    if print_res:
        print("prompt:",prompt)
    if '<VideoHere>' in prompt:
        prompt_segs = prompt.split('<VideoHere>')
    else:
        prompt_segs = prompt.split('<ImageHere>')
    assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
    with torch.no_grad():
        seg_tokens = [
            model.mistral_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(cfg.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [model.mistral_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
#         seg_embs = [model.mistral_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
    mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
    mixed_embs = torch.cat(mixed_embs, dim=1)
    return mixed_embs


def ask(text, conv):
    conv.messages.append([conv.roles[0], text])
        

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False
    
    
def answer(conv, model, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None, print_res=False):
    stop_words_ids = [
        torch.tensor([2]).to(cfg.device),
        torch.tensor([29871, 2]).to(cfg.device)]  # '</s>' can be encoded in two different ways.
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    
    conv.messages.append([conv.roles[1], answer_prompt])
    embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt, print_res=print_res)
    with torch.no_grad():
        outputs = model.mistral_model.generate(
            inputs_embeds=embs,
            max_new_tokens=max_new_tokens,
            stopping_criteria=stopping_criteria,
            num_beams=num_beams,
            do_sample=do_sample,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
        )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
            output_token = output_token[1:]
    if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
            output_token = output_token[1:]
    output_text = model.mistral_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('</s>')[0]  # remove the stop sign </s>
#     output_text = output_text.split('[/INST]')[-1].strip()
    conv.messages[-1][1] = output_text + '</s>'
    return output_text, output_token.cpu().numpy()



def get_index(num_frames, num_segments):
    seg_size = float(num_frames - 1) / num_segments
    start = int(seg_size / 2)
    offsets = np.array([
        start + int(np.round(seg_size * idx)) for idx in range(num_segments)
    ])
    return offsets


def load_video(video_path, num_segments=8, return_msg=False, resolution=224):
    
    if client is not None and "s3" in video_path:
        video_bytes = client.get(video_path)
        assert(video_bytes is not None)
        vr = VideoReader(BytesIO(video_bytes), ctx=cpu(0), num_threads=1)
    else:
        vr = VideoReader(uri=video_path, ctx=cpu(0), num_threads=1)
    num_frames = len(vr)
    frame_indices = get_index(num_frames, num_segments)

    # transform
    crop_size = resolution
    scale_size = resolution
    input_mean = [0.48145466, 0.4578275, 0.40821073]
    input_std = [0.26862954, 0.26130258, 0.27577711]

    transform = T.Compose([
        GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
        GroupCenterCrop(crop_size),
        Stack(),
        ToTorchFormatTensor(),
        GroupNormalize(input_mean, input_std) 
    ])

    images_group = list()
    for frame_index in frame_indices:
        img = Image.fromarray(vr[frame_index].numpy())
        images_group.append(img)
    torch_imgs = transform(images_group)
    if return_msg:
        fps = float(vr.get_avg_fps())
        sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
        # " " should be added in the start and end
        msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds. "
        return torch_imgs, msg
    else:
        return torch_imgs
    

# In[8]:


def generate_videochat(vid_path, user_messages):
    
    num_frame = model.clip_frames
    tot_frames = model.total_frames
    resolution = cfg.model.vision_encoder.img_size
    
    vid, msg = load_video(vid_path, num_segments=tot_frames, return_msg=True, resolution=resolution)

    # The model expects inputs of shape: T x C x H x W
    TC, H, W = vid.shape
    video = vid.reshape(1, TC//3, 3, H, W).to(cfg.device)

    img_list = []
    with torch.no_grad():
        image_emb = model.encode_long_video(video,[msg,],"")
        print("Shape of long video embeds: ", image_emb.shape)
#         image_emb, _ = model.encode_img(video, "")
    img_list.append(image_emb)
    
    chat = EasyDict({
    "system": "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail. ",
    "roles": ("[INST]", "[/INST]"),
    "messages": [],
    "sep": ""
    })
    
    chat.messages.append([chat.roles[0], "<Video><VideoHere></Video> [/INST]"])
    ask(msg+user_messages, chat)

    llm_answer = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
    print("LLM answer:", llm_answer,"\n\n\n")
    
    return llm_answer, chat, img_list

In [None]:
vid_path = "./demo/example/yoga.mp4"
# vid_path = "./example/jesse_dance.mp4"

question="Describe the video in details."

generate_videochat(vid_path,question)

In [4]:
data_list = {
    "Action Sequence": ("action_sequence.json", "pnorm2:s3://star/Charades_v1_480/", "video", True), # has start & end
    "Action Prediction": ("action_prediction.json", "pnorm2:s3://star/Charades_v1_480/", "video", True), # has start & end
    "Action Antonym": ("action_antonym.json", "pnorm2:s3://ssv2-video/", "video", False),
    "Fine-grained Action": ("fine_grained_action.json", "pnorm:s3://Moments_in_Time_Raw/videos/", "video", False),
    "Unexpected Action": ("unexpected_action.json", "pnorm2:s3://funqa-test/test/", "video", False),
    "Object Existence": ("object_existence.json", "pnorm2:s3://clevrer/video_validation/", "video", False),
    "Object Interaction": ("object_interaction.json", "pnorm2:s3://star/Charades_v1_480/", "video", True), # has start & end
    "Object Shuffle": ("object_shuffle.json", "pnorm2:s3://perception/videos/", "video", False),
    "Moving Direction": ("moving_direction.json", "pnorm2:s3://clevrer/video_validation/", "video", False),
    "Action Localization": ("action_localization.json", "pnorm2:s3://sta/sta_video/", "video", True),  # has start & end
    "Scene Transition": ("scene_transition.json", "pnorm2:s3://scene-qa/video/", "video", False),
    "Action Count": ("action_count.json", "pnorm2:s3://perception/videos/", "video", False),
    "Moving Count": ("moving_count.json", "pnorm2:s3://clevrer/video_validation/", "video", False),
    "Moving Attribute": ("moving_attribute.json", "pnorm2:s3://clevrer/video_validation/", "video", False),
    "State Change": ("state_change.json", "pnorm2:s3://perception/videos/", "video", False),
    "Fine-grained Pose": ("fine_grained_pose.json", "pnorm2:s3://nturgbd/", "video", False),
    "Character Order": ("character_order.json", "pnorm2:s3://perception/videos/", "video", False),
    "Egocentric Navigation": ("egocentric_navigation.json", "pnorm2:s3://vlnqa/", "video", False),
    "Episodic Reasoning": ("episodic_reasoning.json", "pnorm2:s3://tvqa/frames_fps3_hq/", "frame", True),  # has start & end, read frame
    "Counterfactual Inference": ("counterfactual_inference.json", "pnorm2:s3://clevrer/video_validation/", "video", False),
}

data_dir = "./download/datasets/mvbench"

In [5]:
from torch.utils.data import Dataset
import io
from io import BytesIO
from petrel_client.client import Client
from decord import VideoReader, cpu
client = Client('~/petreloss.conf', enable_mc=False)

In [6]:
num_frame = model.clip_frames
tot_frames = model.total_frames
resolution = cfg.model.vision_encoder.img_size

In [7]:
class MVBench_dataset(Dataset):
    def __init__(self, data_dir, data_list, num_segments=8, resolution=224):
        self.data_list = []
        for k, v in data_list.items():
            with open(os.path.join(data_dir, v[0]), 'r') as f:
                json_data = json.load(f)
            for data in json_data:
                self.data_list.append({
                    'task_type': k,
                    'prefix': v[1],
                    'data_type': v[2],
                    'bound': v[3],
                    'data': data
                })
        
        self.decord_method = {
            'video': self.read_video,
            'gif': self.read_gif,
            'frame': self.read_frame,
        }
        
        self.num_segments = num_segments
        
        # transform
        crop_size = resolution
        scale_size = resolution
        input_mean = [0.48145466, 0.4578275, 0.40821073]
        input_std = [0.26862954, 0.26130258, 0.27577711]
        self.transform = T.Compose([
            GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
            GroupCenterCrop(crop_size),
            Stack(),
            ToTorchFormatTensor(),
            GroupNormalize(input_mean, input_std) 
        ])
    
    def __str__(self):
        len_list = {}
        option_list = {}
        for data in self.data_list:
            if data['task_type'] not in len_list:
                len_list[data['task_type']] = 0
            len_list[data['task_type']] += 1
            if data['task_type'] not in option_list:
                option_list[data['task_type']] = 0
            option_list[data['task_type']] += len(data['data']['candidates'])
        
        correct = 0
        total = 0
        res = f"There are {len(self.data_list)} videos as follow:\n"
        for k, v in len_list.items():
            correct += len_list[k]
            total += option_list[k]
            res += f"{v} for {k} ({option_list[k]} options => {len_list[k]/option_list[k]*100:.2f}%)\n"
            correct = correct + 1 / option_list[k]
        res += f"Total random accuracy: {correct/total*100:.2f}%"
        return res.rstrip()
        
    def __len__(self):
        return len(self.data_list)
    
    def get_index(self, bound, fps, max_frame, first_idx=0):
        if bound:
            start, end = bound[0], bound[1]
        else:
            start, end = -100000, 100000
        start_idx = max(first_idx, round(start * fps))
        end_idx = min(round(end * fps), max_frame)
        seg_size = float(end_idx - start_idx) / self.num_segments
        frame_indices = np.array([
            int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
            for idx in range(self.num_segments)
        ])
        return frame_indices
    
    def read_video(self, video_path, bound=None, return_time=True):
        if "s3://" in video_path:
            video_bytes = client.get(video_path)
            vr = VideoReader(io.BytesIO(video_bytes), ctx=cpu(0), num_threads=1)
        else:
            vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
        max_frame = len(vr) - 1
        fps = float(vr.get_avg_fps())
        images_group = list()
        frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) 
        for frame_index in frame_indices:
            img = Image.fromarray(vr[frame_index].numpy())
            images_group.append(img)
        torch_imgs = self.transform(images_group)
        if return_time:
            sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
            # " " should be added in the start and end
            msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds. "
            return torch_imgs, msg
        else:
            return torch_imgs
    
    def read_gif(self, video_path, bound=None, fps=25, return_time=True):
        if "s3://" in video_path:
            video_bytes = client.get(video_path)
            gif = imageio.get_reader(io.BytesIO(video_bytes))
        else:
            gif = imageio.get_reader(video_path)
        max_frame = len(gif) - 1
        
        images_group = list()
        frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) 
        for index, frame in enumerate(gif):
            if index in frame_indices:
                img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
                img = Image.fromarray(img)
                images_group.append(img)
        torch_imgs = self.transform(images_group)
        if return_time:
            sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
            # " " should be added in the start and end
            msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds. "
            return torch_imgs, msg
        else:
            return torch_imgs
    
    def read_frame(self, video_path, bound=None, fps=3, return_time=True):
        if os.path.exists(video_path):
            max_frame = len(os.listdir(video_path))
        else:
            max_frame = len([k for k in client.list(video_path)])
        images_group = list()
        frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1
        for frame_index in frame_indices:
            if "s3://" in video_path:
                img_bytes = client.get(os.path.join(video_path, f"{frame_index:05d}.jpg"))
                img = Image.open(io.BytesIO(img_bytes))
            else:
                img = Image.open(os.path.join(video_path, f"{frame_index:05d}.jpg"))
            images_group.append(img)
        torch_imgs = self.transform(images_group)
        if return_time:
            sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
            # " " should be added in the start and end
            msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds. "
            return torch_imgs, msg
        else:
            return torch_imgs

    def qa_template(self, data):
        question = f"Question: {data['question']}\n"
        question += "Options:\n"
        answer = data['answer']
        answer_idx = -1
        for idx, c in enumerate(data['candidates']):
            question += f"({chr(ord('A') + idx)}) {c}\n"
            if c == answer:
                answer_idx = idx
        question = question.rstrip()
        answer = f"({chr(ord('A') + answer_idx)}) {answer}"
        return question, answer

    def __getitem__(self, idx):
        decord_method = self.decord_method[self.data_list[idx]['data_type']]
        bound = None
        if self.data_list[idx]['bound']:
            bound = (
                self.data_list[idx]['data']['start'],
                self.data_list[idx]['data']['end'],
            )
        video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])
        torch_imgs, time_inst = decord_method(video_path, bound)
        question, answer = self.qa_template(self.data_list[idx]['data'])
            
        return {
            'video': torch_imgs, 
            'question': question, 
            'answer': answer,
            'task_type': self.data_list[idx]['task_type'],
            'time': time_inst,
        }

In [8]:
dataset = MVBench_dataset(data_dir, data_list, num_segments=tot_frames, resolution=resolution)

In [9]:
def infer_mvbench(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False
    ):
    video = data_sample["video"]
    msg=data_sample["time"]
    TC, H, W = video.shape
    video = video.reshape(1, TC//3, 3, H, W).to(cfg.device)
    video_list = []
    with torch.no_grad():
        video_emb = model.encode_long_video(video,[msg,],"")
    video_list.append(video_emb)
#     video_list.append(torch.zeros_like(video_emb))

    chat = EasyDict({
        "system": system,
        "roles": ("[INST]", "[/INST]"),
        "messages": [],
        "sep": ""
    })

    chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
    
    if system_llm:
        prompt = msg + system + data_sample['question'] + question_prompt
    else:
        prompt = msg + data_sample['question'] + question_prompt
    
    ask(prompt, chat)

    llm_message = answer(
        conv=chat, model=model, do_sample=False, 
        img_list=video_list, max_new_tokens=100, 
        answer_prompt=answer_prompt, print_res=print_res
    )[0]
    # remove potential explanation
    llm_message = return_prompt + llm_message.strip().split('\n')[0]
    print(llm_message)
    print(f"GT: {data_sample['answer']}")
    return llm_message

In [10]:
def check_ans(pred, gt):
    flag = False
    
    pred_list = pred.lower().split(' ')
    pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:])
    gt_list = gt.lower().split(' ')
    gt_option, gt_content = gt_list[0], " ".join(gt_list[1:])
    if gt_content[-1] == '.':
        gt_content = gt_content[:-1]
    
    if pred_option.replace('.', '') in gt_option:
        flag = True
    elif gt_option in pred_option:
        flag = True
#     elif gt_content in pred_content:
#         flag = True
#     elif gt_content.replace("a ", "") in pred_content:
#         flag = True
#     elif gt_content.replace("an ", "") in pred_content:
#         flag = True
        
    return flag

In [None]:
correct = 0
total = 0
res_list = []
acc_dict = {}

for example in tqdm(dataset):
    task_type = example['task_type']
    if task_type not in acc_dict:
        acc_dict[task_type] = [0, 0] # correct, total
    acc_dict[task_type][1] += 1
    total += 1
    pred = infer_mvbench(
        example, 
#         "Carefully observe the video and choose the best option for the question. ", 
#         "Carefully watch the video and pay attention to the cause, sequence of events, and object details and movements. Based on your observations, select the best option that accurately addresses the question. ",  # newPrompt
#         "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question. ", # newPrompt2
        "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n", # newPrompt2
#         question_prompt="\nOnly give the best option without any explanation.",
#         question_prompt="\nThink it step by step. Only give the best option without any explanation.", # prompt2
        question_prompt="\nOnly give the best option.",  # prompt3
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=True,
        system_llm=True,
    )
    gt = example['answer']
    res_list.append({
        'pred': pred,
        'gt': gt
    })
    if check_ans(pred=pred, gt=gt):
        acc_dict[task_type][0] += 1
        correct += 1
    print(f"Part  Acc: {acc_dict[task_type][0] / acc_dict[task_type][1] * 100 :.2f}%")
    print(f"Total Acc: {correct / total * 100 :.2f}%")
    print('-' * 50, task_type, '-' * 50)

In [None]:
save_path = args.model_dir+"/MVBench_test_"+args.model_pth+"/result"
os.makedirs(os.path.dirname(save_path), exist_ok=True)

with open(f"{save_path}.json", "w") as f:
    json.dump({
        "acc_dict": acc_dict,
        "res_list": res_list
    }, f)

In [None]:
final_res = dict()
correct = 0
total = 0
for k, v in acc_dict.items():
    final_res[k] = v[0] / v[1] * 100
    correct += v[0]
    total += v[1]    
final_res['Avg'] = correct / total * 100

print(final_res)

# with open("upload_leaderboard.json", "w") as f:
#     json.dump(final_res, f)

In [None]:
acc_path = args.model_dir+"/MVBench_test_"+args.model_pth+"/acc"
out = "AS	AP	AA	FA	UA	OE	OI	OS	MD	AL	ST	AC	MC	MA	SC	FP	CO	EN	ER	CI	Avg"
out1 = "AS		AP		AA		FA		UA		OE		OI		OS		MD		AL		ST		AC		MC		MA		SC		FP		CO		EN		ER		CI		Avg"
out2 = ""
correct = 0
total = 0
with open(f"{save_path}.json", "r") as f:
    json_data = json.load(f)
    for k, v in json_data["acc_dict"].items():
        correct += v[0]
        total += v[1]    
        out2 += f"{v[0]/v[1]*100:.2f}\t"
out2 += f"{correct/total*100:.2f}"
print(out)
print(out2)

with open(f"{acc_path}.txt", "w") as f:
    f.write(out1+"\n")
    f.write(out2)



In [None]:
def check_answer_egoschema(pred, qid):
    correct = 0
    answer_content = ans_dict[qid]['content'].lower()
    if answer_content[-1] == ".":
        answer_content = answer_content[:-1]
    if ans_dict[qid]['answer'].lower() in pred.lower():
        flag = True
        for kk in ["(A)", "(B)", "(C)", "(D)", "(E)"]:
            if kk != ans_dict[qid]['answer'].lower() and kk in pred.lower():
                flag = ans_dict
                break
        if flag:
            correct += 1
    elif answer_content in pred.lower():
        correct = 1
    elif answer_content.replace("a ", "") in pred.lower():
        correct = 1
    elif answer_content.replace("an ", "") in pred.lower():
        correct = 1
    return correct

def infer_egoschema(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False,
        num_segments=8,
    ):
    vid_path = os.path.join("shdd:s3://egoschema/videos", data_sample['video'])
    print(vid_path)
    video, msg = load_video(vid_path, num_segments=num_segments, return_msg=True)
    TC, H, W = video.shape
    video = video.reshape(1, TC//3, 3, H, W).to(cfg.device)
    
    video_list = []
    with torch.no_grad():
        video_emb = model.encode_long_video(video,[msg,],"")
    video_list.append(video_emb)
#     video_list.append(torch.zeros_like(video_emb))

    chat = EasyDict({
        "system": system,
        "roles": ("[INST]", "[/INST]"),
        "messages": [],
        "sep": ""
    })

    chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
    
    if system_llm:
        prompt = msg + system + data_sample['QA'][0]['q'] + question_prompt
    else:
        prompt = msg + data_sample['QA'][0]['q'] + question_prompt
    
    ask(prompt, chat)

    llm_message = answer(
        conv=chat, model=model, do_sample=False, 
        img_list=video_list, max_new_tokens=100, 
        answer_prompt=answer_prompt, print_res=print_res
    )[0]
    # remove potential explanation
    llm_message = return_prompt + llm_message.strip().split('\n')[0]
    print(llm_message)
    print(f"GT: {data_sample['QA'][0]['a']}")
    return llm_message

In [None]:
import csv
with open("./download/datasets/egoschema/EgoSchema.csv", mode='r', encoding='utf-8') as file:
    reader = csv.reader(file)

    json_data = []
    ans_dict = {}
    
    for idx, msg in enumerate(reader):
        if idx == 0:
            print(msg)
            continue
            
        video = msg[1] + '.mp4'
        input_str = f"Question: {msg[3].capitalize()}\nOptions:\n"
    
        target_index = -1
        for i, candidate in enumerate(msg[5:]):
            option = chr(ord('A') + i)
            input_str += f"({option}) {candidate}\n"
            if candidate == msg[4]:
                target_index = i
            
        assert target_index != -1
        correct = chr(ord('A') + target_index)
        
        json_data.append({
            'video': video,
            "QA": [{
                "i": "",
                "q": input_str.strip(),
                "a": f"Answer: ({correct}) {msg[4]}",
            }]
        })

        ans_dict[idx - 1] = {
            'video': video,
            'answer': f"({correct})",
            'content': msg[4],
        }

In [None]:
#  position embedding
# num_frame = 16
# resolution = 224
# new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
# model.vision_encoder.encoder.pos_embed = new_pos_emb

correct = 0
total = 0
total_num = len(json_data)

output = ""

for idx, example in enumerate(tqdm(json_data)):
    start = time.time()
    llm_message = infer_egoschema(
        example, 
        "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n", 
        question_prompt="\nOnly give the best option.", 
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=True,
        system_llm=False,
        num_segments=tot_frames
    )
    
    duration = time.time() - start
    output += (example["video"] + '\n')
    output += (llm_message + '\n')
    correct += check_answer_egoschema(llm_message, idx)
    total += 1
    print("Acc:", correct / total)
    print('-' * 20, f'{idx+1}/{total_num} done,', f'cost: {duration:.2f}s', '-' * 20)

In [None]:
save_path = args.model_dir+"/Egoschema_test_"+args.model_pth+"/result_subset"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path + ".txt", "a") as f:
    f.writelines(output)
    
acc_path = args.model_dir+"/Egoschema_test_"+args.model_pth+"/acc_subset"
with open(f"{acc_path}.txt", "w") as f:
    f.write("Acc: " + str(correct / total))

In [None]:
with open("./download/datasets/egoschema/questions.json", "r") as f:
    full_data = json.load(f)

full_egoschema = []
for data in full_data:
    video = data['q_uid'] + '.mp4'
    input_str = f"Question: {data['question'].capitalize()}\nOptions:\n"

    for i, candidate in enumerate(['option 0', 'option 1', 'option 2', 'option 3', 'option 4']):
        option = chr(ord('A') + i)
        input_str += f"({option}) {data[candidate]}\n"
    
    full_egoschema.append({
        'q_uid': data['q_uid'],
        'video': video,
        "QA": [{
            "i": "",
            "q": input_str.strip(),
            "a": "",
        }]
    })


def infer_full_egoschema(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False,
        num_segments=8,
    ):
    vid_path = os.path.join("shdd:s3://egoschema/videos", data_sample['video'])
    print(vid_path)
    video, msg = load_video(vid_path, num_segments=num_segments, return_msg=True)
    TC, H, W = video.shape
    video = video.reshape(1, TC//3, 3, H, W).to(cfg.device)
    
    video_list = []
    with torch.no_grad():
        video_emb = model.encode_long_video(video,[msg,],"")
    video_list.append(video_emb)

    chat = EasyDict({
        "system": system,
        "roles": ("[INST]", "[/INST]"),
        "messages": [],
        "sep": ""
    })

    chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
    
    if system_llm:
        prompt = msg + system + data_sample['QA'][0]['q'] + question_prompt
    else:
        prompt = msg + data_sample['QA'][0]['q'] + question_prompt
    
    ask(prompt, chat)

    llm_message = answer(
        conv=chat, model=model, do_sample=False, 
        img_list=video_list, max_new_tokens=100, 
        answer_prompt=answer_prompt, print_res=print_res
    )[0]
    # remove potential explanation
    llm_message = return_prompt + llm_message.strip().split('\n')[0]
    print(llm_message)
    return llm_message


#  position embedding
# num_frame = 16
# resolution = 224
# new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
# model.vision_encoder.encoder.pos_embed = new_pos_emb


ans_dict = {}

for idx, example in enumerate(tqdm(full_egoschema)):
    start = time.time()
    llm_message = infer_full_egoschema(
        example, 
        "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n", 
        question_prompt="\nOnly give the best option.", 
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=True,
        system_llm=False,
        num_segments=tot_frames,
    )

    assert llm_message[0] == '(' and llm_message[2] == ')'
    ans = ord(llm_message[1]) - ord('A')
    assert ans in [0, 1, 2, 3, 4]
    ans_dict[example['q_uid']] = ans

In [None]:
save_path = args.model_dir+"/Egoschema_test_"+args.model_pth+"/result"
with open(save_path + ".json", "w") as f:
    json.dump(ans_dict, f)

# Then you can run https://github.com/egoschema/EgoSchema/blob/main/validate.py to get the score
# python3 validate.py --f ./your_prediction.json

In [None]:
import pysubs2
import re
from torchvision import transforms
from torchvision.transforms import PILToTensor

def clean_text(text):
    cleaned_text = re.sub(r'[^A-Za-z0-9\s]\[\]', '', text)
    return cleaned_text

def read_vtt_and_concatenate(file_path, tokenizer, max_len=4096):
    subs = pysubs2.load(file_path, encoding="utf-8")
        
    prev = ""
    subtitles = []
    for caption in subs:
        # Split the caption text into individual lines
        lines = caption.text.split('\n')
        for line in lines:
            # Clean the text and check for repetition
            line = clean_text(line)
            if prev != line and line:
                subtitles.append(line)
                prev = line

    # Join subtitles to check length
    full_text = ' '.join(subtitles)
    tokenized_ids = tokenizer(full_text, add_special_tokens=False).input_ids

    # If the tokenized length is within the limit, return the full text
    if len(tokenized_ids) <= max_len:
        return full_text

    # Otherwise, we need to trim the text to fit within the limit
    # We will keep the first half and the last half
    half_len = max_len // 2
    start_text = ' '.join(subtitles[:half_len])
    end_text = ' '.join(subtitles[-half_len:])
    
    # Re-tokenize to ensure the total length is within the limit
    start_tokenized_ids = tokenizer(start_text, add_special_tokens=False).input_ids
    end_tokenized_ids = tokenizer(end_text, add_special_tokens=False).input_ids

    # Adjust the lengths to fit within the max_len
    while len(start_tokenized_ids) + len(end_tokenized_ids) > max_len:
        if len(start_tokenized_ids) > len(end_tokenized_ids):
            start_tokenized_ids.pop()
        else:
            end_tokenized_ids.pop(0)
    
    # Combine the adjusted parts
    adjusted_text = tokenizer.decode(start_tokenized_ids) + ' ... ' + tokenizer.decode(end_tokenized_ids)
    
    return adjusted_text

In [None]:
class MME_dataset(Dataset):
    def __init__(
        self, 
        data_prefix="shdd:s3://VideoMME_0629/processed_1fps",
        subtitle_prefix="./download/datasets/videomme/subtitle_0629",
        anno_path="./download/datasets/videomme/Video-MME_0629.json",
        frame_dict_path="./download/datasets/videomme/video_mme_1fps.json",
        num_segments=16, 
        stride=0, # if stride >= 1, will return all frames according to FPS (1/stride), else return partial frames
        resolution=224, 
        max_subtitle_len=4096, # max_tokens for subtitle
    ):
        self.data_prefix = data_prefix
        self.subtitle_prefix = subtitle_prefix
        with open(anno_path, 'r') as f:
            self.data_list = json.load(f)
        with open(frame_dict_path, 'r') as f:
            self.frame_dict = json.load(f)
        
        self.num_segments = num_segments
        self.stride = stride
        self.resolution = resolution
        self.max_subtitle_len = max_subtitle_len

        # transform
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        self.transform = transforms.Compose([
            transforms.Lambda(lambda x: x.float().div(255.0)),
            transforms.Normalize(mean, std)
        ])
    
    def __str__(self):
        task_dict = {}
        total = 0
        for data in self.data_list:
            if data['duration_category'] not in ans_dict:
                task_dict[data['duration_category']] = {}
            for q in data['questions']:
                if q['task_type'] not in ans_dict[data['duration_category']]:
                    ans_dict[data['duration_category']][q['task_type']] = 0
                ans_dict[data['duration_category']][q['task_type']] += 1
                total += 1

        res = f"There are {len(self.data_list)} videos.\n"
        res += f"There are {total} QAs.\n"
        for k, v in task_dict.items():
            res += f"------{k}------\n"
            for kk, vv in task_dict.items():
                res += f"{kk}: {vv}\n"
                
        return res.rstrip()
        
    def __len__(self):
        return len(self.data_list)
    
    def get_index(self, max_frame):
        start_idx = 0
        end_idx = max_frame - 1
        seg_size = float(max_frame - start_idx) / self.num_segments
        frame_indices = np.array([
            max(int(start_idx + (seg_size / 2) + np.round(seg_size * idx)), end_idx)
            for idx in range(self.num_segments)
        ])
        return frame_indices

    def get_time_stamp(self, video_path):
        timestamp = video_path.split("_")[-1].split(".jpg")[0]
        minutes, seconds = map(int, timestamp.split(":"))
        total_seconds = minutes * 60 + seconds
        return total_seconds

    def read_frame(self, video_name):
        full_frame_list = []
        for p in self.frame_dict[video_name]:
            full_frame_list.append(os.path.join(self.data_prefix, video_name, 'frames', p))
            
        images_group = list()
        time_list = []
        if self.stride >= 1 and (len(full_frame_list) / self.stride) > self.num_segments:
            frame_list = full_frame_list[::self.stride]
        else:
            # if len(full_frame_list) < self.num_segments: # return all frames if not, seem to be a little lower
            #     frame_list = full_frame_list
            # else:
            frame_indices = get_index(len(full_frame_list), self.num_segments)
            frame_list = [full_frame_list[idx] for idx in frame_indices]
        # print(frame_list)
        
        for frame_path in frame_list:
            time_stamp = self.get_time_stamp(frame_path)
            time_list.append(time_stamp)
            if "s3://" in frame_path:
                img_bytes = client.get(frame_path)
                img = Image.open(io.BytesIO(img_bytes))
            else:
                img = Image.open(frame_path)
            img = img.resize((resolution, resolution))
            img = PILToTensor()(img).unsqueeze(0)
            images_group.append(img)
        torch_imgs = self.transform(torch.vstack(images_group))
        sec = ", ".join(map(str, time_list))
        time_instruction = f"The video contains {len(time_list)} frames sampled at {sec} seconds. "
        print(torch_imgs.shape)
        return torch_imgs, time_instruction

    def qa_template(self, data):
        question = f"Question: {data['question']}\n"
        question += "Options:\n"
        answer = data['answer']
        answer = f"({answer}) {data['options'][ord(answer) - ord('A')][3:]}"
        for idx, c in enumerate(data['options']):
            cur_choice, cur_text = c[0], c[3:]
            question += f"({cur_choice}) {cur_text}\n"
        question = question.rstrip()
        return question, answer

    def __getitem__(self, idx):
        video_name = self.data_list[idx]['videoID']
        torch_imgs, time_instruction = self.read_frame(video_name)
        duration_category = self.data_list[idx]['duration']
        qa_list = []
        for qa in self.data_list[idx]['questions']:
            qa_list.append(self.qa_template(qa))

        subtitle = ""
        try:
            subtitle_path = os.path.join(self.subtitle_prefix, video_name + ".srt")
            if os.path.exists(subtitle_path):
                subtitle = read_vtt_and_concatenate(subtitle_path, model.mistral_tokenizer, self.max_subtitle_len)
        except Exception:
            subtitle = ""
            print(f"Error for {subtitle_path}")
            
        return {
            'subtitle': subtitle,
            'video': torch_imgs, 
            'time_instruction': time_instruction,
            'qa_list': qa_list,
            'duration_category': duration_category
        }

    
def infer_mme(
        data_sample, system="", 
        question_prompt='', # add in the end of question
        answer_prompt=None, # add in the begining of answer
        return_prompt='',  # add in the begining of return message
        system_q=False, # whether add question in the system prompt for QFormer
        print_res=True,
        system_llm=False,
        no_qformer_instruction=False,
        qformer_instruction=None,
        add_subtitle=False,
    ):
    assert system_q == False, "do not support system_q now"
    video = data_sample["video"]
    msg = data_sample["time_instruction"]
    T_, C, H, W = video.shape
    video = video.reshape(1, T_, C, H, W).to(cfg.device)
    
    video_list = []
    with torch.no_grad():
        video_emb = model.encode_long_video(video,[msg,],"")
    video_list.append(video_emb[0].unsqueeze(0))
    print(video_list[0].shape)

    pred_list = []
    gt_list = []
    for idx, qa in enumerate(data_sample['qa_list']):
        print(f"----------qa_{idx}---------", flush=True)
        chat = EasyDict({
            "system": system,
            "roles": ("[INST]", "[/INST]"),
            "messages": [],
            "sep": ""
        })
        
        if add_subtitle and data_sample['subtitle'] != '':
            subtitle = f"This video's subtitles are listed below: {data_sample['subtitle']}"
            chat.messages.append([chat.roles[0], f"{subtitle}\n<Video><VideoHere></Video> [/INST]"])
        else:
            chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video> [/INST]"])
        
        if system_llm:
            prompt = msg + system + qa[0] + question_prompt
        else:
            prompt = msg + qa[0] + question_prompt
        
        ask(prompt, chat)
    
        llm_message = answer(
            conv=chat, model=model, do_sample=False, 
            img_list=video_list, max_new_tokens=256, 
            answer_prompt=answer_prompt, print_res=print_res
        )[0]
        # remove potential explanation
        llm_message = return_prompt + llm_message.strip().split('\n')[0]
        print(f"Pred: {llm_message}", flush=True)
        print(f"GT: {qa[1]}", flush=True)
        pred_list.append(llm_message[1])
        gt_list.append(qa[1][1])
    return pred_list, gt_list

In [None]:
stride = 0
max_subtitle_len=8192
data_prefix = "shdd:s3://VideoMME_0629/processed_1fps"
anno_path = "./download/datasets/videomme/Video-MME_0629.json"
frame_dict_path = "./download/datasets/videomme/video_mme_1fps.json"
dataset = MME_dataset(
    data_prefix=data_prefix, 
    anno_path=anno_path, 
    frame_dict_path=frame_dict_path,
    num_segments=tot_frames, 
    stride=stride,
    resolution=resolution,
    max_subtitle_len=max_subtitle_len,
)

with open(anno_path, 'r') as f:
    res_json_data = json.load(f)
    

    
# Only Vision Information

correct = 0
total = 0
res_list = []
acc_dict = {}

for idx, example in enumerate(tqdm(dataset)):
    duration_category = example['duration_category']
    if duration_category not in acc_dict:
        acc_dict[duration_category] = [0, 0] # correct, total
    qa_count = len(example['qa_list'])
    acc_dict[duration_category][1] += qa_count
    total += qa_count

    
    pred_list, gt_list = infer_mme(
        example, 
        "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n", # newPrompt2
        question_prompt="\nOnly give the best option.",  # prompt3
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=True,
        system_llm=True,
        add_subtitle=False,
    )
    
    res_list.append({
        'pred': pred_list,
        'gt': gt_list
    })
    qa_idx = 0
    for pred, gt in zip(pred_list, gt_list):
        if pred == gt:
            acc_dict[duration_category][0] += 1
            correct += 1
        res_json_data[idx]['questions'][qa_idx]['response'] = pred
        qa_idx += 1
    print(f"Part  Acc: {acc_dict[duration_category][0] / acc_dict[duration_category][1] * 100 :.2f}%")
    print(f"Total Acc: {correct / total * 100 :.2f}%")
    print('-' * 50, duration_category, '-' * 50)
    

    
    
save_path = args.model_dir+"/VideoMME_test_"+args.model_pth+"/Wo_result"
acc_path = args.model_dir+"/VideoMME_test_"+args.model_pth+"/Wo_acc"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
category_list = ["short", "medium", "long"]

with open(f"{save_path}.json", "w") as f:
    json.dump({
        "acc_dict": acc_dict,
        "res_list": res_list
    }, f)

with open(f"{save_path}_full.json", "w") as f:
    json.dump(res_json_data, f)

with open(f"{acc_path}.txt", "w") as f:
    f.write("Acc: " + str(round(100 * correct / total, 1)) + "\n")
    for duration_category in category_list:
        f.write(duration_category + " Acc: " + str(round(100 * acc_dict[duration_category][0] / acc_dict[duration_category][1], 1)) + "\n")
        
        
        

        
        
# With Subtitle
correct = 0
total = 0
res_list = []
acc_dict = {}

for idx, example in enumerate(tqdm(dataset)):
    duration_category = example['duration_category']
    if duration_category not in acc_dict:
        acc_dict[duration_category] = [0, 0] # correct, total
    qa_count = len(example['qa_list'])
    acc_dict[duration_category][1] += qa_count
    total += qa_count

    
    pred_list, gt_list = infer_mme(
        example, 
        "Carefully watch the video, read the related subtitles and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n", # newPrompt2
        question_prompt="\nOnly give the best option.",  # prompt3
        answer_prompt="Best option:(",
        return_prompt='(',
        system_q=False,
        print_res=True,
        system_llm=True,
        add_subtitle=True,
    )
    
    res_list.append({
        'pred': pred_list,
        'gt': gt_list
    })
    qa_idx = 0
    for pred, gt in zip(pred_list, gt_list):
        if pred == gt:
            acc_dict[duration_category][0] += 1
            correct += 1
        res_json_data[idx]['questions'][qa_idx]['response'] = pred
        qa_idx += 1
    print(f"Part  Acc: {acc_dict[duration_category][0] / acc_dict[duration_category][1] * 100 :.2f}%")
    print(f"Total Acc: {correct / total * 100 :.2f}%")
    print('-' * 50, duration_category, '-' * 50)
    
    
    

save_path = args.model_dir+"/VideoMME_test_"+args.model_pth+"/WithSub_result"
acc_path = args.model_dir+"/VideoMME_test_"+args.model_pth+"/WithSub_acc"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
category_list = ["short", "medium", "long"]

with open(f"{save_path}.json", "w") as f:
    json.dump({
        "acc_dict": acc_dict,
        "res_list": res_list
    }, f)

with open(f"{save_path}_full.json", "w") as f:
    json.dump(res_json_data, f)

with open(f"{acc_path}.txt", "w") as f:
    f.write("Acc: " + str(round(100 * correct / total, 1)) + "\n")
    for duration_category in category_list:
        f.write(duration_category + " Acc: " + str(round(100 * acc_dict[duration_category][0] / acc_dict[duration_category][1], 1)) + "\n")