## Inference

Python notebook to run inference on the trained model

In [None]:
import os

import torch
from accelerate import Accelerator
from torchvision import transforms
from transformers import AutoProcessor

from components.model import get_model
from components.pre_processor import get_frames

In [2]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.cuda.empty_cache()

In [None]:
# Constants
MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf"

In [None]:
# File/directory
CACHE_DIR = "../cache/"

In [None]:
# Quantization parameters
USE_QLORA = False
USE_8BIT = False
USE_DBL_QUANT = False

In [None]:
# LoRA hyperparameters
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.1

In [None]:
# model constants
MAX_LENGTH = 3500

In [None]:
def load_trained_model(checkpoint_path):
    p_model = get_model(
        model_id=MODEL_ID,
        use_qlora=USE_QLORA,
        use_8bit=USE_8BIT,
        use_double_quant=USE_DBL_QUANT,
        lora_r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        cache_dir=CACHE_DIR
    )
    checkpoint = torch.load(checkpoint_path)

    # Load only the LoRA weights
    p_model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    return p_model, checkpoint['epoch'], checkpoint['loss']

In [14]:
# Usage:
accelerator = Accelerator()
model, epoch, loss = load_trained_model('./output/checkpoint_epoch_20')
model = accelerator.prepare(model)

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  checkpoint = torch.load(checkpoint_path)


In [15]:
print(epoch)

20


In [16]:
print(loss)

tensor(0.0271, requires_grad=True)


In [42]:
def generate_for_single_video(model, processor, video_path, accelerator):
    # Set model to evaluation mode
    model.eval()

    # Get frames from the video
    frames = get_frames(video_path, num_frames=16)  # Using 16 frames as in training

    # Convert frames to tensor
    frames_list = [transforms.ToTensor()(frame) for frame in frames]
    frame_tensor = torch.stack(frames_list)

    # Convert to format expected by processor
    images = [img.permute(1, 2, 0).cpu().numpy() for img in frame_tensor]

    # Create prompt
    tmp_prompt = "Translate the sign language to english text."
    prompt = f"USER: <video> {tmp_prompt}\n ASSISTANT: Answer:"

    # Process inputs
    batch = processor(
        text=prompt,
        videos=[images],  # Wrap in list as processor expects batch
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    )

    # Prepare inputs
    input_ids = accelerator.prepare(batch["input_ids"])
    attention_mask = accelerator.prepare(batch["attention_mask"])
    pixel_values_videos = accelerator.prepare(batch["pixel_values_videos"])

    frame_count = pixel_values_videos.shape[1]
    height, width = pixel_values_videos.shape[3], pixel_values_videos.shape[4]
    n_video_tokens = (input_ids == processor.tokenizer.convert_tokens_to_ids("<video>")).sum(dim=1)
    expected_tokens = frame_count * (height // processor.patch_size) * (width // processor.patch_size) // 4
    token_diffs = expected_tokens - n_video_tokens
    
    # Adjust input_ids and attention_mask
    max_length = input_ids.size(1) + max(0, token_diffs.max().item())
    adjusted_input_ids = torch.full((input_ids.size(0), max_length), processor.tokenizer.pad_token_id, device=accelerator.device)
    adjusted_attention_mask = torch.zeros((input_ids.size(0), max_length), device=accelerator.device)
    
    for i in range(input_ids.size(0)):
        current_length = input_ids.size(1)
        diff = token_diffs[i].item()
    
        # Add tokens or truncate as needed
        if diff > 0:
            # Add extra <video> tokens
            adjusted_input_ids[i, :current_length] = input_ids[i]
            adjusted_input_ids[i, current_length:current_length + diff] = processor.tokenizer.convert_tokens_to_ids("<video>")
            adjusted_attention_mask[i, :current_length + diff] = attention_mask[i]
        else:
            # Truncate tokens
            adjusted_input_ids[i, :current_length + diff] = input_ids[i, :current_length + diff]
            adjusted_attention_mask[i, :current_length + diff] = attention_mask[i, :current_length + diff]
    
    # Replace original tensors with adjusted ones
    input_ids = adjusted_input_ids
    attention_mask = adjusted_attention_mask

    # Generate text
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values_videos=pixel_values_videos,
            max_length=200
        )

    # Decode the generated text
    return processor.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

In [43]:
processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"
processor.image_processor.do_rescale = False

In [48]:
video_path = '/scratch/as18464/raw_videos/-06_nJnhORg_3-5-rgb_front.mp4'
generated_text = generate_for_single_video(model, processor, video_path, accelerator)

In [49]:
generated_text

['USER:  Translate the sign language to english text.\n ASSISTANT: Answer: Gener']