# Federated Learning with LLaVA-NeXT-Video (Jupyter Notebook)
**Goal**: Fine-tune a violence detection model across decentralized clients without sharing raw video data.

## Key Features:
- Uses **QLoRA** (4-bit quantization) for efficient federated training.
- **Flower** framework for federated averaging.
- Simulates 2 clients + 1 server in one notebook.

In [12]:
!pip install decord
!pip install flwr
!pip install bitsandbytes
!pip install -q kagglehub



In [13]:
import torch
import os
import json
import numpy as np
import kagglehub
from pathlib import Path
from random import sample

from decord import VideoReader, cpu
from dataclasses import dataclass
from transformers import (
    AutoProcessor,
    LlavaNextVideoForConditionalGeneration,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
import flwr as fl
from flwr.server import ServerConfig, start_server
from flwr.server.strategy import FedAvg
import threading

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Decord: {VideoReader.__module__} loaded successfully")

test_tensor = torch.randn(3,3).cuda()
print(f"\nCUDA Test Tensor: {test_tensor.mean().item()}")

print(f"\nFlower version: {fl.__version__}")

import bitsandbytes
print(f"bitsandbytes version: {bitsandbytes.__version__}")

PyTorch: 2.6.0+cu124
CUDA: 12.4
GPU: Tesla T4
Decord: decord.video_reader loaded successfully

CUDA Test Tensor: 0.7448925375938416

Flower version: 1.17.0
bitsandbytes version: 0.45.5


In [14]:
# Download dataset
path = kagglehub.dataset_download("yash07yadav/project-data")
print(f"Dataset downloaded to: {path}")

# Define paths
base_path = Path(path) / "Complete Dataset" / "train"
fight_dir = base_path / "Fight"
non_fight_dir = base_path / "NonFight"

# Verify download
print(f"\nFound {len(list(fight_dir.glob('*')))} fight videos")
print(f"Found {len(list(non_fight_dir.glob('*')))} non-fight videos")



# %% [Cell 3] Modified Dataset Utilities


# %% [Cell 4] Updated Client Class


    # Rest of the client methods remain the same...

Downloading from https://www.kaggle.com/api/v1/datasets/download/yash07yadav/project-data?dataset_version_number=1...


100%|██████████| 12.4G/12.4G [07:20<00:00, 30.1MB/s]

Extracting files...





Dataset downloaded to: /root/.cache/kagglehub/datasets/yash07yadav/project-data/versions/1

Found 1000 fight videos
Found 1000 non-fight videos


In [15]:
def create_client_files(client_count=2, samples_per_client=8):
    # Get all video paths
    fight_videos = [str(f) for f in fight_dir.glob('*')]
    nonfight_videos = [str(f) for f in non_fight_dir.glob('*')]

    # Create balanced datasets for clients
    for client_id in range(client_count):
        client_data = {
            "videos": (
                [{"path": p, "label": 1} for p in sample(fight_videos, samples_per_client//2)] +
                [{"path": p, "label": 0} for p in sample(nonfight_videos, samples_per_client//2)]
            )
        }

        with open(f'client{client_id}_data.json', 'w') as f:
            json.dump(client_data, f)

    print(f"Created {client_count} client files with {samples_per_client} samples each")

create_client_files()

Created 2 client files with 8 samples each


In [16]:
def read_video(video_path, num_frames=4):
    """Updated video reader that handles Path objects"""
    try:
        vr = VideoReader(str(video_path), ctx=cpu(0))  # Convert Path to string
        if len(vr) < num_frames:
            return np.zeros((num_frames, 224, 224, 3))

        step = max(1, len(vr) // num_frames)
        frames = vr.get_batch(range(0, len(vr), step)).asnumpy()[:num_frames]
        return frames
    except Exception as e:
        print(f"Error reading {video_path}: {str(e)}")
        return np.zeros((num_frames, 224, 224, 3))

In [17]:
@dataclass
class Config:
    MODEL_ID = "llava-hf/LLaVa-NeXT-Video-7b-hf"
    NUM_FRAMES = 4  # Reduced for memory
    BATCH_SIZE = 1
    USE_QLORA = True
    LORA_RANK = 4  # Reduced rank
    SERVER_ADDRESS = "127.0.0.1:8080"
    GRADIENT_ACCUMULATION_STEPS = 4  # Process small batches but accumulate gradients

config = Config()

In [18]:
# Dataset Utilities (Memory Optimized)
def read_video(video_path, num_frames=4):
    """Read video with frame skipping and error handling"""
    try:
        vr = VideoReader(video_path, ctx=cpu(0))
        if len(vr) < num_frames:
            return np.zeros((num_frames, 224, 224, 3))

        # Skip frames to save memory
        step = max(1, len(vr) // num_frames)
        frames = vr.get_batch(range(0, len(vr), step)).asnumpy()[:num_frames]
        return frames
    except Exception as e:
        print(f"Error reading {video_path}: {str(e)}")
        return np.zeros((num_frames, 224, 224, 3))

In [32]:
def get_model_and_processor():
    processor = AutoProcessor.from_pretrained(config.MODEL_ID)

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        llm_int8_enable_fp32_cpu_offload=True  # Enable CPU offloading
    )

    # Custom device map for memory optimization
    device_map = {
        "model.embed_tokens": 0,
        "model.layers.0": 0,
        "model.layers.1": 0,
        "model.layers.2": 0,
        "model.layers.3": 0,
        "model.norm": 0,
        "lm_head": 0,
        "model.layers.4": "cpu",  # Offload later layers to CPU
        "model.layers.5": "cpu",
        "model.layers.6": "cpu",
        "model.layers.7": "cpu"
    }

    model = LlavaNextVideoForConditionalGeneration.from_pretrained(
        config.MODEL_ID,
        quantization_config=bnb_config,
        device_map=device_map,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        offload_folder="offload"
    )

    lora_config = LoraConfig(
        r=config.LORA_RANK,
        lora_alpha=8,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        task_type="CAUSAL_LM",
    )

    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, lora_config)
    model.gradient_checkpointing_enable()
    model.config.use_cache = False

    print("Model loaded with mixed CPU/GPU offloading")
    return model, processor

In [33]:
# Memory-Optimized Client for Kaggle Dataset
class LlavaClient(fl.client.NumPyClient):
    def __init__(self, dataset_path):
        try:
            torch.cuda.empty_cache()

            with open(dataset_path) as f:
                data = json.load(f)
                self.dataset = data["videos"][:8] if isinstance(data, dict) else data[:8]

            self.model, self.processor = get_model_and_processor()
            print(f"Client initialized with {len(self.dataset)} samples")

        except Exception as e:
            torch.cuda.empty_cache()
            raise RuntimeError(f"Client initialization failed: {str(e)}\n"
                             "Possible solutions:\n"
                             "1. Restart runtime and try again\n"
                             "2. Reduce NUM_FRAMES in config\n"
                             "3. Use smaller batch size")

    def _prepare_dataset(self):
        """Convert video data to HuggingFace Dataset format"""
        processed_data = []

        for item in self.data["videos"]:
            try:
                frames = read_video(item["path"])
                processed_data.append({
                    "pixel_values": frames,
                    "label": item["label"]
                })
            except Exception as e:
                print(f"Skipping corrupted video {item['path']}: {str(e)}")
                continue

        return Dataset.from_dict({
            "pixel_values": [x["pixel_values"] for x in processed_data],
            "label": [x["label"] for x in processed_data]
        })

    def get_parameters(self, config):
        return [p.cpu().numpy() for p in self.model.parameters() if p.requires_grad]

    def fit(self, parameters, config):
        # Update model parameters
        current_params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        for (n, p), new_p in zip(current_params.items(), parameters):
            p.data = torch.from_numpy(new_p).to(p.device)

        # Memory-efficient training
        training_args = TrainingArguments(
            per_device_train_batch_size=config.BATCH_SIZE,
            gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
            output_dir=f"./output_client_{config.get('client_id', 0)}",
            learning_rate=config.get("lr", 1e-5),
            max_steps=4,
            fp16=True,
            optim="adamw_8bit",
            report_to="none",
            save_strategy="no",
            remove_unused_columns=True  # Saves memory
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=self.dataset,
        )

        # Clear cache before and after training
        torch.cuda.empty_cache()
        trainer.train()
        torch.cuda.empty_cache()

        return self.get_parameters({}), len(self.dataset), {}

    def evaluate(self, parameters, config):
        return 0.0, len(self.dataset), {"accuracy": 0.0}

In [34]:
def start_server():
    strategy = FedAvg(
        fraction_fit=1.0,
        min_fit_clients=2,
        min_available_clients=2,
        on_fit_config_fn=lambda rnd: {
            "lr": 1e-5,
            "batch_size": config.BATCH_SIZE,
            "gradient_accumulation_steps": config.GRADIENT_ACCUMULATION_STEPS
        }
    )

    # Updated server configuration
    fl.server.start_server(
        server_address=config.SERVER_ADDRESS,
        config=fl.server.ServerConfig(num_rounds=2),
        strategy=strategy
    )

In [35]:
# Run sequentially to save memory
def run_colab_simulation():
    print("Starting memory-optimized federated training...")

    # Start server
    server_thread = threading.Thread(target=start_server)
    server_thread.start()

    # Run clients one after another
    for client_id in range(2):
        print(f"\nStarting client {client_id}...")
        client = LlavaClient(f"client{client_id}_data.json")
        fl.client.start_numpy_client(server_address=config.SERVER_ADDRESS, client=client)
        torch.cuda.empty_cache()  # Clear memory between clients

    server_thread.join()
    print("Training completed!")


if __name__ == "__main__":
    run_colab_simulation()

	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower server, config: num_rounds=2, no round_timeout
INFO:flwr:Starting Flower server, config: num_rounds=2, no round_timeout


Starting memory-optimized federated training...

Starting client 0...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

RuntimeError: Client initialization failed: image_newline doesn't have any device set.
Possible solutions:
1. Restart runtime and try again
2. Reduce NUM_FRAMES in config
3. Use smaller batch size