In [20]:
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoModel, AutoImageProcessor

model_name = "DAMO-NLP-SG/VideoLLaMA3-2B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
video_path = "put your video path here"
image_path = "put your image path here"
question = "Describe this video in detail."

In [21]:
import torch
torch.cuda.empty_cache()

# Test Generation

In [29]:
# Video conversation
video_path = "/home/ubuntu/sample.mp4"
#image_path = "/home/ubuntu/45s.png"
image_path = "/home/ubuntu/graduation.png"


#sample question
#question = "What happens at time: 10 seconds?"


#image based QA
#image_caption = "A modern educational or institutional building with a green lawn, featuring a white facade and large glass windows under a clear blue sky. A police emblem watermark is visible in the upper right corner"
question = "What's his name"


conversation = [
    {"role": "system", "content": "You are a helpful and informed assistant."},
    {
        "role": "user",
        "content": [
            {"type": "video", "video": {"video_path": video_path, "fps": 1, "max_frames": 1, "start_time":50.0, "end_time":70.0}},
            {"type": "text", "text": question},
        ]
    },
]

inputs = processor(conversation=conversation, return_tensors="pt")
inputs = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
if "pixel_values" in inputs:
    inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
output_ids = model.generate(**inputs, max_new_tokens=256)
response = processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(response)


The man in the video is not named.


# Extracting Past KV Cache

In [30]:
a = model.forward(**inputs, return_dict=True, use_cache=True)

In [31]:
inputs

{'input_ids': tensor([[151644,   8948,    198,   2610,    525,    264,  10950,    323,  15987,
           17847,     13, 151645,    198, 151644,    872,    198,   1462,    220,
              20,     15,     13,     15,     82,     25, 151665, 151665, 151665,
          151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
          151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
          151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
          151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
          151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
          151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
          151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
          151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
          151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665, 151665,
   

In [32]:
a.past_key_values[0][0].shape

torch.Size([1, 2, 330, 128])

In [33]:
inputs["pixel_values"].shape

torch.Size([1196, 588])

# Feeding in KV Cache

# Adding special Token

In [34]:
from transformers import AutoTokenizer
# Initialize the tokenizer
model_id = "DAMO-NLP-SG/VideoLLaMA3-2B"
tokenizer = AutoTokenizer.from_pretrained(model_id)

new_tokens = ["[MEMORY]"]
tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})

# Check if the tokens are in the vocabulary
for token in new_tokens:
    token_id = tokenizer.convert_tokens_to_ids(token)
    print(f"Token: {token}, Token ID: {token_id}")

Token: [MEMORY], Token ID: 151668


In [35]:
# Check that it works
test_str = "qwer qwer qwer [MEMORY][MEMORY]"
tokens = tokenizer.tokenize(test_str)
ids = tokenizer.convert_tokens_to_ids(tokens)
reconstruct_tokens = tokenizer.convert_ids_to_tokens(ids)

print(tokens)
print(ids)
print(reconstruct_tokens)

['q', 'wer', 'Ġq', 'wer', 'Ġq', 'wer', 'Ġ', '[MEMORY]', '[MEMORY]']
[80, 6566, 2804, 6566, 2804, 6566, 220, 151668, 151668]
['q', 'wer', 'Ġq', 'wer', 'Ġq', 'wer', 'Ġ', '[MEMORY]', '[MEMORY]']


In [36]:
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

['<|im_start|>',
 'system',
 'Ċ',
 'You',
 'Ġare',
 'Ġa',
 'Ġhelpful',
 'Ġand',
 'Ġinformed',
 'Ġassistant',
 '.',
 '<|im_end|>',
 'Ċ',
 '<|im_start|>',
 'user',
 'Ċ',
 'Time',
 'Ġ',
 '5',
 '0',
 '.',
 '0',
 's',
 ':',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 '<image>',
 

# collate_fn 

In [37]:
from datasets import load_from_disk

hf_dataset = load_from_disk("/home/ubuntu/temp/sports_dataset")

Loading dataset from disk:   0%|          | 0/116 [00:00<?, ?it/s]

In [38]:
hf_dataset[0].keys()

dict_keys(['chunks', 'mp4'])

In [39]:
print(len(hf_dataset[0]['chunks']))
hf_dataset[0]['chunks'][0:4]

20


[{'activity': 'Police officers walk together on a sidewalk',
  'activity_interval': [2, 3],
  'curr_interval': [0, 5],
  'mp4_path': '/home/ubuntu/temp/mp4s/0.mp4'},
 {'activity': 'Police recruit #2 runs',
  'activity_interval': [7, 8],
  'curr_interval': [5, 10],
  'mp4_path': '/home/ubuntu/temp/mp4s/0.mp4'},
 {'activity': 'Police recruits practice shooting',
  'activity_interval': [9, 10],
  'curr_interval': [10, 15],
  'mp4_path': '/home/ubuntu/temp/mp4s/0.mp4'},
 {'activity': 'Police officers stand in formation',
  'activity_interval': [6, 8],
  'curr_interval': [15, 20],
  'mp4_path': '/home/ubuntu/temp/mp4s/0.mp4'}]

In [40]:
from torch.utils.data import DataLoader

dataloader_hf = DataLoader(hf_dataset, batch_size=1, shuffle=True)

batch = next(iter(dataloader_hf))

In [41]:
tokenizer.pad_token_id

151643

In [42]:
def collect_and_pad(batch, tokenizer):
    pad_token = tokenizer.pad_token_id
    max_len = max([len(b_item["input_ids"][0]) for b_item in batch])
    batch_dict = {"input_ids": [], "attention_mask": []}
    # input_ids
    for b_item in batch:
        # Pad the input_ids to the max length
        input_ids = b_item["input_ids"]
        padding_length = max_len - len(input_ids[0])
        input_ids = torch.cat([input_ids[0], pad_token*torch.ones(padding_length, dtype=torch.long)])
        batch_dict["input_ids"].append(input_ids)
    batch_dict["input_ids"] = torch.stack(batch_dict["input_ids"])
    # attention_mask
    for b_item in batch:
        # Pad the attention_mask to the max length
        attention_mask = b_item["attention_mask"]
        padding_length = max_len - len(attention_mask[0])
        attention_mask = torch.cat([attention_mask[0], torch.zeros(padding_length, dtype=torch.long)])
        batch_dict["attention_mask"].append(attention_mask)
    batch_dict["attention_mask"] = torch.stack(batch_dict["attention_mask"])

    # assum no padding required for image related items
    for key in ["pixel_values", "grid_sizes", "merge_sizes", "modals"]:
        if key in batch[0]:
            batch_dict[key] = batch[0][key]
        #     if(key != "modals"):
        #         batch_dict[key] = torch.stack([b_item[key] for b_item in batch])
        #     else:
        #         batch_dict[key] = [b_item[key] for b_item in batch]
    return batch_dict


def collate_fn(batch, tokenizer, num_memory_slots, max_chunks= 4, batchify=False, cuda=False):
    # batchify the data
    if batchify:
        batch_dict = {}
        batch_dict['mp4'] = [item['mp4'] for item in batch]
        batch_dict['chunks'] = [item['chunks'] for item in batch]
        batch = batch_dict
    # Extract the relevant fields from the batch
    chunks = batch['chunks']

    full_batch_dict = {}
    num_chunks = len(chunks) if len(chunks) < max_chunks else max_chunks
    b_size = len(chunks[0]['activity'])
    for t in range(num_chunks):
        # Tokenize the instructions, time steps, questions, and answers
        messages = [[{"role": "user", "content": 
                        [
                            {"type": "video", "video": {"video_path": chunks[t]['mp4_path'][i], "fps": 1, "max_frames": 100, 
                                                        "start_time":chunks[t]['curr_interval'][0][0], "end_time":chunks[t]['curr_interval'][1][0]}},
                            {"type": "text", "text": " "},
                        ]},
                    {"role": "assistant", "content": "[MEMORY]"*num_memory_slots}] for i in range(b_size)]
        messages = [processor(conversation=message, return_tensors="pt") for message in messages]
        messages = collect_and_pad(messages, tokenizer)

        qa = [[{"role": "user", "content": f"What happened during during T: {chunks[t]['activity_interval'][0][0]}-{chunks[t]['activity_interval'][1][0]}s"}, 
               {"role": "assistant", "content": chunks[t]['activity'][i]}] 
                    for i in range(b_size)]
        qa = [processor(conversation=message, return_tensors="pt") for message in qa]
        qa = collect_and_pad(qa, tokenizer)

        # # Move tokenized output to CUDA
        # mem_tokenized = {key: torch.tensor(value) for key, value in mem_tokenized.items()}
        # qa_tokenized = {key: torch.tensor(value) for key, value in qa_tokenized.items()}
        # labels = labels
        if(cuda):
            messages = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in messages.items()}
            if "pixel_values" in messages:
                messages["pixel_values"] = messages["pixel_values"].to(torch.bfloat16)
            qa = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in qa.items()}
        
        full_batch_dict[f"T{t}"] = {
            'memory_save': messages,
            'QA': qa,
            # 'labels': labels,
        }
    return full_batch_dict

In [43]:
batch_input = collate_fn(batch, tokenizer, 2, cuda=True)

In [44]:
batch_input["T0"]['memory_save']

{'input_ids': tensor([[151644,    872,    198,  ...,     60, 151645,    198]],
        device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'),
 'pixel_values': tensor([[ 0.1299,  0.2559,  0.2637,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.3105,  0.3105,  0.3418,  ...,  1.0000,  1.0000,  1.0000],
         [ 0.2949,  0.4277,  0.4355,  ...,  1.0000,  1.0000,  1.0000],
         ...,
         [-1.0000, -1.0000, -1.0000,  ..., -0.5625, -0.5625, -0.5625],
         [ 0.1924,  0.2002,  0.1924,  ..., -0.4512, -0.4824, -0.4746],
         [-0.6250, -0.4434, -1.0000,  ..., -0.5312, -0.5781, -0.5781]],
        device='cuda:0', dtype=torch.bfloat16),
 'grid_sizes': tensor([[ 5, 26, 46]], device='cuda:0'),
 'merge_sizes': tensor([2], device='cuda:0'),
 'modals': ['video']}

In [48]:
batch_input["T0"]['memory_save']['input_ids'].shape, batch_input["T0"]['memory_save']['attention_mask'].shape, batch_input["T0"]['memory_save']['pixel_values'].shape

(torch.Size([1, 1551]), torch.Size([1, 1551]), torch.Size([5980, 588]))

In [53]:
a = model.forward(**batch_input["T0"]["QA"], return_dict=True, use_cache=True, past_key_values=a.past_key_values)

In [54]:
a

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 4.0000, -2.1406, -0.9297,  ..., -1.8125, -1.8125, -1.8125],
         [ 6.5312,  4.3438,  3.6094,  ..., -1.5469, -1.5469, -1.5469],
         [ 6.5938, 11.2500,  8.2500,  ..., -2.8750, -2.8750, -2.8750],
         ...,
         [13.2500,  6.9375,  3.1406,  ..., -0.3633, -0.3633, -0.3633],
         [ 9.0000,  6.8750,  6.9375,  ..., -2.3594, -2.3594, -2.3594],
         [ 6.6875,  7.7500,  8.1875,  ..., -1.2734, -1.2734, -1.2734]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>), past_key_values=DynamicCache(), hidden_states=None, attentions=None)

In [57]:
a.past_key_values 

DynamicCache()

In [52]:
batch_input["T0"]['QA']['input_ids'].shape

torch.Size([1, 25])

In [194]:
model

Videollama3Qwen2ForCausalLM(
  (model): Videollama3Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), eps=1e-06

In [61]:
reconstruct_tokens = tokenizer.convert_ids_to_tokens(batch_input["T0"]['QA']['input_ids'][0])
reconstruct_tokens, batch_input["T0"]['QA']['input_ids'][0]

(['<|im_start|>',
  'user',
  'Ċ',
  'What',
  'Ġhappened',
  'Ġduring',
  'Ġduring',
  'ĠT',
  ':',
  'Ġ',
  '0',
  '-',
  '1',
  's',
  '<|im_end|>',
  'Ċ',
  '<|im_start|>',
  'assistant',
  'Ċ',
  'Police',
  'Ġrecruits',
  'Ġpractice',
  'Ġcombat',
  '<|im_end|>',
  'Ċ'],
 tensor([151644,    872,    198,   3838,   6932,   2337,   2337,    350,     25,
            220,     15,     12,     16,     82, 151645,    198, 151644,  77091,
            198,  22202,  55097,   6588,  12610, 151645,    198],
        device='cuda:0'))

In [63]:
tokenizer.convert_ids_to_tokens([77091]), tokenizer.convert_tokens_to_ids("assistant")

(['assistant'], 77091)