# Import Libraries

In [32]:
!pip install av numpy torch torchvision peft pandas tqdm
!pip install --upgrade bitsandbytes transformers

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [33]:
import av
import numpy as np
import torch
import os
import pandas as pd

from tqdm import tqdm
from typing import Tuple
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DataParallel

from transformers import BitsAndBytesConfig, LlavaNextVideoForConditionalGeneration, AutoProcessor
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from accelerate import Accelerator, DistributedDataParallelKwargs

# Setup

In [34]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [35]:
torch.cuda.empty_cache()

In [36]:
# Constants
MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf"
MODEL_NAME = MODEL_ID.split("/")[-1]

In [37]:
# File/directory
VIDEO_DIR = "/scratch/as18464/raw_videos"
CSV_FILE = "../data/valid_clips.csv"
CACHE_DIR = "./cache/"
OUTPUT_DIR = "./output/"

In [38]:
# LoRA hyperparameters
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.1
LORA_TARGET_MODULES = [
    "q_proj",
    "v_proj",
    "k_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]

In [39]:
# Quantization parameters
USE_QLORA = False
USE_8BIT = False
USE_DBL_QUANT = False

In [40]:
# model constants
DATASET_SIZE = 4
BATCH_SIZE = 1
MAX_LENGTH = 3500
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.05

In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Quantizations


In [11]:
def get_8bit_qlora(use_double_quant: bool):
    return BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=use_double_quant
    )


def get_4bit_qlora(use_double_quant: bool):
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=use_double_quant
    )


def get_lora():
    return BitsAndBytesConfig(
        load_in_8bit=False,
        load_in_4bit=False,
        llm_int8_threshold=0.5,  # Lower threshold for increased precision
        llm_int8_skip_modules=None,  # None if skipping is not needed
        llm_int8_enable_fp32_cpu_offload=False,
        llm_int8_has_fp16_weight=True,  # Use FP16 weights for better precision
        # Ensures highest precision in computations
        bnb_4bit_compute_dtype=torch.float16
    )


def get_bnb_config(use_qlora: bool, use_8bit: bool, use_double_quant: bool):
    if use_qlora:
        if use_8bit:
            return get_8bit_qlora(use_double_quant)

        return get_4bit_qlora(use_double_quant)

    return get_lora()

# Model

In [12]:
def get_base_model(model_id, bnb_config, cache_dir):
    return LlavaNextVideoForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        cache_dir=cache_dir,
        quantization_config=bnb_config,
        device_map="auto",
    )


def get_lora_config(lora_r, lora_alpha, lora_dropout):
    return LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=LORA_TARGET_MODULES,
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )


def set_trainable_params(model):
    # First make sure all parameters are not trainable
    for param in model.parameters():
        param.requires_grad = False

    # Then enable training only for the LoRA parameters
    for name, param in model.named_parameters():
        if "lora_" in name:  # This targets only the LoRA layers
            param.requires_grad = True


def get_model(
        model_id: str,
        use_qlora: bool,
        use_8bit: bool,
        use_double_quant: bool,
        lora_r: int,
        lora_alpha: int,
        lora_dropout: float,
        cache_dir: str
):
    bnb_config = get_bnb_config(use_qlora, use_8bit, use_double_quant)
    model = get_base_model(model_id, bnb_config, cache_dir)
    lora_config = get_lora_config(lora_r, lora_alpha, lora_dropout)

    model = prepare_model_for_kbit_training(model)
    # model.gradient_checkpointing_enable()

    peft_model = get_peft_model(model, lora_config)
    set_trainable_params(peft_model)
    return peft_model


# Pre-processor

In [13]:
def read_video_pyav(container, indices):
    """
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    """
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]

    resize_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            # Convert to numpy array in RGB format
            frame_array = frame.to_ndarray(format="rgb24")
            # Apply resize transform and convert back to numpy
            resized_frame = resize_transform(frame_array).numpy()
            # Convert from CxHxW to HxWxC format and scale back to 0-255 range
            resized_frame = (resized_frame.transpose(1, 2, 0) * 255).astype(np.uint8)
            frames.append(resized_frame)

    return np.stack(frames)


def get_frames(video_path: str, num_frames: int = 8) -> np.ndarray:
    """
    Extract frames from video with consistent sampling
    Args:
        video_path (str): Path to video file
        num_frames (int): Number of frames to extract
    Returns:
        np.ndarray: Array of frames with shape (num_frames, height, width, 3)
    """
    container = av.open(video_path)

    # Get video stream
    stream = container.streams.video[0]
    total_frames = stream.frames
    fps = stream.average_rate

    # Calculate indices to sample
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)

    # Read frames at calculated indices
    frames = read_video_pyav(container, indices)

    # Ensure we got exactly num_frames
    if len(frames) < num_frames:
        # If we got fewer frames, duplicate the last frame
        last_frame = frames[-1]
        while len(frames) < num_frames:
            frames = np.concatenate([frames, last_frame[np.newaxis, ...]], axis=0)
    elif len(frames) > num_frames:
        # If we got more frames, take the first num_frames
        frames = frames[:num_frames]

    container.close()
    return frames

# Dataset

In [14]:
class VideoDataset(Dataset):
    def __init__(self, video_dir: str, csv_file: str, dataset_size: int, num_frames: int = 8):
        self.video_dir = video_dir
        self.annotations = pd.read_csv(csv_file, sep=',').head(dataset_size).reset_index(drop=True)
        self.num_frames = num_frames
        self.system_prompt = ("Analyze the American Sign Language (ASL) signs in this video and "
                              "translate them into clear, natural English. Consider the sequence of "
                              "signs as a complete message, and provide an accurate translation that "
                              "captures the full meaning. Respond with only the English translation, "
                              "without descriptions of the signs themselves.")

        print(f"Loaded dataset with {len(self.annotations)} entries")

    def __len__(self) -> int:
        return len(self.annotations)

    def __getitem__(self, idx: int) -> Tuple[str, np.ndarray]:
        row = self.annotations.iloc[idx]
        video_id = str(row['SENTENCE_NAME']).strip()
        sentence = str(row['SENTENCE']).strip()

        video_path = os.path.join(self.video_dir, f"{video_id}.mp4")
        if not os.path.isfile(video_path):
            raise FileNotFoundError(f"Video file '{video_path}' not found.")

        frames = get_frames(video_path, self.num_frames)

        prompt = f"USER: {self.system_prompt}\n<video>\nASSISTANT: {sentence}"

        frames_list = [frame for frame in frames]
        frames_list = [transforms.ToTensor()(frame) for frame in frames_list]
        frame_tensor = torch.stack(frames_list)

        return prompt, frame_tensor


# Dataloader

In [15]:
def create_data_loader(
        video_dir: str,
        csv_file: str,
        batch_size: int,
        dataset_size: int,
        num_frames: int = 8
):
    dataset = VideoDataset(
        video_dir=video_dir,
        csv_file=csv_file,
        dataset_size=dataset_size,
        num_frames=num_frames
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # Set to 0 for debugging
        pin_memory=False
    )

    return loader

# Trainer

In [16]:
def train_epoch(config, epoch):
    loss = None
    total_loss = 0
    avg_loss = 0

    model = config['model']
    optimizer = config['optimizer']
    train_loader = config['train_loader']
    processor = config['processor']
    accelerator = config['accelerator']
    output_dir = config['output_dir']

    model.train()
    progress_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')

    for batch_idx, (texts, videos) in enumerate(progress_bar):
        vids = list(torch.unbind(videos, dim=0))
        image_lists = []
        for batch in vids:
            images = [img.cpu().permute(1, 2, 0).numpy() for img in batch]
            image_lists.append(images)
        try:
            batch = processor(
                text=texts,
                videos=image_lists,
                padding=True,
                truncation=True,
                max_length=config.get('max_length'),
                return_tensors="pt"
            )

            labels = batch["input_ids"].clone()
            labels[labels == processor.tokenizer.pad_token_id] = -100

            for i, text in enumerate(texts):
                assistant_start = None
                # Look for sequence: "ASSISTANT:"
                for j in range(len(batch["input_ids"][i])):
                    if processor.tokenizer.decode(batch["input_ids"][i][j:j + 4]) == "ASSISTANT:":
                        assistant_start = j
                        break

                if assistant_start is not None:
                    # Mask everything before and including "ASSISTANT:"
                    labels[i, :assistant_start + 4] = -100

            # To remove later - for debugging
            # print("\n====== Tokens and Labels for Batch", batch_idx, "======")
            # for i, text in enumerate(texts):
            #     print(f"\nOriginal text {i}: {text}")
            #     print("\nTokens and their labels:")
            #     tokens = processor.tokenizer.convert_ids_to_tokens(batch["input_ids"][i])
            #     for j, (token, label) in enumerate(zip(tokens, labels[i])):
            #         print(f"Position {j:3d} | Token: {token:15} | Label: {label.item():5}")
            #     print("-" * 50)

            batch["labels"] = labels

            input_ids = accelerator.prepare(batch["input_ids"])
            attention_mask = accelerator.prepare(batch["attention_mask"])
            pixel_values_videos = accelerator.prepare(batch["pixel_values_videos"])
            labels = accelerator.prepare(batch["labels"])

            frame_count = pixel_values_videos.shape[1]
            height, width = pixel_values_videos.shape[3], pixel_values_videos.shape[4]
            n_video_tokens = (input_ids == processor.tokenizer.convert_tokens_to_ids("<video>")).sum(dim=1)
            expected_tokens = frame_count * (height // processor.patch_size) * (width // processor.patch_size) // 4
            token_diffs = expected_tokens - n_video_tokens
            
            # Adjust input_ids, attention_mask, and labels
            max_length = input_ids.size(1) + max(0, token_diffs.max().item())
            adjusted_input_ids = torch.full((input_ids.size(0), max_length), processor.tokenizer.pad_token_id, device=accelerator.device)
            adjusted_attention_mask = torch.zeros((input_ids.size(0), max_length), device=accelerator.device)
            adjusted_labels = torch.full((input_ids.size(0), max_length), -100, device=accelerator.device)
            
            for i in range(input_ids.size(0)):
                current_length = input_ids.size(1)
                diff = token_diffs[i].item()
            
                # Add tokens or truncate as needed
                if diff > 0:
                    # Add extra <video> tokens
                    adjusted_input_ids[i, :current_length] = input_ids[i]
                    adjusted_input_ids[i, current_length:current_length + diff] = processor.tokenizer.convert_tokens_to_ids("<video>")
                    adjusted_attention_mask[i, :current_length + diff] = attention_mask[i]
                    adjusted_labels[i, :current_length] = labels[i]
                else:
                    # Truncate tokens
                    adjusted_input_ids[i, :current_length + diff] = input_ids[i, :current_length + diff]
                    adjusted_attention_mask[i, :current_length + diff] = attention_mask[i, :current_length + diff]
                    adjusted_labels[i, :current_length + diff] = labels[i, :current_length + diff]
            
            # Replace original tensors with adjusted ones
            input_ids = adjusted_input_ids
            attention_mask = adjusted_attention_mask
            labels = adjusted_labels

            optimizer.zero_grad()

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pixel_values_videos=pixel_values_videos,
                labels=labels
            )
            loss = outputs.loss

            accelerator.backward(loss)

            # torch.nn.utils.clip_gradnorm(model.parameters(), 1.0)
            optimizer.step()

            current_loss = loss.item()
            total_loss += current_loss
            avg_loss = total_loss / (batch_idx + 1)

            progress_bar.set_postfix({
                'batch_loss': f'{current_loss:.4f}',
                'avg_loss': f'{avg_loss:.4f}'
            })

            if accelerator.is_main_process:
                print(f'Epoch {epoch} | Batch {batch_idx}/{len(train_loader)} | '
                      f'Loss: {current_loss:.4f} | Avg Loss: {avg_loss:.4f}')

        except Exception as e:
            raise e

    if accelerator.is_main_process: # and epoch % 5 == 0:
        checkpoint_path = f"{output_dir}/checkpoint_epoch_{epoch}"
        os.makedirs(f"{output_dir}", exist_ok=True)

        unwrapped_model = accelerator.unwrap_model(model)

        if hasattr(unwrapped_model, 'get_peft_state_dict'):
            state_dict = unwrapped_model.get_peft_state_dict()
        else:
            state_dict = unwrapped_model.state_dict()

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }

        print(f"Saving checkpoint for epoch {epoch} with average loss: {avg_loss:.4f}")
        torch.save(checkpoint, checkpoint_path)

    return total_loss / len(train_loader)

# Training

In [17]:
train_loader = create_data_loader(
    video_dir=VIDEO_DIR,
    csv_file=CSV_FILE,
    batch_size=BATCH_SIZE,
    dataset_size=DATASET_SIZE,
    num_frames=16
)

Loaded dataset with 4 entries


In [18]:
p_model = get_model(
    model_id=MODEL_ID,
    use_qlora=USE_QLORA,
    use_8bit=USE_8BIT,
    use_double_quant=USE_DBL_QUANT,
    lora_r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    cache_dir=CACHE_DIR
)

optimizer = torch.optim.AdamW(p_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# initialize the accelerator with the right kwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

# Prepare your model, optimizer, and dataloader with accelerator
p_model, optimizer, train_loader = accelerator.prepare(
    p_model, optimizer, train_loader
)

processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"
processor.image_processor.do_rescale = False
processor.video_processor.do_rescale = False

Downloading shards: 100%|██████████| 3/3 [05:36<00:00, 112.13s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:36<00:00, 12.16s/it]


In [19]:
config = {
    "model": p_model,
    "train_loader": train_loader,
    "optimizer": optimizer,
    "processor": processor,
    "accelerator": accelerator,
    "output_dir": OUTPUT_DIR,
    "max_length": MAX_LENGTH
}

In [20]:
for i in range(2):
    train_epoch(config, i + 1)

  return fn(*args, **kwargs)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Training Epoch 1:  25%|██▌       | 1/4 [00:19<00:59, 19.75s/it, batch_loss=8.8511, avg_loss=8.8511]

Epoch 1 | Batch 0/4 | Loss: 8.8511 | Avg Loss: 8.8511


Training Epoch 1:  50%|█████     | 2/4 [00:32<00:30, 15.39s/it, batch_loss=7.5054, avg_loss=8.1783]

Epoch 1 | Batch 1/4 | Loss: 7.5054 | Avg Loss: 8.1783


Training Epoch 1:  75%|███████▌  | 3/4 [00:44<00:14, 14.26s/it, batch_loss=6.5325, avg_loss=7.6297]

Epoch 1 | Batch 2/4 | Loss: 6.5325 | Avg Loss: 7.6297


Training Epoch 1: 100%|██████████| 4/4 [00:56<00:00, 14.07s/it, batch_loss=5.6880, avg_loss=7.1443]


Epoch 1 | Batch 3/4 | Loss: 5.6880 | Avg Loss: 7.1443
Saving checkpoint for epoch 1 with average loss: 7.1443


Training Epoch 2:  25%|██▌       | 1/4 [00:12<00:38, 12.86s/it, batch_loss=4.9410, avg_loss=4.9410]

Epoch 2 | Batch 0/4 | Loss: 4.9410 | Avg Loss: 4.9410


Training Epoch 2:  50%|█████     | 2/4 [00:25<00:25, 12.56s/it, batch_loss=4.1837, avg_loss=4.5623]

Epoch 2 | Batch 1/4 | Loss: 4.1837 | Avg Loss: 4.5623


Training Epoch 2:  75%|███████▌  | 3/4 [00:38<00:12, 12.74s/it, batch_loss=3.8051, avg_loss=4.3099]

Epoch 2 | Batch 2/4 | Loss: 3.8051 | Avg Loss: 4.3099


Training Epoch 2: 100%|██████████| 4/4 [00:49<00:00, 12.38s/it, batch_loss=3.5048, avg_loss=4.1087]


Epoch 2 | Batch 3/4 | Loss: 3.5048 | Avg Loss: 4.1087
Saving checkpoint for epoch 2 with average loss: 4.1087


# Inference

In [42]:
def load_trained_model(checkpoint_path):
    p_model = get_model(
        model_id=MODEL_ID,
        use_qlora=USE_QLORA,
        use_8bit=USE_8BIT,
        use_double_quant=USE_DBL_QUANT,
        lora_r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        cache_dir=CACHE_DIR
    )
    checkpoint = torch.load(checkpoint_path)

    # Load only the LoRA weights
    p_model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    return p_model, checkpoint['epoch'], checkpoint['loss']

In [43]:
# Usage:
accelerator = Accelerator()
model, epoch, loss = load_trained_model('./output/checkpoint_epoch_2')
model = accelerator.prepare(model)

Loading checkpoint shards: 100%|██████████| 3/3 [00:20<00:00,  6.79s/it]
  checkpoint = torch.load(checkpoint_path)


In [44]:
print(epoch)

2


In [45]:
print(loss)

tensor(3.5048, device='cuda:0', requires_grad=True)


In [49]:
def generate_for_single_video(model, processor, video_path, accelerator):
    # Set model to evaluation mode
    model.eval()

    # Get frames from the video
    frames = get_frames(video_path, num_frames=16)  # Using 16 frames as in training

    # Convert frames to tensor
    frames_list = [transforms.ToTensor()(frame) for frame in frames]
    frame_tensor = torch.stack(frames_list)

    # Convert to format expected by processor
    images = [img.permute(1, 2, 0).cpu().numpy() for img in frame_tensor]

    # Create prompt
    tmp_prompt = "Translate the sign language to english text."
    prompt = f"USER: <video> {tmp_prompt}\n ASSISTANT: Answer:"

    # Process inputs
    batch = processor(
        text=prompt,
        videos=[images],  # Wrap in list as processor expects batch
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    )

    # Prepare inputs
    input_ids = accelerator.prepare(batch["input_ids"])
    attention_mask = accelerator.prepare(batch["attention_mask"])
    pixel_values_videos = accelerator.prepare(batch["pixel_values_videos"])

    frame_count = pixel_values_videos.shape[1]
    height, width = pixel_values_videos.shape[3], pixel_values_videos.shape[4]
    n_video_tokens = (input_ids == processor.tokenizer.convert_tokens_to_ids("<video>")).sum(dim=1)
    expected_tokens = frame_count * (height // processor.patch_size) * (width // processor.patch_size) // 4
    token_diffs = expected_tokens - n_video_tokens
    
    # Adjust input_ids and attention_mask
    max_length = input_ids.size(1) + max(0, token_diffs.max().item())
    adjusted_input_ids = torch.full((input_ids.size(0), max_length), processor.tokenizer.pad_token_id, device=accelerator.device)
    adjusted_attention_mask = torch.zeros((input_ids.size(0), max_length), device=accelerator.device)
    
    for i in range(input_ids.size(0)):
        current_length = input_ids.size(1)
        diff = token_diffs[i].item()
    
        # Add tokens or truncate as needed
        if diff > 0:
            # Add extra <video> tokens
            adjusted_input_ids[i, :current_length] = input_ids[i]
            adjusted_input_ids[i, current_length:current_length + diff] = processor.tokenizer.convert_tokens_to_ids("<video>")
            adjusted_attention_mask[i, :current_length + diff] = attention_mask[i]
        else:
            # Truncate tokens
            adjusted_input_ids[i, :current_length + diff] = input_ids[i, :current_length + diff]
            adjusted_attention_mask[i, :current_length + diff] = attention_mask[i, :current_length + diff]
    
    # Replace original tensors with adjusted ones
    input_ids = adjusted_input_ids
    attention_mask = adjusted_attention_mask
    
    # Generate text
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values_videos=pixel_values_videos,
            max_length=3500
        )

    # Decode the generated text
    return processor.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

In [47]:
processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"
processor.image_processor.do_rescale = False
processor.video_processor.do_rescale = False

In [None]:
video_path = '/scratch/as18464/raw_videos/--7E2sU6zP4_11-5-rgb_front.mp4'
generated_text = generate_for_single_video(model, processor, video_path, accelerator)

In [None]:
generated_text