# Federated Learning with LLaVA-NeXT-Video
**Goal**: Fine-tune a violence detection model across decentralized clients without sharing raw video data.

## Key Features:
- Uses **QLoRA** (4-bit quantization) for efficient federated training.
- **Flower** framework for federated averaging.
- Simulates 4 clients + 1 server in one notebook.

In [1]:
# === Core Libraries ===
!pip install -q kagglehub
!pip install -q opencv-python-headless
!pip install -q av
!pip install -q decord
!pip install -q torchvision
!pip install -q scikit-learn
!pip install -q seaborn
!pip install -q tensorboard

# === Hugging Face & Model Training ===
!pip install -q transformers datasets sentencepiece accelerate bitsandbytes peft trl

# === Federated Learning ===
!pip install -q flwr flwr-datasets flwr[simulation]

# === Configuration & Logging ===
!pip install -q omegaconf hydra-core

# === Energy Tracking ===
!pip install -q codecarbon

!pip install -q ipywidgets

In [2]:
# === Standard Library ===
import os
import json
import threading
import logging
from pathlib import Path
from random import sample
from dataclasses import dataclass
from typing import Any

# === Scientific and Visualization Libraries ===
import numpy as np
import pandas as pd
import cv2
import av
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    roc_auc_score
)
from tqdm.notebook import tqdm

# === PyTorch ===
import torch

# === Environment Tracking ===
from codecarbon import EmissionsTracker

# === Hugging Face Transformers ===
from transformers import (
    AutoProcessor,
    LlavaNextVideoForConditionalGeneration,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer
)

# === PEFT (Parameter-Efficient Fine-Tuning) ===
from peft import (
    LoraConfig,
    prepare_model_for_kbit_training,
    get_peft_model
)

# === Video Processing ===
from decord import VideoReader, cpu

# === Kaggle ===
import kagglehub

# === Datasets ===
from datasets import Dataset, load_dataset

# === Federated Learning (Flower) ===
import flwr as fl
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from flwr.client.mod import fixedclipping_mod
from flwr.server.strategy import DifferentialPrivacyClientSideFixedClipping

# === Custom Utilities ===
from utils import *

from datasets import Dataset, DatasetDict
import random
import shutil

In [3]:
logging.getLogger("codecarbon").setLevel(logging.WARNING)

In [4]:
cfg = get_config("federated")

print_config(cfg)

dataset:
  name: DanJoshua/RWF-2000
model:
  name: llava-hf/LLaVa-NeXT-Video-7b-hf
  quantization: 4
  gradient_checkpointing: true
  use_fast_tokenizer: false
  lora:
    r: 16
    alpha: 64
    target_modules:
    - q_proj
    - v_proj
    dropout: 0.075
    bias: none
  num_frames: 24
  save_model_path: fl_model/${model.name}_final_model.pt
train:
  num_rounds: ${flower.num_rounds}
  save_every_round: 5
  learning_rate_max: 5.0e-05
  learning_rate_min: 1.0e-06
  seq_length: 512
  padding_side: left
  evaluate_split: true
  training_arguments:
    batch_size: 2
    output_dir: null
    learning_rate: 5.0e-05
    per_device_train_batch_size: 1
    gradient_accumulation_steps: 1
    logging_steps: 10
    num_train_epochs: 3
    max_steps: 10
    report_to: null
    save_steps: 1000
    save_total_limit: 10
    gradient_checkpointing: ${model.gradient_checkpointing}
    lr_scheduler_type: constant
flower:
  num_clients: 4
  num_rounds: 10
  fraction_fit: 1.0e-05
  min_fit_clients: 4
  m

In [5]:
path = kagglehub.dataset_download("yash07yadav/project-data")
print(f"Dataset downloaded to: {path}")    

# Paths
train_base_path = Path(path) / "Complete Dataset" / "train"
val_base_path = Path(path) / "Complete Dataset" / "val"

def load_video_paths(base_path, class_names=["NonFight", "Fight"]):
    """Load video paths with labels from structured directory"""
    dataset = {"videos": [], "labels": []}
    
    for label_idx, class_name in enumerate(class_names):
        # Verify download
        class_dir = base_path / class_name
        video_paths = list(class_dir.glob("*"))
        
        for path in video_paths:
            dataset["videos"].append(str(path))
            dataset["labels"].append(label_idx)
            
    return Dataset.from_dict(dataset)
    
def partition_dataset(dataset: Dataset, num_clients: int = 4, seed: int = 42):
    """Split dataset into `num_clients` partitions."""
    random.seed(seed)
    indices = list(range(len(dataset)))
    random.shuffle(indices)

    partition_size = len(dataset) // num_clients
    partitions = []
    
    for i in range(num_clients):
        start = i * partition_size
        end = start + partition_size if i != num_clients - 1 else len(dataset)
        part_indices = indices[start:end]
        partition = dataset.select(part_indices)
        partitions.append(partition)

    return partitions

def save_partitions_to_dirs(partitions, save_base: str = "data/clients"):
    """Copy videos into data/clients/client_{i}/ folders."""
    for i, partition in enumerate(partitions):
        client_dir = Path(save_base) / f"client_{i}"
        client_dir.mkdir(parents=True, exist_ok=True)

        for video_path, label in zip(partition["videos"], partition["labels"]):
            class_name = "fight" if label == 1 else "nonfight"
            fname = f"{class_name}_{os.path.basename(video_path)}"
            dst_path = client_dir / fname
            shutil.copy(video_path, dst_path)

    print(f"Saved {len(partitions)} clients in {save_base}/")

# Example usage
train_dataset = load_video_paths(train_base_path)
val_dataset = load_video_paths(val_base_path)

train_partitions = partition_dataset(train_dataset, num_clients=cfg.flower.num_clients)
val_partitions = partition_dataset(val_dataset, num_clients=cfg.flower.num_clients)

# ------------------- uncomment when number of clients changed -------------------------
save_partitions_to_dirs(train_partitions, save_base="data/clients_train")
save_partitions_to_dirs(val_partitions, save_base="data/clients_val")

Dataset downloaded to: /home/jovyan/.cache/kagglehub/datasets/yash07yadav/project-data/versions/1
Saved 4 clients in data/clients_train/
Saved 4 clients in data/clients_val/


In [6]:
processor, data_collator = get_processor_and_data_collator(
    cfg.model.name,
    cfg.model.use_fast_tokenizer,
    cfg.train.padding_side,
)

In [7]:
save_path = "./fl_model"

client = fl.client.ClientApp(
    client_fn = gen_client_fn(
        data_dir="data/clients_train",
        data_collator=data_collator,
        model_cfg=cfg.model,
        train_cfg=cfg.train.training_arguments,
        save_path=save_path,
    ),
    mods=[fixedclipping_mod] 
)

In [8]:
def server_fn(context: Context):
    # Define the Strategy
    strategy = fl.server.strategy.FedAvg(
        fraction_fit=cfg.flower.fraction_fit,
        min_fit_clients=cfg.flower.min_fit_clients,
        min_available_clients=cfg.flower.min_available_clients,
        min_evaluate_clients=cfg.flower.min_evaluate_clients,

        fraction_evaluate=0.0, # No federated evaluation
        on_fit_config_fn=get_on_fit_config(cfg.flower.num_rounds),
        fit_metrics_aggregation_fn=fit_weighted_average,
        evaluate_fn=get_evaluate_fn(
            cfg.model,
            processor,
            cfg.model.save_model_path,
            val_base_path
        ),
    )

    # Add Differential Privacy
    strategy = DifferentialPrivacyClientSideFixedClipping(
        strategy, 
        noise_multiplier=cfg.flower.dp.noise_mult,
        clipping_norm=cfg.flower.dp.clip_norm, 
        num_sampled_clients=cfg.flower.num_clients
    )

    # Number of rounds to run the simulation
    config = fl.server.ServerConfig(
        num_rounds=cfg.flower.num_rounds,
    )
    
    return fl.server.ServerAppComponents(strategy=strategy, config=config) 

In [9]:
server = fl.server.ServerApp(server_fn=server_fn)

In [10]:
def run_simulation():
    print("Starting federated training...")

    os.makedirs(f"FedPer emissions/{cfg.model.name} Emissions", exist_ok=True)

    # Initialize trackers
    emissions_tracker = EmissionsTracker(
            project_name=f"{cfg.model.name} FedPer Emissions",
            measure_power_secs=1,
            output_dir=f"FedPer emissions/{cfg.model.name} Emissions",
            save_to_file=True,
            log_level="warning"
        )

    emissions_tracker.start()

    client_resources = dict(cfg.flower.client_resources)
    fl.simulation.run_simulation(
        server_app=server,
        client_app=client,
        num_supernodes=cfg.flower.num_clients,
        backend_config={"client_resources": {
                            "num_cpus": int(cfg.flower.client_resources["num_cpus"]),
                            "num_gpus": float(cfg.flower.client_resources["num_gpus"]),
                        },
                        "init_args": backend_setup}
    )

    emissions_tracker.stop()

    print("Training completed!")

In [11]:
if __name__ == "__main__":
    from multiprocessing import Process
    import time
    import gc
    
    # Clear caches
    gc.collect()
    torch.cuda.empty_cache()
    
    # Run simulation
    run_simulation()



Starting federated training...


 Linux OS detected: Please ensure RAPL files exist at /sys/class/powercap/intel-rapl/subsystem to measure CPU

[92mINFO [0m: Starting Flower ServerApp, config: num_rounds=10, no round_timeout
[92mINFO [0m: 
[92mINFO [0m: [INIT]
[92mINFO [0m: Requesting initial parameters from one random client
[92mINFO [0m: Received initial parameters from one random client
[92mINFO [0m: Starting evaluation of initial global parameters
[92mINFO [0m: initial parameters (loss, other metrics): 0.0, {}
[92mINFO [0m: 
[92mINFO [0m: [ROUND 1]
[92mINFO [0m: configure_fit: strategy sampled 4 clients (out of 4)
[92mINFO [0m: aggregate_fit: received 4 results and 0 failures
[92mINFO [0m: aggregate_fit: central DP noise with 0.0025 stdev added
[92mINFO [0m: fit progress: (1, 0.0, {}, 221.67543571296846)
[92mINFO [0m: configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m: 
[92mINFO [0m: [ROUND 2]
[92mINFO [0m: configure_fit: strategy sampled 4 clients (out of 4

Training completed!
