In [1]:
import av
import bisect
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, BitsAndBytesConfig, VideoLlavaForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
import json
import torch.nn.utils.prune as prune
from tqdm import tqdm
import os
from typing import Tuple, Any
import pandas as pd
from torch.cuda.amp import autocast
from torch.nn.parallel import DataParallel

In [2]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [3]:
torch.cuda.empty_cache()

In [4]:
# Constants
MODEL_ID = "LanguageBind/Video-LLaVA-7B-hf"
MODEL_NAME = MODEL_ID.split("/")[-1]

In [5]:
# File/directory
VIDEO_DIR = "/scratch/as18464/raw_videos"
CSV_FILE = "valid_clips.csv"
CACHE_DIR = "cache/"
DATASET_SIZE = 5000

In [6]:
# LoRA hyperparameters
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.1
LORA_TARGET_MODULES = [
    "q_proj",
    "v_proj",
    "k_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]

In [7]:
# model constants
BATCH_SIZE = 4
MAX_LENGTH = 350
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01

In [8]:
def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]

    resize_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            # Convert to numpy array in RGB format
            frame_array = frame.to_ndarray(format="rgb24")
            # Apply resize transform and convert back to numpy
            resized_frame = resize_transform(frame_array).numpy()
            # Convert from CxHxW to HxWxC format and scale back to 0-255 range
            resized_frame = (resized_frame.transpose(1, 2, 0) * 255).astype(np.uint8)
            frames.append(resized_frame)
    
    return np.stack(frames)

def get_frames(video_path: str, num_frames: int = 8) -> np.ndarray:
    """
    Extract frames from video with consistent sampling
    Args:
        video_path (str): Path to video file
        num_frames (int): Number of frames to extract
    Returns:
        np.ndarray: Array of frames with shape (num_frames, height, width, 3)
    """
    container = av.open(video_path)
    
    # Get video stream
    stream = container.streams.video[0]
    total_frames = stream.frames
    fps = stream.average_rate
    
    # Calculate indices to sample
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    
    # Read frames at calculated indices
    frames = read_video_pyav(container, indices)
    
    # Ensure we got exactly num_frames
    if len(frames) < num_frames:
        # If we got fewer frames, duplicate the last frame
        last_frame = frames[-1]
        while len(frames) < num_frames:
            frames = np.concatenate([frames, last_frame[np.newaxis, ...]], axis=0)
    elif len(frames) > num_frames:
        # If we got more frames, take the first num_frames
        frames = frames[:num_frames]
    
    container.close()
    return frames

In [9]:
class VideoDataset(Dataset):
    def __init__(self, video_dir: str, csv_file: str, num_frames: int = 8):
        self.video_dir = video_dir
        self.annotations = pd.read_csv(csv_file, sep=',').head(DATASET_SIZE).reset_index(drop=True)
        self.num_frames = num_frames
        print(f"Loaded dataset with {len(self.annotations)} entries")
    
    def __len__(self) -> int:
        return len(self.annotations)
    
    def __getitem__(self, idx: int) -> Tuple[str, np.ndarray]:
        row = self.annotations.iloc[idx]
        video_id = str(row['SENTENCE_NAME']).strip()
        sentence = str(row['SENTENCE']).strip()
        
        video_path = os.path.join(self.video_dir, f"{video_id}.mp4")
        if not os.path.isfile(video_path):
            raise FileNotFoundError(f"Video file '{video_path}' not found.")
        
        frames = get_frames(video_path, self.num_frames)

        tmp_prompt = "Translate the American Sign Language (ASL) demonstrated in the video to English text, where each frame shows ASL signs used at different time points chronologically."
        
        prompt = f"USER: <video> {tmp_prompt}\nASSISTANT: Answer: {sentence}"

        frames_list = [frame for frame in frames]
        frames_list = [transforms.ToTensor()(frame) for frame in frames_list]
        frame_tensor = torch.stack(frames_list)
        
        return prompt, frame_tensor

In [10]:
def train_epoch(model, train_loader, optimizer, processor, device, epoch):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    
    for batch_idx, (texts, videos) in enumerate(progress_bar):
        vids = list(torch.unbind(videos, dim=0))
        image_lists = []
        for batch in vids:
            images = [img.permute(1, 2, 0).numpy() for img in batch]
            image_lists.append(images)
        try:
            batch = processor(
                text=texts,
                videos=image_lists,
                padding=True,
                truncation=True,
                max_length=MAX_LENGTH,
                return_tensors="pt"
            )
            
            labels = batch["input_ids"].clone()
            labels[labels == processor.tokenizer.pad_token_id] = -100
            batch["labels"] = labels
            
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            pixel_values_videos = batch["pixel_values_videos"].to(device)
            labels = batch["labels"].to(device)
            
            optimizer.zero_grad()
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pixel_values_videos=pixel_values_videos,
                labels=labels
            )
            loss = outputs.loss

            if isinstance(model, DataParallel):
                loss = loss.mean()
            
            loss.backward()
            # torch.nn.utils.clip_gradnorm(model.parameters(), 1.0)
            optimizer.step()

            current_loss = loss.item()
            total_loss += current_loss
            avg_loss = total_loss / (batch_idx + 1)

            progress_bar.set_postfix({
                'batch_loss': f'{current_loss:.4f}',
                'avg_loss': f'{avg_loss:.4f}'
            })

            # Print detailed loss information for each iteration
            print(f'Epoch {epoch} | Batch {batch_idx}/{len(train_loader)} | '
                  f'Loss: {current_loss:.4f} | Avg Loss: {avg_loss:.4f}')
            
        except Exception as e:
            raise e

    # Save checkpoint after every epoch
    checkpoint_path = f"output/checkpoint_epoch_{epoch}"
    os.makedirs("output", exist_ok=True)  # Ensure output directory exists
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.module.state_dict() if isinstance(model, DataParallel) else model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss
    }
    
    print(f"Saving checkpoint for epoch {epoch} with average loss: {avg_loss:.4f}")
    torch.save(checkpoint, checkpoint_path)
    
    return total_loss / len(train_loader)

In [11]:
# Create dataset and dataloader
def create_data_loader(video_dir, csv_file, batch_size, num_frames=8):
    dataset = VideoDataset(
        video_dir=video_dir,
        csv_file=csv_file,
        num_frames=num_frames
    )
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # Set to 0 for debugging
        pin_memory=False
    )
    
    return loader

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
train_loader = create_data_loader(
    video_dir=VIDEO_DIR,
    csv_file=CSV_FILE,
    batch_size=BATCH_SIZE,
    num_frames=16
)

Loaded dataset with 10000 entries


In [14]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

In [15]:
model = VideoLlavaForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir=CACHE_DIR,
    quantization_config=quantization_config
)

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [16]:
p_model = prepare_model_for_kbit_training(model)

In [17]:
# Configure LoRA
peft_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=LORA_TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

In [18]:
# Get PEFT model
p_model = get_peft_model(p_model, peft_config)
p_model.print_trainable_parameters()

trainable params: 22,347,776 || all params: 7,388,626,944 || trainable%: 0.3025


In [19]:
def set_trainable_params(model):
    # First make sure all parameters are not trainable
    for param in model.parameters():
        param.requires_grad = False
    
    # Then enable training only for the LoRA parameters
    for name, param in model.named_parameters():
        if "lora_" in name:  # This targets only the LoRA layers
            param.requires_grad = True

In [20]:
set_trainable_params(p_model)

In [21]:
p_model = p_model.cuda()
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    p_model = DataParallel(p_model)

In [22]:
optimizer = torch.optim.AdamW(p_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"
processor.image_processor.do_rescale = False

In [23]:
for i in range(4):
    train_epoch(p_model, train_loader, optimizer, processor, device, i)

Training Epoch 0:   0%|          | 0/2500 [00:00<?, ?it/s]Expanding inputs for image tokens in Video-LLaVa should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.44.
  return fn(*args, **kwargs)
Expanding inputs for image tokens in Video-LLaVa should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.
`use_cache=True` is incompatible with gradient checkpointing. Setting

Epoch 0 | Batch 0/2500 | Loss: 4.2296 | Avg Loss: 4.2296


Training Epoch 0:   0%|          | 2/2500 [02:45<57:18:45, 82.60s/it, batch_loss=3.9798, avg_loss=4.1047]

Epoch 0 | Batch 1/2500 | Loss: 3.9798 | Avg Loss: 4.1047


Training Epoch 0:   0%|          | 3/2500 [04:07<57:01:17, 82.21s/it, batch_loss=3.4986, avg_loss=3.9026]

Epoch 0 | Batch 2/2500 | Loss: 3.4986 | Avg Loss: 3.9026


Training Epoch 0:   0%|          | 4/2500 [05:29<56:48:20, 81.93s/it, batch_loss=3.3622, avg_loss=3.7675]

Epoch 0 | Batch 3/2500 | Loss: 3.3622 | Avg Loss: 3.7675


Training Epoch 0:   0%|          | 5/2500 [06:51<56:48:58, 81.98s/it, batch_loss=3.0351, avg_loss=3.6210]

Epoch 0 | Batch 4/2500 | Loss: 3.0351 | Avg Loss: 3.6210


Training Epoch 0:   0%|          | 6/2500 [08:14<57:05:41, 82.41s/it, batch_loss=2.7090, avg_loss=3.4690]

Epoch 0 | Batch 5/2500 | Loss: 2.7090 | Avg Loss: 3.4690


Training Epoch 0:   0%|          | 7/2500 [09:37<57:07:15, 82.49s/it, batch_loss=2.3969, avg_loss=3.3159]

Epoch 0 | Batch 6/2500 | Loss: 2.3969 | Avg Loss: 3.3159


Training Epoch 0:   0%|          | 8/2500 [10:58<56:53:45, 82.19s/it, batch_loss=2.2422, avg_loss=3.1817]

Epoch 0 | Batch 7/2500 | Loss: 2.2422 | Avg Loss: 3.1817


Training Epoch 0:   0%|          | 9/2500 [12:19<56:33:22, 81.74s/it, batch_loss=1.5686, avg_loss=3.0024]

Epoch 0 | Batch 8/2500 | Loss: 1.5686 | Avg Loss: 3.0024


Training Epoch 0:   0%|          | 10/2500 [13:40<56:24:15, 81.55s/it, batch_loss=1.4513, avg_loss=2.8473]

Epoch 0 | Batch 9/2500 | Loss: 1.4513 | Avg Loss: 2.8473


Training Epoch 0:   0%|          | 11/2500 [15:02<56:23:39, 81.57s/it, batch_loss=1.1989, avg_loss=2.6975]

Epoch 0 | Batch 10/2500 | Loss: 1.1989 | Avg Loss: 2.6975


Training Epoch 0:   0%|          | 12/2500 [16:23<56:20:52, 81.53s/it, batch_loss=1.5045, avg_loss=2.5981]

Epoch 0 | Batch 11/2500 | Loss: 1.5045 | Avg Loss: 2.5981


Training Epoch 0:   1%|          | 13/2500 [17:44<56:15:36, 81.44s/it, batch_loss=1.4299, avg_loss=2.5082]

Epoch 0 | Batch 12/2500 | Loss: 1.4299 | Avg Loss: 2.5082


Training Epoch 0:   1%|          | 14/2500 [19:06<56:15:43, 81.47s/it, batch_loss=1.2022, avg_loss=2.4149]

Epoch 0 | Batch 13/2500 | Loss: 1.2022 | Avg Loss: 2.4149


Training Epoch 0:   1%|          | 15/2500 [20:28<56:20:59, 81.63s/it, batch_loss=1.4260, avg_loss=2.3490]

Epoch 0 | Batch 14/2500 | Loss: 1.4260 | Avg Loss: 2.3490


Training Epoch 0:   1%|          | 16/2500 [21:49<56:12:53, 81.47s/it, batch_loss=1.2469, avg_loss=2.2801]

Epoch 0 | Batch 15/2500 | Loss: 1.2469 | Avg Loss: 2.2801


Training Epoch 0:   1%|          | 17/2500 [23:10<56:07:36, 81.38s/it, batch_loss=1.0560, avg_loss=2.2081]

Epoch 0 | Batch 16/2500 | Loss: 1.0560 | Avg Loss: 2.2081


Training Epoch 0:   1%|          | 18/2500 [24:32<56:10:28, 81.48s/it, batch_loss=1.3413, avg_loss=2.1599]

Epoch 0 | Batch 17/2500 | Loss: 1.3413 | Avg Loss: 2.1599


Training Epoch 0:   1%|          | 19/2500 [25:53<56:04:39, 81.37s/it, batch_loss=1.1597, avg_loss=2.1073]

Epoch 0 | Batch 18/2500 | Loss: 1.1597 | Avg Loss: 2.1073


Training Epoch 0:   1%|          | 20/2500 [27:14<56:06:06, 81.44s/it, batch_loss=1.0291, avg_loss=2.0534]

Epoch 0 | Batch 19/2500 | Loss: 1.0291 | Avg Loss: 2.0534


Training Epoch 0:   1%|          | 21/2500 [28:35<55:56:11, 81.23s/it, batch_loss=0.8366, avg_loss=1.9954]

Epoch 0 | Batch 20/2500 | Loss: 0.8366 | Avg Loss: 1.9954


Training Epoch 0:   1%|          | 22/2500 [29:56<55:50:52, 81.14s/it, batch_loss=0.8919, avg_loss=1.9453]

Epoch 0 | Batch 21/2500 | Loss: 0.8919 | Avg Loss: 1.9453


Training Epoch 0:   1%|          | 23/2500 [31:19<56:07:40, 81.57s/it, batch_loss=1.4768, avg_loss=1.9249]

Epoch 0 | Batch 22/2500 | Loss: 1.4768 | Avg Loss: 1.9249


Training Epoch 0:   1%|          | 24/2500 [32:41<56:12:34, 81.73s/it, batch_loss=1.1990, avg_loss=1.8947]

Epoch 0 | Batch 23/2500 | Loss: 1.1990 | Avg Loss: 1.8947


Training Epoch 0:   1%|          | 25/2500 [34:03<56:13:06, 81.77s/it, batch_loss=1.1012, avg_loss=1.8629]

Epoch 0 | Batch 24/2500 | Loss: 1.1012 | Avg Loss: 1.8629


Training Epoch 0:   1%|          | 26/2500 [35:23<55:59:35, 81.48s/it, batch_loss=0.8286, avg_loss=1.8231]

Epoch 0 | Batch 25/2500 | Loss: 0.8286 | Avg Loss: 1.8231


Training Epoch 0:   1%|          | 27/2500 [36:45<55:54:17, 81.38s/it, batch_loss=0.8177, avg_loss=1.7859]

Epoch 0 | Batch 26/2500 | Loss: 0.8177 | Avg Loss: 1.7859


Training Epoch 0:   1%|          | 28/2500 [38:06<55:56:17, 81.46s/it, batch_loss=1.1735, avg_loss=1.7640]

Epoch 0 | Batch 27/2500 | Loss: 1.1735 | Avg Loss: 1.7640


Training Epoch 0:   1%|          | 29/2500 [39:27<55:49:41, 81.34s/it, batch_loss=0.9544, avg_loss=1.7361]

Epoch 0 | Batch 28/2500 | Loss: 0.9544 | Avg Loss: 1.7361


Training Epoch 0:   1%|          | 30/2500 [40:49<55:53:45, 81.47s/it, batch_loss=1.2203, avg_loss=1.7189]

Epoch 0 | Batch 29/2500 | Loss: 1.2203 | Avg Loss: 1.7189


Training Epoch 0:   1%|          | 31/2500 [42:10<55:48:54, 81.38s/it, batch_loss=1.1651, avg_loss=1.7011]

Epoch 0 | Batch 30/2500 | Loss: 1.1651 | Avg Loss: 1.7011


Training Epoch 0:   1%|▏         | 32/2500 [43:33<56:02:13, 81.74s/it, batch_loss=1.3948, avg_loss=1.6915]

Epoch 0 | Batch 31/2500 | Loss: 1.3948 | Avg Loss: 1.6915


Training Epoch 0:   1%|▏         | 33/2500 [44:55<56:03:10, 81.80s/it, batch_loss=1.2959, avg_loss=1.6795]

Epoch 0 | Batch 32/2500 | Loss: 1.2959 | Avg Loss: 1.6795


Training Epoch 0:   1%|▏         | 34/2500 [46:16<55:56:38, 81.67s/it, batch_loss=0.9896, avg_loss=1.6592]

Epoch 0 | Batch 33/2500 | Loss: 0.9896 | Avg Loss: 1.6592


Training Epoch 0:   1%|▏         | 35/2500 [47:38<55:51:23, 81.58s/it, batch_loss=1.0895, avg_loss=1.6429]

Epoch 0 | Batch 34/2500 | Loss: 1.0895 | Avg Loss: 1.6429


Training Epoch 0:   1%|▏         | 36/2500 [49:01<56:19:09, 82.28s/it, batch_loss=1.5341, avg_loss=1.6399]

Epoch 0 | Batch 35/2500 | Loss: 1.5341 | Avg Loss: 1.6399


Training Epoch 0:   1%|▏         | 37/2500 [50:22<56:02:13, 81.91s/it, batch_loss=1.0660, avg_loss=1.6244]

Epoch 0 | Batch 36/2500 | Loss: 1.0660 | Avg Loss: 1.6244


Training Epoch 0:   2%|▏         | 38/2500 [51:44<55:58:45, 81.85s/it, batch_loss=1.0972, avg_loss=1.6105]

Epoch 0 | Batch 37/2500 | Loss: 1.0972 | Avg Loss: 1.6105


Training Epoch 0:   2%|▏         | 39/2500 [53:05<55:43:41, 81.52s/it, batch_loss=0.7935, avg_loss=1.5896]

Epoch 0 | Batch 38/2500 | Loss: 0.7935 | Avg Loss: 1.5896


Training Epoch 0:   2%|▏         | 39/2500 [53:20<56:05:57, 82.06s/it, batch_loss=0.7935, avg_loss=1.5896]


KeyboardInterrupt: 