In [None]:
import av
import bisect
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, BitsAndBytesConfig, VideoLlavaForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
import json
import torch.nn.utils.prune as prune
from tqdm import tqdm
import os
from typing import Tuple, Any
import pandas as pd
from torch.cuda.amp import autocast
# from torch.nn.parallel import DataParallel

from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs

In [2]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [3]:
torch.cuda.empty_cache()

In [4]:
# Constants
MODEL_ID = "LanguageBind/Video-LLaVA-7B-hf"
MODEL_NAME = MODEL_ID.split("/")[-1]

In [5]:
# File/directory
VIDEO_DIR = "/scratch/as18464/raw_videos"
CSV_FILE = "valid_clips.csv"
CACHE_DIR = "cache/"
DATASET_SIZE = 4

In [6]:
# LoRA hyperparameters
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.1
LORA_TARGET_MODULES = [
    "q_proj",
    "v_proj",
    "k_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]

In [7]:
# model constants
BATCH_SIZE = 4
MAX_LENGTH = 128
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.05

In [8]:
def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]

    resize_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            # Convert to numpy array in RGB format
            frame_array = frame.to_ndarray(format="rgb24")
            # Apply resize transform and convert back to numpy
            resized_frame = resize_transform(frame_array).numpy()
            # Convert from CxHxW to HxWxC format and scale back to 0-255 range
            resized_frame = (resized_frame.transpose(1, 2, 0) * 255).astype(np.uint8)
            frames.append(resized_frame)
    
    return np.stack(frames)

def get_frames(video_path: str, num_frames: int = 8) -> np.ndarray:
    """
    Extract frames from video with consistent sampling
    Args:
        video_path (str): Path to video file
        num_frames (int): Number of frames to extract
    Returns:
        np.ndarray: Array of frames with shape (num_frames, height, width, 3)
    """
    container = av.open(video_path)
    
    # Get video stream
    stream = container.streams.video[0]
    total_frames = stream.frames
    fps = stream.average_rate
    
    # Calculate indices to sample
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    
    # Read frames at calculated indices
    frames = read_video_pyav(container, indices)
    
    # Ensure we got exactly num_frames
    if len(frames) < num_frames:
        # If we got fewer frames, duplicate the last frame
        last_frame = frames[-1]
        while len(frames) < num_frames:
            frames = np.concatenate([frames, last_frame[np.newaxis, ...]], axis=0)
    elif len(frames) > num_frames:
        # If we got more frames, take the first num_frames
        frames = frames[:num_frames]
    
    container.close()
    return frames

In [9]:
class VideoDataset(Dataset):
    def __init__(self, video_dir: str, csv_file: str, num_frames: int = 8):
        self.video_dir = video_dir
        self.annotations = pd.read_csv(csv_file, sep=',').head(DATASET_SIZE).reset_index(drop=True)
        self.num_frames = num_frames
        self.system_prompt = ("Analyze the American Sign Language (ASL) signs in this video and "
                            "translate them into clear, natural English. Consider the sequence of "
                            "signs as a complete message, and provide an accurate translation that "
                            "captures the full meaning. Respond with only the English translation, "
                            "without descriptions of the signs themselves.")
        
        print(f"Loaded dataset with {len(self.annotations)} entries")
    
    def __len__(self) -> int:
        return len(self.annotations)
    
    def __getitem__(self, idx: int) -> Tuple[str, np.ndarray]:
        row = self.annotations.iloc[idx]
        video_id = str(row['SENTENCE_NAME']).strip()
        sentence = str(row['SENTENCE']).strip()
        
        video_path = os.path.join(self.video_dir, f"{video_id}.mp4")
        if not os.path.isfile(video_path):
            raise FileNotFoundError(f"Video file '{video_path}' not found.")
        
        frames = get_frames(video_path, self.num_frames)
        
        prompt = f"USER: {self.system_prompt}\n<video>\nASSISTANT: {sentence}"

        frames_list = [frame for frame in frames]
        frames_list = [transforms.ToTensor()(frame) for frame in frames_list]
        frame_tensor = torch.stack(frames_list)
        
        return prompt, frame_tensor

In [42]:
def train_epoch(model, train_loader, optimizer, processor, accelerator, epoch):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    
    for batch_idx, (texts, videos) in enumerate(progress_bar):
        vids = list(torch.unbind(videos, dim=0))
        image_lists = []
        for batch in vids:
            images = [img.cpu().permute(1, 2, 0).numpy() for img in batch]
            image_lists.append(images)
        try:
            batch = processor(
                text=texts,
                videos=image_lists,
                padding=True,
                truncation=True,
                max_length=MAX_LENGTH,
                return_tensors="pt",
            )
            
            labels = batch["input_ids"].clone()
            labels[labels == processor.tokenizer.pad_token_id] = -100

            for i, text in enumerate(texts):
                assistant_start = None
                # Look for sequence: "ASSISTANT:"
                for j in range(len(batch["input_ids"][i])):
                    if processor.tokenizer.decode(batch["input_ids"][i][j:j+4]) == "ASSISTANT:":
                        assistant_start = j
                        break
                
                if assistant_start is not None:
                    # Mask everything before and including "ASSISTANT:"
                    labels[i, :assistant_start+4] = -100

            # To remove later - for debugging
            # print("\n====== Tokens and Labels for Batch", batch_idx, "======")
            for i, text in enumerate(texts):
                print(f"\nOriginal text {i}: {text}")
                print("\nTokens and their labels:")
                tokens = processor.tokenizer.convert_ids_to_tokens(batch["input_ids"][i])
                for j, (token, label) in enumerate(zip(tokens, labels[i])):
                    print(f"Position {j:3d} | Token: {token:15} | Label: {label.item():5}")
                print("-" * 50)
            
            batch["labels"] = labels
            
            input_ids = accelerator.prepare(batch["input_ids"])
            attention_mask = accelerator.prepare(batch["attention_mask"])
            pixel_values_videos = accelerator.prepare(batch["pixel_values_videos"])
            labels = accelerator.prepare(batch["labels"])
            
            optimizer.zero_grad()
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pixel_values_videos=pixel_values_videos,
                labels=labels
            )
            loss = outputs.loss

            accelerator.backward(loss)
            
            # torch.nn.utils.clip_gradnorm(model.parameters(), 1.0)
            optimizer.step()

            current_loss = loss.item()
            total_loss += current_loss
            avg_loss = total_loss / (batch_idx + 1)

            progress_bar.set_postfix({
                'batch_loss': f'{current_loss:.4f}',
                'avg_loss': f'{avg_loss:.4f}'
            })

            if accelerator.is_main_process:
                print(f'Epoch {epoch} | Batch {batch_idx}/{len(train_loader)} | '
                      f'Loss: {current_loss:.4f} | Avg Loss: {avg_loss:.4f}')
            
        except Exception as e:
            raise e

    if accelerator.is_main_process and epoch%5 == 0:
        checkpoint_path = f"output/checkpoint_epoch_{epoch}"
        os.makedirs("output", exist_ok=True)
        
        unwrapped_model = accelerator.unwrap_model(model)
        
        if hasattr(unwrapped_model, 'get_peft_state_dict'):
            state_dict = unwrapped_model.get_peft_state_dict()
        else:
            state_dict = unwrapped_model.state_dict()
    
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }
        
        print(f"Saving checkpoint for epoch {epoch} with average loss: {avg_loss:.4f}")
        torch.save(checkpoint, checkpoint_path)
    
    return total_loss / len(train_loader)

In [27]:
# Create dataset and dataloader
def create_data_loader(video_dir, csv_file, batch_size, num_frames=8):
    dataset = VideoDataset(
        video_dir=video_dir,
        csv_file=csv_file,
        num_frames=num_frames
    )
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # Set to 0 for debugging
        pin_memory=False
    )
    
    return loader

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [29]:
train_loader = create_data_loader(
    video_dir=VIDEO_DIR,
    csv_file=CSV_FILE,
    batch_size=BATCH_SIZE,
    num_frames=16
)

Loaded dataset with 4 entries


In [30]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

In [31]:
model = VideoLlavaForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir=CACHE_DIR,
    quantization_config=quantization_config
)

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [32]:
p_model = prepare_model_for_kbit_training(model)

In [33]:
# Configure LoRA
peft_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

In [34]:
# Get PEFT model
p_model = get_peft_model(p_model, peft_config)
p_model.print_trainable_parameters()

trainable params: 22,347,776 || all params: 7,388,626,944 || trainable%: 0.3025


In [35]:
def set_trainable_params(model):
    # First make sure all parameters are not trainable
    for param in model.parameters():
        param.requires_grad = False
    
    # Then enable training only for the LoRA parameters
    for name, param in model.named_parameters():
        if "lora_" in name:  # This targets only the LoRA layers
            param.requires_grad = True

In [36]:
set_trainable_params(p_model)

In [37]:
optimizer = torch.optim.AdamW(p_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [38]:
# initialize the accelerator with the right kwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

In [39]:
# Prepare your model, optimizer, and dataloader with accelerator
p_model, optimizer, train_loader = accelerator.prepare(
    p_model, optimizer, train_loader
)

In [40]:
processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"
processor.image_processor.do_rescale = False

In [43]:
for i in range(10):
    train_epoch(p_model, train_loader, optimizer, processor, accelerator, i+1)

Training Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s]


Original text 0: USER: Analyze the American Sign Language (ASL) signs in this video and translate them into clear, natural English. Consider the sequence of signs as a complete message, and provide an accurate translation that captures the full meaning. Respond with only the English translation, without descriptions of the signs themselves.
<video>
ASSISTANT: This is all the you know, take off on the idea of the acanthus leaf.

Tokens and their labels:
Position   0 | Token: <s>             | Label:  -100
Position   1 | Token: ▁US             | Label:  -100
Position   2 | Token: ER              | Label:  -100
Position   3 | Token: :               | Label:  -100
Position   4 | Token: ▁Anal           | Label:  -100
Position   5 | Token: y               | Label:  -100
Position   6 | Token: ze              | Label:  -100
Position   7 | Token: ▁the            | Label:  -100
Position   8 | Token: ▁American       | Label:  -100
Position   9 | Token: ▁Sign           | Label:  -100
Position  10

Training Epoch 1:   0%|          | 0/1 [00:52<?, ?it/s]


KeyboardInterrupt: 

### Inference step

In [12]:
def load_trained_model(checkpoint_path, accelerator):
    # First create base model
    model = VideoLlavaForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16,
        cache_dir=CACHE_DIR,
        quantization_config=quantization_config
    )
    
    # Prepare model with LoRA (same as during training)
    p_model = prepare_model_for_kbit_training(model)
    
    # Configure LoRA
    peft_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=LORA_TARGET_MODULES,
        lora_dropout=LORA_DROPOUT,
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )
    
    # Apply LoRA
    p_model = get_peft_model(p_model, peft_config)
    
    checkpoint = torch.load(checkpoint_path)
    
    # Load only the LoRA weights
    p_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    
    return p_model, checkpoint['epoch'], checkpoint['loss']

In [14]:
# Usage:
accelerator = Accelerator()
model, epoch, loss = load_trained_model('output/checkpoint_epoch_20', accelerator)
model = accelerator.prepare(model)

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  checkpoint = torch.load(checkpoint_path)


In [15]:
print(epoch)

20


In [16]:
print(loss)

tensor(0.0271, requires_grad=True)


In [42]:
def generate_for_single_video(model, processor, video_path, accelerator):
    # Set model to evaluation mode
    model.eval()
    
    # Get frames from the video
    frames = get_frames(video_path, num_frames=16)  # Using 16 frames as in training
    
    # Convert frames to tensor
    frames_list = [transforms.ToTensor()(frame) for frame in frames]
    frame_tensor = torch.stack(frames_list)
    
    # Convert to format expected by processor
    images = [img.permute(1, 2, 0).cpu().numpy() for img in frame_tensor]
    
    # Create prompt
    tmp_prompt = "Translate the sign language to english text."        
    prompt = f"USER: <video> {tmp_prompt}\n ASSISTANT: Answer:"
    
    # Process inputs
    batch = processor(
        text=prompt,
        videos=[images],  # Wrap in list as processor expects batch
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    )
    
    # Prepare inputs
    input_ids = accelerator.prepare(batch["input_ids"])
    attention_mask = accelerator.prepare(batch["attention_mask"])
    pixel_values_videos = accelerator.prepare(batch["pixel_values_videos"])
    
    # Generate text
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values_videos=pixel_values_videos,
            max_length=200
        )
    
    # Decode the generated text
    generated_text = processor.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return generated_text

In [43]:
processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"
processor.image_processor.do_rescale = False

In [48]:
video_path = '/scratch/as18464/raw_videos/-06_nJnhORg_3-5-rgb_front.mp4'
generated_text = generate_for_single_video(model, processor, video_path, accelerator)

In [49]:
generated_text

['USER:  Translate the sign language to english text.\n ASSISTANT: Answer: Gener']

In [24]:
# video_path = 'test/--7E2sU6zP4_12-5-rgb_front.mp4'
# container = av.open(video_path)
# total_frames = container.streams.video[0].frames
# indices = np.arange(0, total_frames, 16).astype(int)
# clip = read_video_pyav(container, indices)

# tmp_prompt = "Analyze the American Sign Language (ASL) signs in this video and translate them into clear, natural English. Consider the sequence of signs as a complete message, and provide an accurate translation that captures the full meaning. Respond with only the English translation, without descriptions of the signs themselves."        
# prompt = f"USER: <video> {tmp_prompt}\n ASSISTANT: Answer:"

# inputs = processor(text=prompt, videos=clip, return_tensors="pt")

# # Generate
# generate_ids = model.generate(**inputs, max_length=100)
# print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

USER:  Analyze the American Sign Language (ASL) signs in this video and translate them into clear, natural English. Consider the sequence of signs as a complete message, and provide an accurate translation that captures the full meaning. Respond with only the English translation, without descriptions of the signs themselves.
 ASSISTANT: Answer: andЪЪЪЪЪЪЪЪЪЪЪЪЪЪЪЪЪЪЪЪЪЪЪ
