<a href="https://colab.research.google.com/github/akash13s/ASL-Interpreter/blob/llava-next/llava_next.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [29]:
!pip install av numpy torch torchvision peft pandas tqdm
!pip install --upgrade bitsandbytes transformers



In [30]:
import av
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BitsAndBytesConfig, LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from tqdm import tqdm
import os
from typing import Tuple
import pandas as pd
from torch.nn.parallel import DataParallel

In [32]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [34]:
# Constants
MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf"
MODEL_NAME = MODEL_ID.split("/")[-1]

# File/directory
VIDEO_DIR = "/scratch/mg7609/ASL-Interpreter/data/raw_videos"
CSV_FILE = "/scratch/mg7609/ASL-Interpreter/data/valid_clips.csv"
CACHE_DIR = "cache/"
DATASET_SIZE = 4

# model constants
BATCH_SIZE = 1
MAX_LENGTH = 3500
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.05

# LoRA hyperparameters
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.1
LORA_TARGET_MODULES = [
    "q_proj",
    "v_proj",
    "k_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]

In [91]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

In [42]:
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir=CACHE_DIR,
    quantization_config=quantization_config
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
p_model = prepare_model_for_kbit_training(model)

In [None]:
# Configure LoRA
peft_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

In [None]:
# Get PEFT model
p_model = get_peft_model(p_model, peft_config)
p_model.print_trainable_parameters()

In [None]:
p_model = p_model.cuda()
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    p_model = DataParallel(p_model)

In [None]:
optimizer = torch.optim.AdamW(p_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
processor = LlavaNextVideoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"
processor.image_processor.do_rescale = False
processor.video_processor.do_rescale = False

In [92]:
def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, 3, height, width).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]

    resize_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            # Convert to numpy array in RGB format
            frame_array = frame.to_ndarray(format="rgb24")
            # Apply resize transform and convert back to numpy
            resized_frame = resize_transform(frame_array).numpy()
            # Scale back to 0-255 range
            resized_frame = (resized_frame.transpose(1, 2, 0) * 255).astype(np.uint8)
            frames.append(resized_frame)

    return np.stack(frames)

def get_frames(video_path: str, num_frames: int = 8) -> np.ndarray:
    """
    Extract frames from video with consistent sampling
    Args:
        video_path (str): Path to video file
        num_frames (int): Number of frames to extract
    Returns:
        np.ndarray: Array of frames with shape (num_frames, 3, height, width)
    """
    container = av.open(video_path)

    # Get video stream
    stream = container.streams.video[0]
    total_frames = stream.frames
    fps = stream.average_rate

    # Calculate indices to sample
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)

    # Read frames at calculated indices
    frames = read_video_pyav(container, indices)

    # Ensure we got exactly num_frames
    if len(frames) < num_frames:
        # If we got fewer frames, duplicate the last frame
        last_frame = frames[-1]
        while len(frames) < num_frames:
            frames = np.concatenate([frames, last_frame[np.newaxis, ...]], axis=0)
    elif len(frames) > num_frames:
        # If we got more frames, take the first num_frames
        frames = frames[:num_frames]

    container.close()

    return frames

In [125]:
class VideoDataset(Dataset):
    def __init__(self, video_dir: str, csv_file: str, num_frames: int = 8):
        self.video_dir = video_dir
        self.annotations = pd.read_csv(csv_file, sep=',').head(DATASET_SIZE).reset_index(drop=True)
        self.num_frames = num_frames
        print(f"Loaded dataset with {len(self.annotations)} entries")

    def __len__(self) -> int:
        return len(self.annotations)

    def __getitem__(self, idx: int) -> Tuple[str, np.ndarray]:
        row = self.annotations.iloc[idx]
        video_id = str(row['SENTENCE_NAME']).strip()
        sentence = str(row['SENTENCE']).strip()

        video_path = os.path.join(self.video_dir, f"{video_id}.mp4")
        if not os.path.isfile(video_path):
            raise FileNotFoundError(f"Video file '{video_path}' not found.")

        frames = get_frames(video_path, self.num_frames)

        tmp_prompt = "Analyze the American Sign Language (ASL) signs in this video and translate them into clear, natural English. Consider the sequence of signs as a complete message, and provide an accurate translation that captures the full meaning. Respond with only the English translation, without descriptions of the signs themselves."
        prompt = f"USER: <video> {tmp_prompt}\nASSISTANT: Answer: {sentence}"

        frames_list = [frame for frame in frames]
        frames_list = [transforms.ToTensor()(frame) for frame in frames_list]
        frame_tensor = torch.stack(frames_list)

        return prompt, frame_tensor

In [None]:
# Create dataset and dataloader
def create_data_loader(video_dir, csv_file, batch_size, num_frames=8):
    dataset = VideoDataset(
        video_dir=video_dir,
        csv_file=csv_file,
        num_frames=num_frames
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # Set to 0 for debugging
        pin_memory=False
    )

    return loader

In [None]:
train_loader = create_data_loader(
    video_dir=VIDEO_DIR,
    csv_file=CSV_FILE,
    batch_size=BATCH_SIZE,
    num_frames = 16
)

In [137]:
def train_epoch(model, train_loader, optimizer, processor, device, epoch):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')


    for batch_idx, (texts, videos) in enumerate(progress_bar):
        vids = list(torch.unbind(videos, dim=0))
        image_lists = []
        for batch in vids:
            images = [img.permute(1, 2, 0).numpy() for img in batch]
            image_lists.append(images)
        try:
            batch = processor(
                text=texts,
                videos=image_lists,
                padding=True,
                truncation=True,
                max_length=MAX_LENGTH,
                return_tensors="pt"
            )

            labels = batch["input_ids"].clone()
            labels[labels == processor.tokenizer.pad_token_id] = -100
            batch["labels"] = labels

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            pixel_values_videos = batch["pixel_values_videos"].to(device)
            labels = batch["labels"].to(device)

            n_video_tokens = (input_ids == processor.tokenizer.convert_tokens_to_ids("<video>")).sum().item()
            frame_count = pixel_values_videos.shape[1]
            height, width = pixel_values_videos.shape[3], pixel_values_videos.shape[4]
            expected_tokens = frame_count * (height // processor.patch_size) * (width // processor.patch_size) // 4
            
            if n_video_tokens != expected_tokens:
                # Adjust attention_mask
                adjusted_attention_mask = torch.ones((1, input_ids.size(1) + expected_tokens - n_video_tokens), device=device)
                adjusted_attention_mask[:, :input_ids.size(1)] = attention_mask
                attention_mask = adjusted_attention_mask

                # Adjust input_ids
                if n_video_tokens < expected_tokens:
                    extra_tokens = expected_tokens - n_video_tokens
                    new_tokens = torch.full((1, extra_tokens), processor.tokenizer.convert_tokens_to_ids("<video>"), device=device)
                    input_ids = torch.cat([input_ids, new_tokens], dim=-1)
                elif n_video_tokens > expected_tokens:
                    mask = input_ids != processor.tokenizer.convert_tokens_to_ids("<video>")
                    input_ids = input_ids[mask]
                    input_ids = input_ids[:, :expected_tokens]  # Truncate to expected length

                # Adjust labels
                adjusted_labels = torch.full_like(input_ids, -100)  # Start with all tokens ignored
                adjusted_labels[:, :labels.size(1)] = labels
                labels = adjusted_labels

            optimizer.zero_grad()

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pixel_values_videos=pixel_values_videos,
                labels=labels
            )

            loss = outputs.loss

            if isinstance(model, DataParallel):
                loss = loss.mean()

            loss.backward()
            # torch.nn.utils.clip_gradnorm(model.parameters(), 1.0)
            optimizer.step()

            current_loss = loss.item()
            total_loss += current_loss
            avg_loss = total_loss / (batch_idx + 1)

            progress_bar.set_postfix({
                'batch_loss': f'{current_loss:.4f}',
                'avg_loss': f'{avg_loss:.4f}'
            })

            # Print detailed loss information for each iteration
            print(f'Epoch {epoch} | Batch {batch_idx}/{len(train_loader)} | '
                  f'Loss: {current_loss:.4f} | Avg Loss: {avg_loss:.4f}')

        except Exception as e:
            raise e

    # Save checkpoint after every epoch
    checkpoint_path = f"output/checkpoint_epoch_{epoch}"
    os.makedirs("output", exist_ok=True)  # Ensure output directory exists

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.module.state_dict() if isinstance(model, DataParallel) else model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss
    }

    print(f"Saving checkpoint for epoch {epoch} with average loss: {avg_loss:.4f}")
    torch.save(checkpoint, checkpoint_path)

    return total_loss / len(train_loader)

In [None]:
for i in range(4):
    train_epoch(p_model, train_loader, optimizer, processor, device, i)

Training Epoch 0:   0%|          | 0/4 [00:00<?, ?it/s]

Adjusting <video> tokens from 345 to 2304


Training Epoch 0:  25%|██▌       | 1/4 [00:37<01:51, 37.07s/it, batch_loss=4.0374, avg_loss=4.0374]

Epoch 0 | Batch 0/4 | Loss: 4.0374 | Avg Loss: 4.0374
Adjusting <video> tokens from 345 to 2304


Training Epoch 0:  50%|█████     | 2/4 [01:12<01:11, 35.83s/it, batch_loss=4.0029, avg_loss=4.0201]

Epoch 0 | Batch 1/4 | Loss: 4.0029 | Avg Loss: 4.0201
