In [1]:
%load_ext autoreload
%autoreload 2

import os
os.chdir('../')

In [None]:
# In[1]# In[1]:
import os
import sys
import copy
import torch
import argparse
from transformers import StoppingCriteria, StoppingCriteriaList
from math import ceil
from PIL import Image
import numpy as np
import decord

decord.bridge.set_bridge('torch')
from torchvision.transforms.functional import InterpolationMode
import json
import time
import datetime
from tqdm import tqdm
import random

random.seed(1234)

from utils.config import Config
from utils.easydict import EasyDict
from transformers import StoppingCriteria, StoppingCriteriaList
from decord import VideoReader, cpu
import torchvision.transforms as T
from dataset.video_transforms import (
    GroupNormalize, GroupScale, GroupCenterCrop, 
    Stack, ToTorchFormatTensor
)
from peft import get_peft_model, LoraConfig, TaskType
from io import BytesIO
from models import *

try:
    from petrel_client.client import Client
    has_client = True
    print("Client on!")
except:
    has_client = False
    print("Client off!")

if has_client:
    client = Client('~/petreloss.conf')
else:
    client = None


# In[2]:


def get_args():
    parser = argparse.ArgumentParser()
    #与测试任务无关
    parser.add_argument('--model_type', default="VideoChat2_it4_mistral_LinearProAda")
    parser.add_argument('--model_dir', default="./download/parameters")
    parser.add_argument('--model_pth', default="timesuite")
    parser.add_argument('--output_dir', default="Please input model output dir!")
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--infer_clip_frames', type=int, default=8)
    
    args = parser.parse_args(args=[])    
    return args


args = get_args()
args_list_str = '\n' + '\n'.join([f'{k:<25}: {v}' for k, v in vars(args).items()])
print(args_list_str)


# In[5]:

# config_file = "configs/config_mistral.json"
config_file = args.model_dir+"/config.json"

cfg = Config.from_file(config_file)
cfg.model.use_lora = False
cfg.model.pretrained_path=None
cfg.device="cuda:7"


print("vision_encoder.num_frames:", cfg.model.vision_encoder.num_frames)
# cfg.model.vision_encoder.num_frames = 4

model_cls = eval(args.model_type)
# model = VideoChat2_it_mistral(config=cfg.model)
model = model_cls(config=cfg.model)


# add lora to run stage3 model
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False, 
    r=16, lora_alpha=32, lora_dropout=0.,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
         "gate_proj", "up_proj", "down_proj", "lm_head"
    ]
)
model.mistral_model = get_peft_model(model.mistral_model, peft_config)


# state_dict = torch.load("./download/parameters/videochat2_mistral_7b_stage3.pth", "cpu")
state_dict = torch.load(args.model_dir+"/"+args.model_pth+".pth", "cpu")


if 'model' in state_dict.keys():
    msg = model.load_state_dict(state_dict['model'], strict=False)
else:
    msg = model.load_state_dict(state_dict, strict=False)
print(msg)

model = model.to(torch.device(cfg.device))
model = model.eval()

print('Model Initialization Finished')







def get_prompt(conv):
    ret = conv.system + conv.sep
    for role, message in conv.messages:
        if message:
            ret += role + " " + message + " " + conv.sep
        else:
            ret += role
    return ret


def get_prompt2(conv):
    ret = conv.system + conv.sep
    count = 0
    for role, message in conv.messages:
        count += 1
        if count == len(conv.messages):
            ret += role + " " + message
        else:
            if message:
                ret += role + " " + message + " " + conv.sep
            else:
                ret += role
    return ret


def get_context_emb(conv, model, img_list, answer_prompt=None, print_res=False):
    if answer_prompt:
        prompt = get_prompt2(conv)
    else:
        prompt = get_prompt(conv)
    if print_res:
        print("prompt:",prompt)
    if '<VideoHere>' in prompt:
        prompt_segs = prompt.split('<VideoHere>')
    else:
        prompt_segs = prompt.split('<ImageHere>')
    assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
    with torch.no_grad():
        seg_tokens = [
            model.mistral_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(cfg.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [model.mistral_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
#         seg_embs = [model.mistral_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
    mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
    mixed_embs = torch.cat(mixed_embs, dim=1)
    return mixed_embs


def ask(text, conv):
    conv.messages.append([conv.roles[0], text])
        

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False
    
    
def answer(conv, model, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None, print_res=False):
    stop_words_ids = [
        torch.tensor([2]).to(cfg.device),
        torch.tensor([29871, 2]).to(cfg.device)]  # '</s>' can be encoded in two different ways.
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    
    conv.messages.append([conv.roles[1], answer_prompt])
    embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt, print_res=print_res)
    with torch.no_grad():
        outputs = model.mistral_model.generate(
            inputs_embeds=embs,
            max_new_tokens=max_new_tokens,
            stopping_criteria=stopping_criteria,
            num_beams=num_beams,
            do_sample=do_sample,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
        )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
            output_token = output_token[1:]
    if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
            output_token = output_token[1:]
    output_text = model.mistral_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('</s>')[0]  # remove the stop sign </s>
#     output_text = output_text.split('[/INST]')[-1].strip()
    conv.messages[-1][1] = output_text + '</s>'
    return output_text, output_token.cpu().numpy()



def get_index(num_frames, num_segments):
    seg_size = float(num_frames - 1) / num_segments
    start = int(seg_size / 2)
    offsets = np.array([
        start + int(np.round(seg_size * idx)) for idx in range(num_segments)
    ])
    return offsets


In [3]:
def load_video(video_path, num_segments=8, return_msg=False, resolution=224):
    
    if client is not None and "s3" in video_path:
        video_bytes = client.get(video_path)
        assert(video_bytes is not None)
        vr = VideoReader(BytesIO(video_bytes), ctx=cpu(0), num_threads=1)
    else:
        vr = VideoReader(uri=video_path, ctx=cpu(0), num_threads=1)
    num_frames = len(vr)
    frame_indices = get_index(num_frames, num_segments)

    # transform
    crop_size = resolution
    scale_size = resolution
    input_mean = [0.48145466, 0.4578275, 0.40821073]
    input_std = [0.26862954, 0.26130258, 0.27577711]

    transform = T.Compose([
        GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
        GroupCenterCrop(crop_size),
        Stack(),
        ToTorchFormatTensor(),
        GroupNormalize(input_mean, input_std) 
    ])

    images_group = list()
    for frame_index in frame_indices:
        img = Image.fromarray(vr[frame_index].numpy())
        images_group.append(img)
    torch_imgs = transform(images_group)
    if return_msg:
        fps = float(vr.get_avg_fps())
        sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
        msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds. "
        # sec = [str(round(f / fps, 1)) for f in frame_indices]
        # msg = f"The video contains {len(frame_indices)} frames uniformly sampled from {sec[0]} to {sec[-1]} seconds. "
        return torch_imgs, msg
    else:
        return torch_imgs
    

In [4]:
def generate_videochat(vid_path, user_messages):
    
    num_frame = model.clip_frames
    tot_frames = model.total_frames
    resolution = cfg.model.vision_encoder.img_size
    
    vid, msg = load_video(vid_path, num_segments=tot_frames, return_msg=True, resolution=resolution)

    # The model expects inputs of shape: T x C x H x W
    TC, H, W = vid.shape
    video = vid.reshape(1, TC//3, 3, H, W).to(cfg.device)

    img_list = []
    with torch.no_grad():
        image_emb = model.encode_long_video(video,[msg,],"")
        print("Shape of long video embeds: ", image_emb.shape)
#         image_emb, _ = model.encode_img(video, "")
    img_list.append(image_emb)
    
    chat = EasyDict({
    "system": "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail. ",
    "roles": ("[INST]", "[/INST]"),
    "messages": [],
    "sep": ""
    })
    
    chat.messages.append([chat.roles[0], "<Video><VideoHere></Video> [/INST]"])
    ask(msg+user_messages, chat)

    llm_answer = answer(conv=chat, model=model, do_sample=False, img_list=img_list, max_new_tokens=256, print_res=True)[0]
    print("LLM answer:", llm_answer,"\n\n\n")
    
    return llm_answer, chat, img_list

In [None]:
# vid_path = "./download/demo_video/ikun.mp4"
# question="At what time in the video does the man dribble the basketball?"
# question="What did the man do after throwing away the basketball?"


# vid_path = "./download/demo_video/sora.mp4"
# question = "When does the close-up of this woman's face appear?"
# LLM answer: From 37.8 to 59.9 seconds.
# question = "Describe this video in detail."
# LLM answer: A woman is seen walking down a street in the rain while wearing sunglasses and a leather jacket. She continues walking down the street and looking off into the distance.  


# vid_path = "./download/demo_video/rocket.mp4"
# question = "At what time in the video the rocket ignite and launch?"
# LLM answer: From 10.0 to 13.0 seconds.
# question = "Describe this video in detail."
# LLM answer: The video shows a rocket launching from a building and then a space station docking with the international space station. The rocket is white and red with a red star on it. The space station is white and gray with solar panels. The video is in Chinese and there are no people in the video.  


# vid_path = "./download/demo_video/movie.mp4"
# question = "When in the video do black men and white women hug each other?"
# LLM answer: 10.0 - 15.0 seconds
# question = "What are the two men in the video putting in the briefcase in the car?."
# LLM answer: Money.  


# vid_path = "./download/demo_video/legendof1900.mp4"
# question = "When in the video is the man in a white suit lighting a cigarette by the piano?"
# LLM answer: From 168.0 to 171.0 seconds. 
# question = "Why are the strings in the video hot enough to light a cigarette?"


# vid_path = "./download/demo_video/TheWanderingEarth2.mp4"
# question = "At what moment in the picture did the person in the red life jacket and camouflage uniform open the isolation door?"


vid_path = "./download/demo_video/mrbean.mp4"
question = "At what point does the man in the brown suit extend his hand to his deskmate?"


generate_videochat(vid_path,question)