<a href="https://colab.research.google.com/github/akash13s/ASL-Interpreter/blob/main/asl_to_text.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --upgrade -q accelerate bitsandbytes
!pip install git+https://github.com/huggingface/transformers.git
!pip install -q av

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from transformers import BitsAndBytesConfig, LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
import torch
import av
import numpy as np

from dataset import VideoDataset, getDefaultTransform
from torch.utils.data import DataLoader

## Load and Set Up the Model

Loading the model with quantization enabled

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

processor = LlavaNextVideoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    "llava-hf/LLaVA-NeXT-Video-7B-hf",
    quantization_config=quantization_config,
    device_map='auto'
)

## Frame Sampling

In order to process the video we will sample 32 frames.

In [None]:
def sample_frames(frames, num_sample_frames=32):

  # Generate indices for uniform sampling
  indices = torch.linspace(0, frames.shape[1] - 1, steps=num_sample_frames).long()

  # Sample frames using the calculated indices
  sampled_frames = frames[:, indices, :, :, :]  # Tensor of frames: (batch_size, num_sample_frames, C, H, W)

  return sampled_frames

## Process and Generate Using the Model

Here we will apply the prompt, and generate outputs

In [None]:
def process_and_generate(model, processor, frames, prompt_text):
  # Prepare conversation template
  conversation = [
    {
      "role": "user",
      "content": [
        {"type": "text", "text": prompt_text},
        {"type": "video"},
      ],
    },
  ]

  # Apply chat template
  prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

  # Process inputs for the model
  inputs = processor([prompt], videos=[frames], return_tensors="pt").to(model.device)

  # Set generation parameters
  generate_kwargs = {"max_new_tokens": 100, "do_sample": True, "top_p": 0.9}

  # Generate response
  output = model.generate(**inputs, **generate_kwargs)
  generated_text = processor.batch_decode(output, skip_special_tokens=True)

  return generated_text

## Iterate Through the DataLoader for Batch Processing

Now we will go through each batch in our dataloader, sample frames, and pass them through the model.


In [None]:
video_dir = "/scratch/as18464/raw_videos/"
csv_file = "/scratch/rr4577/translation/train.csv"

transform = getDefaultTransform()

dataset = VideoDataset(video_dir=video_dir, csv_file=csv_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

# for batch in dataloader:
batch = next(iter(dataloader))

# Get frames and sentences from batch
frames, sentences = batch

# Sample frames from each video in the batch
frames = sample_frames(frames)

# Try several prompts
prompts = ["Translate the ASL gestures in this video into English text.", "Provide a detailed translation of the ASL signing in this video.", "Convert the sign language in this video to text."]

for prompt in prompts:
  # Generate response from the model
  generated_text = process_and_generate(model, processor, frames, prompt)

  print("Generated Text:", generated_text)
  print("Ground Truth Sentences:", sentences)