In [None]:
import torch
from internvl2 import InternVLChatModel

In [None]:
name = "OpenGVLab/InternVL2-2B"
model = InternVLChatModel.from_pretrained(name, torch_dtype=torch.bfloat16, device_map='cuda:7')

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True)

In [None]:
generation_config = dict(max_new_tokens=1024, do_sample=True)
question = 'Hello, who are you?'
response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')

In [None]:
text = '<|im_end|><|im_start|>assistant\n'
tokenizer(text, return_tensors='pt')

In [None]:
from data import load_video
video_path = 'localdata/red-panda.mp4'
pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
pixel_values = pixel_values.to(torch.bfloat16).to(model.device)
video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])
question = video_prefix + 'What is the red panda doing?'
# Frame1: <image>\nFrame2: <image>\n...\nFrame8: <image>\n{question}
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
                               num_patches_list=num_patches_list, history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')

In [None]:
pixel_values = pixel_values.repeat(2, 1, 1, 1)
print(input_ids.shape, attention_mask.shape, pixel_values.shape)

In [None]:
print(num_patches_list)

In [None]:
model.batch_chat(tokenizer, pixel_values, ['What is the red panda doing?', 'What is the red panda eating?'], generation_config, num_patches_list=[8, 8])

# Test MoE Video Judge

In [None]:
from moe_reward import InternVLChatRewardModeling, InternVLChatRewardModelingConfig
from transformers import AutoTokenizer
from internvl2 import InternVLChatModel, InternVLChatConfig, prepare_chat_input
import torch

from torch import distributed as dist
import os

os.environ['WORLD_SIZE'] = str(1)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(12345)
os.environ['LOCAL_RANK'] = str(0)
os.environ['RANK'] = str(0)


dist.init_process_group(backend='nccl', world_size=1, rank=0)

name = "OpenGVLab/InternVL2-2B"
tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
config = InternVLChatRewardModelingConfig.from_pretrained(name, pad_token_id=tokenizer.pad_token_id, num_objectives=10, num_aspects=3, aspect2criteria={
    0: [0, 1, 2],
    1: [3, 4, 5],
    2: [6, 7, 8, 9]
}, gating_temperature=1.0, gating_hidden_dim=1024, gating_n_hidden=3)


In [2]:
model = InternVLChatRewardModeling(name=name, config=config)
model = model.to(torch.bfloat16).to('cuda:7')

In [None]:
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'
model.model.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
print(model.model.img_context_token_id)

In [None]:
from data import load_video
import torch

def pad_to_batch(pad_token_id, input_ids_list: list, attention_mask_list: list, pixel_values_list: list):
    max_len = max(input_ids.shape[-1] for input_ids in input_ids_list)
    for i in range(len(input_ids_list)):
        input_ids_list[i] = torch.cat(
            [input_ids_list[i], torch.full((input_ids_list[i].shape[0], max_len - input_ids_list[i].shape[-1]), pad_token_id, dtype=input_ids_list[i].dtype, device=input_ids_list[i].device)], dim=-1
        )
        attention_mask_list[i] = torch.cat(
            [attention_mask_list[i], torch.zeros((attention_mask_list[i].shape[0], max_len - attention_mask_list[i].shape[-1]), dtype=attention_mask_list[i].dtype, device=attention_mask_list[i].device)], dim=-1
        )
    
    input_ids_list = torch.cat(input_ids_list, dim=0)
    attention_mask_list = torch.cat(attention_mask_list, dim=0)
    pixel_values_list = torch.cat(pixel_values_list, dim=0)

    return input_ids_list, attention_mask_list, pixel_values_list
    

generation_config = dict(max_new_tokens=1024, do_sample=True)
video_path = 'localdata/red-panda.mp4'
pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
pixel_values = pixel_values.to(torch.bfloat16).to(model.model.device)
video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])
question1 = video_prefix + 'What is the red panda doing?'
question2 = video_prefix + 'What is the red panda eating? and what is the red panda doing? and how many red pandas are there?'

input_ids1, attention_mask1 = prepare_chat_input(config, tokenizer, pixel_values, question1, generation_config, device=model.model.device)
input_ids2, attention_mask2 = prepare_chat_input(config, tokenizer, pixel_values, question2, generation_config, device=model.model.device)


input_ids, attention_mask, pixel_values = pad_to_batch(tokenizer.pad_token_id, [input_ids1, input_ids2], [attention_mask1, attention_mask2], [pixel_values, pixel_values])
print(input_ids1.shape, input_ids2.shape)
print(input_ids.shape, attention_mask.shape, pixel_values.shape)



print(input_ids[:, -19:], attention_mask[:, :-10])

In [None]:
outs = model.forward(pixel_values, input_ids, attention_mask)