In [2]:
import os
import json
import random
from glob import glob
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# === SAFE GLOBALS FOR PICKLE ===
torch.serialization.add_safe_globals([LabelEncoder])

# === DATASET LOADER ===
def load_dataset(root_dir):
    dataset = []
    class_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]

    for class_dir in class_dirs:
        class_name = os.path.basename(class_dir)
        video_dirs = sorted(glob(os.path.join(class_dir, 'video_*')))

        for video_dir in video_dirs:
            video_name = os.path.basename(video_dir)
            caption_filename = f"captions_{video_name.split('_')[-1]}.json"
            caption_path = os.path.join(video_dir, caption_filename)

            if not os.path.exists(caption_path):
                continue

            try:
                with open(caption_path, 'r') as f:
                    captions = json.load(f)
            except json.JSONDecodeError:
                print(f"❌ Skipping malformed JSON: {caption_path}")
                continue

            for i in range(64):
                frame_key = f"frame_{i:02d}"
                if frame_key not in captions:
                    continue
                caption = captions[frame_key]

                frame_found = False
                for ext in ['jpg', 'png']:
                    frame_filename = f"frame_{i:02d}.{ext}"
                    frame_path = os.path.join(video_dir, frame_filename)
                    if os.path.exists(frame_path):
                        frame_found = True
                        break

                if not frame_found:
                    continue

                try:
                    image = Image.open(frame_path).convert("RGB")
                    dataset.append({
                        "class": class_name,
                        "video": video_name,
                        "frame": frame_path,
                        "caption": caption,
                        "image": image
                    })
                except:
                    continue

    return dataset

# === DATASET CLASS ===
class CrimeDataset(Dataset):
    def __init__(self, data, processor):
        self.data = data
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]["image"]
        label = self.data[idx]["label"]
        inputs = self.processor(images=image, return_tensors="pt")
        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }

# === MODEL ===
class CLIPClassifier(nn.Module):
    def __init__(self, clip_model, num_classes):
        super(CLIPClassifier, self).__init__()
        self.clip = clip_model
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, pixel_values):
        vision_outputs = self.clip.vision_model(pixel_values=pixel_values)
        pooled = vision_outputs.pooler_output
        embeddings = self.clip.visual_projection(pooled)
        return self.classifier(embeddings)

# === TRAINING FUNCTION ===
def train_clip_classifier(root_dir, epochs=4, batch_size=16, freeze_clip=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    raw_data = load_dataset(root_dir)
    if len(raw_data) == 0:
        raise ValueError("❌ No valid data found. Please check your dataset path and captions.")

    random.shuffle(raw_data)

    label_encoder = LabelEncoder()
    all_labels = label_encoder.fit_transform([item["class"] for item in raw_data])
    for i in range(len(raw_data)):
        raw_data[i]["label"] = all_labels[i]

    num_classes = len(label_encoder.classes_)

    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)

    if freeze_clip:
        for param in clip_model.vision_model.parameters():
            param.requires_grad = False

    model = CLIPClassifier(clip_model, num_classes)
    model = nn.DataParallel(model)  # ✅ Multi-processing ready
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
    loss_fn = nn.CrossEntropyLoss()

    train_data, val_data = train_test_split(raw_data, test_size=0.2, random_state=42)
    train_dataset = CrimeDataset(train_data, processor)
    val_dataset = CrimeDataset(val_data, processor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=False)


    for epoch in range(epochs):
        model.train()
        total_loss, correct = 0, 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            pixel_values = batch["pixel_values"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)

            outputs = model(pixel_values)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        acc = correct / len(train_dataset)
        print(f"✅ Epoch {epoch+1} - Loss: {total_loss:.4f} - Accuracy: {acc:.4f}")

        model.eval()
        val_correct = 0
        with torch.no_grad():
            for batch in val_loader:
                pixel_values = batch["pixel_values"].to(device, non_blocking=True)
                labels = batch["label"].to(device, non_blocking=True)

                outputs = model(pixel_values)
                val_correct += (outputs.argmax(dim=1) == labels).sum().item()

        val_acc = val_correct / len(val_dataset)
        print(f"✅ Validation Accuracy: {val_acc:.4f}")

    torch.save(model.state_dict(), "clip_crime_classifier1.pt")
    torch.save(label_encoder, "label_encoder.pt")
    print("✅ Model and label encoder saved.")
    return model, label_encoder, processor

# === INFERENCE FUNCTIONS ===
def load_model_for_inference(model_path, label_encoder_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    label_encoder = torch.load(label_encoder_path, map_location=device, weights_only=False)
    label_names = label_encoder.classes_

    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    model = CLIPClassifier(clip_model, num_classes=len(label_names))
    model = nn.DataParallel(model)  # ✅ MPU-safe
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()

    return model, processor, label_names, device

def predict_image(img_path, model, processor, label_names, device):
    image = Image.open(img_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs["pixel_values"].to(device)

    with torch.no_grad():
        outputs = model(pixel_values)
        probs = torch.nn.functional.softmax(outputs, dim=1)
        pred_idx = probs.argmax(dim=1).item()
        return label_names[pred_idx], probs[0][pred_idx].item()

def run_inference(image_path, model_path="clip_crime_classifier1.pt", label_encoder_path="label_encoder.pt"):
    model, processor, label_names, device = load_model_for_inference(model_path, label_encoder_path)

    if os.path.isfile(image_path):
        pred, conf = predict_image(image_path, model, processor, label_names, device)
        print(f"[{image_path}] ➜ {pred} ({conf:.2f})")
        return pred, conf

    elif os.path.isdir(image_path):
        print(f"\n📁 Inference on folder: {image_path}")
        results = []
        for fname in os.listdir(image_path):
            if fname.lower().endswith((".jpg", ".png", ".jpeg")):
                fpath = os.path.join(image_path, fname)
                pred, conf = predict_image(fpath, model, processor, label_names, device)
                print(f"{fname}: {pred} ({conf:.2f})")
                results.append((fname, pred, conf))
        return results

    else:
        print(f"❌ Invalid path: {image_path}")
        return None

# === ENTRY POINT ===
if __name__ == "__main__":
    model, label_encoder, processor = train_clip_classifier("/Users/preetham_aleti/Desktop/DATASET")
    run_inference("/Users/preetham_aleti/Desktop/DATASET/ROBBERY/video_03/frame_13.jpg")


Epoch 1: 100%|██████████| 636/636 [15:47<00:00,  1.49s/it]


✅ Epoch 1 - Loss: 104.6688 - Accuracy: 0.9595
✅ Validation Accuracy: 0.9992


Epoch 2: 100%|██████████| 636/636 [15:43<00:00,  1.48s/it]


✅ Epoch 2 - Loss: 3.3190 - Accuracy: 0.9995
✅ Validation Accuracy: 0.9996


Epoch 3: 100%|██████████| 636/636 [16:15<00:00,  1.53s/it]


✅ Epoch 3 - Loss: 1.4929 - Accuracy: 0.9998
✅ Validation Accuracy: 0.9996


Epoch 4: 100%|██████████| 636/636 [19:44<00:00,  1.86s/it] 


✅ Epoch 4 - Loss: 0.8564 - Accuracy: 0.9999
✅ Validation Accuracy: 1.0000
✅ Model and label encoder saved.


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy._core.multiarray._reconstruct was not an allowed global by default. Please use `torch.serialization.add_safe_globals([numpy._core.multiarray._reconstruct])` or the `torch.serialization.safe_globals([numpy._core.multiarray._reconstruct])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [5]:
def load_model_for_inference(model_path, label_encoder_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    label_encoder = torch.load(label_encoder_path, map_location=device, weights_only=False)
    label_names = label_encoder.classes_

    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    model = CLIPClassifier(clip_model, num_classes=len(label_names))
    model = nn.DataParallel(model)  # ✅ MPU-safe
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()

    return model, processor, label_names, device
run_inference("/Users/preetham_aleti/Desktop/DATASET/ROBBERY/video_03/frame_13.jpg")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


[/Users/preetham_aleti/Desktop/DATASET/ROBBERY/video_03/frame_13.jpg] ➜ ROBBERY (1.00)


(np.str_('ROBBERY'), 0.9993220567703247)

In [None]:
import os
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image

def load_clip_model(model_path):
    """
    Load CLIP model with advanced weight loading strategies.

    Args:
        model_path (str): Path to the model checkpoint file

    Returns:
        tuple: Loaded model and processor
    """
    # Initialize base CLIP model
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

    # Ensure the model path exists
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model checkpoint not found at {model_path}")

    try:
        # Load state dict
        state_dict = torch.load(model_path, map_location="cpu")

        # Print out available keys for debugging
        print("Available keys in state dict:")
        print("\n".join(state_dict.keys()))

        # Advanced weight loading strategies
        try:
            # Strategy 1: Try loading as-is
            model.load_state_dict(state_dict, strict=False)
            print("✅ Loaded model weights successfully with loose matching.")
        except Exception as full_load_error:
            print(f"Full model load failed: {full_load_error}")

            # Strategy 2: Try loading vision model weights specifically
            try:
                # Extract vision model weights
                vision_dict = {
                    k.replace("vision_model.", ""): v
                    for k, v in state_dict.items()
                    if k.startswith("vision_model.")
                }

                # If no vision model keys found, try matching partial keys
                if not vision_dict:
                    vision_dict = {
                        k.split(".")[-1]: v
                        for k, v in state_dict.items()
                        if any(part in k for part in ['vision', 'embed', 'layer', 'norm'])
                    }

                # Partial loading with custom matching
                missing_keys = []
                unexpected_keys = []
                model.vision_model.load_state_dict(
                    vision_dict,
                    strict=False,
                    missing_keys=missing_keys,
                    unexpected_keys=unexpected_keys
                )

                print("⚠️ Loaded vision model weights with partial matching.")
                if missing_keys:
                    print("Missing keys:", missing_keys)
                if unexpected_keys:
                    print("Unexpected keys:", unexpected_keys)

            except Exception as vision_load_error:
                print(f"Vision model load failed: {vision_load_error}")
                raise RuntimeError("Unable to load model weights") from vision_load_error

        model.eval()
    except Exception as e:
        print(f"Critical error loading model weights: {e}")
        raise

    # Initialize the processor for CLIP
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    return model, processor

def classify_image(model, processor, image_path, text_labels, top_k=4):
    """
    Classify an image using CLIP model with given text labels.

    Args:
        model (CLIPModel): Loaded CLIP model
        processor (CLIPProcessor): CLIP processor
        image_path (str): Path to the input image
        text_labels (list): List of text labels to compare against
        top_k (int, optional): Number of top predictions to return. Defaults to 5.

    Returns:
        list: Top K predictions with labels and scores
    """
    # Validate inputs
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found at {image_path}")

    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")

    # Encode inputs
    inputs = processor(text=text_labels, images=image, return_tensors="pt", padding=True)

    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    # Normalize embeddings & compute similarity
    image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
    text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
    similarities = (image_embeds @ text_embeds.T)[0]  # shape: [num_text_labels]

    # Get top-K predictions
    top_k_indices = similarities.topk(top_k).indices
    top_k_labels = [text_labels[i] for i in top_k_indices]
    top_k_scores = [similarities[i].item() for i in top_k_indices]

    return list(zip(top_k_labels, top_k_scores))

def main():
    # Configuration
    model_path = "clip_crime_classifier.pt"  # Path to fine-tuned model weights
    image_path = "/Users/preetham_aleti/Desktop/istockphoto-955124060-612x612.jpg"# Replace with actual image path

    # Text labels (same as original script)
    text_labels = [
    "knife", "gun", "blood", "weapon", "dead", "lying", "suspicious", "fight", "attack", "injury", "explosion", "fire",
    "pistol", "rifle", "body", "corpse", "skull", "molotov", "bullet shells", "syringe", "chains", "rope", "handcuffs",
    "bat", "metal rod", "glass shard", "drugs", "cash bundle", "suitcase", "torn clothes", "broken phone", "firearm",
    "blunt object", "axe", "machete", "hammer", "crowbar", "scalpel", "shiv", "box cutter", "brass knuckles",
    "stun gun", "pepper spray", "garrote", "baton", "metal pipe", "switchblade", "nail gun", "bullet", "shell casing",
    "spent cartridge", "gunpowder residue", "blood spatter", "brain matter", "hair strands", "fingerprint", "shoeprint",
    "footprint", "broken glass", "semen stain", "urine sample", "burn marks", "ligature marks", "defensive wounds",
    "duct tape", "zip tie", "gag", "blindfold", "chloroform bottle", "acid bottle", "gloves", "balaclava", "ski mask",
    "face mask", "surgical gloves", "latex gloves", "bloody cloth", "wet towel", "gasoline can", "fire debris",
    "ash residue", "charred remains", "scorch mark", "bullet hole", "burnt paper", "exploded device", "broken lock",
    "bent door", "scratched surface", "lockpick set", "ID card", "fake passport", "credit card", "stolen phone",
    "burner phone", "sim card", "CCTV footage", "notebook", "map with markings", "surveillance photo", "threat letter",
    "ransom note", "USB drive", "hard disk", "memory card", "wallet", "driver's license", "personal photo", "watch",
    "ring", "bloodied shoe", "muddy boots", "missing clothing", "ripped shirt", "torn dress", "evidence bag",
    "crime scene tape", "silencer", "tripwire", "booby trap", "grenade", "sniper rifle", "projectile", "smoke bomb",
    "carbon monoxide canister", "fire starter", "fuse wire", "circuit board", "detonator", "gas leak", "blood trail",
    "drag marks", "vomit stain", "fingerprint powder", "DNA swab", "bite mark", "nail scrapings", "glass particles",
    "paint chip", "soil sample", "fibers", "gunshot residue", "blood droplets", "taser", "wire cutters", "pliers",
    "torch", "headlamp", "ladder", "paracord", "spiked object", "whip", "screwdriver", "socket wrench", "drill",
    "nail", "saw", "shovel", "climbing gear", "ski cap", "work boots", "combat boots", "disguise", "makeup kit",
    "fake beard", "wig", "mirror fragment", "red-stained clothing", "soaked glove", "burnt shoes", "burnt wallet",
    "twine", "twisted metal", "credit card reader", "POS skimmer", "pin pad", "cloned card", "forged signature",
    "bank note", "fake cheque", "contract", "confession note", "manifest", "registry", "diary", "blackmail note",
    "spy camera", "drone", "bug detector", "tracking device", "microphone", "hidden recorder", "wiretap",
    "GPS tracker", "browser history", "deleted messages", "burned documents", "printer ink", "USB cable",
    "data cable", "footprint in blood", "smeared prints", "smashed camera", "punch hole", "broken furniture",
    "spray paint", "graffiti", "blood pool", "trail of tears", "hysterical victim", "witness statement",
    "torn notebook", "security tag", "alarm sensor", "motion detector", "entry log", "access card", "forged badge",
    "suspicious note", "powdery substance", "chemical residue", "gas mask", "hazmat suit", "burner laptop",
    "encrypted phone", "person running", "flames fire","Explosion aftermath","Rescue operation","Disaster response","Fire damage",
        "Explosion aftermath", "Fire damage", "Scattered debris", "Black smoke", "Firefighters at scene",
"Destroyed vehicles", "Collapsed buildings", "Rescue operation", "Emergency response units",
"Injured civilians", "Charred remains", "Flash burns", "Smoke inhalation victims", "Craters on ground",
"Damaged infrastructure", "Shattered glass", "Shockwave damage", "Burnt clothing", "Heat distortion",
"Bomb squad presence", "Evacuation in progress", "Military cordon", "Hazmat suits", "Surveillance footage of blast",
"Sound of explosion", "Witness reports", "Forensic chemical traces", "Explosive residue", "Satellite imagery of blast",
"Thermal imaging", "Before/after photos", "Seismograph spike", "Damaged electronics", "Overturned vehicles",
"Melted metals", "High-temperature indicators", "Blast pressure marks", "Security camera destruction",
"Media reporting live", "Flash of light", "Sudden power outages", "Flying debris injuries",
"Hospital surge", "Burn units activated", "Soot-covered victims", "Displaced people", "Airspace restrictions",
"Shockwave-injured animals", "Search and rescue dogs", "Structural engineers onsite",

"Physical altercation", "Punches thrown", "Visible bruises", "Blood stains", "Crowd gathering",
"Police breaking up fight", "Verbal aggression", "Weapons drawn", "Security camera footage", "Bodycam recordings",
"CCTV footage", "Broken bottles", "Pulled hair", "Torn clothing", "Screaming", "Knife wounds",
"Hospital reports", "Eyewitness testimony", "Self-defense stances", "Bystander videos", "Arrest records",
"911 calls", "Social media footage", "Conflicting statements", "Injury reports", "Physical evidence on ground",
"Trampled area", "Thrown objects", "Police batons", "Tasers used", "Pepper spray evidence", "Handcuff marks",
"Footage of chase", "Disturbed public area", "Fight started in queue", "Nightclub security footage",
"Bar fight scene", "Street brawl", "Schoolyard altercation", "Confiscated weapons", "Face injuries",
"Split lips", "Black eyes", "Smashed furniture", "Police statement", "Witness cell footage", "Brawl aftermath",
"Adrenaline effects", "Medical treatment logs", "Offensive gestures",

"Forced entry", "Broken locks", "Surveillance footage", "Mask-wearing suspect", "Glove prints",
"Empty safe", "Stolen valuables", "Threats with weapon", "Demand notes", "Security alarm triggered",
"Witness testimony", "Gunshot sound", "Panic button pressed", "Store clerks hiding", "Cash register emptied",
"Abandoned getaway vehicle", "Dropped loot", "Forensic evidence", "Glove fibers", "Security tape review",
"Bank teller account", "Vault tampering", "ATM ripped open", "Smashed glass", "Crowbar found",
"CCTV malfunction", "Phone jammers", "Multiple suspects", "Police pursuit", "Police sketches",
"Security guard overpowered", "Use of disguise", "Timing of heist", "Inside job suspicion",
"Drilled locks", "Camera blackout", "Customer witness", "Fake IDs", "Unusual deposits", "Tool marks",
"Fingerprint recovery", "Silent alarm logs", "DNA swab from scene", "Safe-cracking tools", "Entry point photos",
"Surrounding business footage", "Abandoned evidence", "Loot bag found", "Fake license plates",
"Escape route tracing", "Police barrier breach",

"Unpaid items in bag", "Security footage", "No receipt", "Concealed merchandise", "Price tag removal",
"Security tag tampering", "In-store alarms", "Store detective report", "Suspicious behavior", "Camera footage",
"Bag check failure", "Hidden compartments", "Items under clothing", "Unusual customer route", "Acting nervously",
"Rushed exit", "Cashier distraction", "Blind spots visited", "Multiple entry attempts", "Dressing room concealment",
"Magnetic detacher possession", "Recovered stolen goods", "Admission of guilt", "Inventory mismatch",
"Staff report", "Frequent offender", "Cellphone jammer", "Fake returns", "Returned stolen items",
"Store layout exploitation", "Distraction technique", "Witness confrontation", "Unattended carts",
"Companion as lookout", "Fake disability use", "Suspicious bag size", "Overcrowded checkout lane",
"Tampered packaging", "Evasive responses", "Rapid clothing change", "Discarded wrappers", "Matching descriptions",
"Unscanned items", "Detained by staff", "Portable scanner blocker", "Frequent item handling",
"Retail chain alert", "Repeat pattern detection", "Reluctance to open bag", "Disguised identity"
]


    try:
        # Load model and processor
        model, processor = load_clip_model(model_path)

        # Classify image
        predictions = classify_image(model, processor, image_path, text_labels)

        # Display results
        print("\n--- Top Predictions ---")
        for rank, (label, score) in enumerate(predictions, 1):
            print(f"Top {rank}: {label} (Score: {score:.3f})")

    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == "__main__":
    main()

In [7]:
import torch
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, BartForConditionalGeneration
import matplotlib.pyplot as plt
import os
import cv2
from typing import List, Tuple, Dict
import json
from tqdm import tqdm


class CLIPVideoInference:
    def __init__(self, model_path: str = "clip_crime_classifier.pt", device: str = None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)

        print(f"Using device: {self.device}")
        if self.device.type == "cuda":
            print(f"GPU: {torch.cuda.get_device_name(0)}")

        print(f"Loading model from: {model_path}")
        try:
            self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
            self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

            checkpoint = torch.load(model_path, map_location=self.device)
            if 'model_state_dict' in checkpoint:
                self.model.load_state_dict(checkpoint['model_state_dict'])
            elif 'state_dict' in checkpoint:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint)

            self.model.to(self.device)
            self.model.eval()
            print("✅ Custom CLIP model loaded successfully!")
        except Exception as e:
            print(f"❌ Failed to load model: {e}")

    def load_captions_from_file(self, file_path: str) -> List[str]:
        with open(file_path, 'r') as f:
            captions = [line.strip() for line in f.readlines() if line.strip()]
        print(f"Loaded {len(captions)} captions")
        return captions

    def extract_frames_with_timestamps(self, video_path: str, num_frames: int = 80) -> Tuple[List[np.ndarray], List[float]]:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise ValueError(f"Cannot open video: {video_path}")

        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        duration = total_frames / fps

        indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
        frames, timestamps = [], []

        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(rgb)
                timestamps.append(idx / fps)

        cap.release()
        print(f"Extracted {len(frames)} frames from {video_path}")
        return frames, timestamps

    def encode_frame(self, frame: np.ndarray) -> torch.Tensor:
        pil_image = Image.fromarray(frame)
        inputs = self.processor(images=[pil_image], return_tensors="pt").to(self.device)
        with torch.no_grad():
            features = self.model.get_image_features(**inputs)
            features = features / features.norm(dim=-1, keepdim=True)
        return features

    def encode_texts(self, captions: List[str]) -> torch.Tensor:
        all_feats = []
        for i in range(0, len(captions), 32):
            batch = captions[i:i+32]
            inputs = self.processor(text=batch, return_tensors="pt", padding=True, truncation=True).to(self.device)
            with torch.no_grad():
                feats = self.model.get_text_features(**inputs)
                feats = feats / feats.norm(dim=-1, keepdim=True)
                all_feats.append(feats)
        return torch.cat(all_feats)

    def get_best_caption(self, frames: List[np.ndarray], captions: List[str]) -> Tuple[str, float]:
        frame_feats = [self.encode_frame(f) for f in frames]
        avg_feat = torch.mean(torch.stack(frame_feats), dim=0)
        text_feats = self.encode_texts(captions)
        sim = torch.matmul(avg_feat, text_feats.T).squeeze()
        best_idx = torch.argmax(sim).item()
        return captions[best_idx], sim[best_idx].item()

    def generate_video_story(self, video_path: str, captions: List[str], segments: int = 20, frames_per_segment: int = 4) -> Dict:
        total_frames = segments * frames_per_segment
        frames, timestamps = self.extract_frames_with_timestamps(video_path, total_frames)

        if len(frames) < total_frames:
            print("⚠️ Not enough frames, reducing number of segments")
            segments = len(frames) // frames_per_segment

        used = set()
        story = []

        for i in range(segments):
            start = i * frames_per_segment
            end = start + frames_per_segment
            seg_frames = frames[start:end]
            seg_start, seg_end = timestamps[start], timestamps[end-1]

            available = [cap for cap in captions if cap not in used] or captions
            caption, score = self.get_best_caption(seg_frames, available)
            used.add(caption)

            story.append({
                "sequence": i+1,
                "time_range": f"{seg_start:.1f}s - {seg_end:.1f}s",
                "caption": caption,
                "confidence": round(score, 4)
            })

            print(f"[{i+1}] {caption} ({score:.3f})")

        return {
            "video_path": video_path,
            "segments": story,
            "duration": timestamps[-1] if timestamps else 0
        }

    def display_story(self, result: Dict):
        print("\n" + "="*80)
        print(f"🎬 VIDEO STORY: {os.path.basename(result['video_path'])}")
        print("="*80)
        for item in result["segments"]:
            print(f"🕒 {item['time_range']} | 💬 {item['caption']} (confidence: {item['confidence']:.3f})")
        print("="*80 + "\n")

    def get_captions_list(self, result: Dict) -> List[str]:
        return [item['caption'] for item in result['segments']]


def summarize_captions(captions: List[str], device: torch.device = torch.device("cpu")) -> str:
    print("\n📚 Summarizing captions using BART...")
    model_name = "facebook/bart-large-cnn"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = BartForConditionalGeneration.from_pretrained(model_name).to(device)

    text_input = " ".join(captions)
    inputs = tokenizer([text_input], max_length=1024, return_tensors='pt', truncation=True).to(device)

    summary_ids = model.generate(
        inputs["input_ids"],
        max_length=300,
        min_length=100,
        length_penalty=2.0,
        num_beams=6,
        no_repeat_ngram_size=4,
        early_stopping=True,
        repetition_penalty=1.5,
        temperature=0.9,
        top_k=50,
        top_p=0.95
    )

    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary.strip()


# === Main Script ===
if __name__ == "__main__":
    video_path = "/Users/preetham_aleti/Downloads/Explosion/Explosion050_x264A/Explosion050_x264A.mp4"
    captions_file = "/Users/preetham_aleti/Desktop/captions.txt"
    model_path = "clip_crime_classifier1.pt"

    # Inference
    inferencer = CLIPVideoInference(model_path=model_path)
    available_captions = inferencer.load_captions_from_file(captions_file)
    result = inferencer.generate_video_story(video_path, available_captions)
    inferencer.display_story(result)

    # Summarization
    segment_captions = inferencer.get_captions_list(result)
    summary = summarize_captions(segment_captions)
    print("\n📌 Final Video Summary:\n" + summary)


Using device: cpu
Loading model from: clip_crime_classifier1.pt
❌ Failed to load model: Error(s) in loading state_dict for CLIPModel:
	Missing key(s) in state_dict: "logit_scale", "text_model.embeddings.token_embedding.weight", "text_model.embeddings.position_embedding.weight", "text_model.encoder.layers.0.self_attn.k_proj.weight", "text_model.encoder.layers.0.self_attn.k_proj.bias", "text_model.encoder.layers.0.self_attn.v_proj.weight", "text_model.encoder.layers.0.self_attn.v_proj.bias", "text_model.encoder.layers.0.self_attn.q_proj.weight", "text_model.encoder.layers.0.self_attn.q_proj.bias", "text_model.encoder.layers.0.self_attn.out_proj.weight", "text_model.encoder.layers.0.self_attn.out_proj.bias", "text_model.encoder.layers.0.layer_norm1.weight", "text_model.encoder.layers.0.layer_norm1.bias", "text_model.encoder.layers.0.mlp.fc1.weight", "text_model.encoder.layers.0.mlp.fc1.bias", "text_model.encoder.layers.0.mlp.fc2.weight", "text_model.encoder.layers.0.mlp.fc2.bias", "text_m

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



📌 Final Video Summary:
CCTV footage shows a street scene with walkers, bikers, an auto, and a woman in green saree balancing flowers on her head. A gas station is engulfed in fire with vehicles parked nearby and people rushing to safety. The dashcam continues to record the residential street in the immediate aftermath of a powerful explosion. A substantial amount of smoke and debris still fills the air, though it may be showing very slight signs of further dispersal. Large and small fragments are still visible within the smoke.


In [13]:
import os
import cv2
import shutil
from collections import Counter
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel

def extract_frames_from_video(video_path, frames_dir, frame_rate=1):
    """
    Extract frames from video at given frame_rate (frames per second).
    Saves extracted frames as .jpg files in frames_dir.
    """
    if os.path.exists(frames_dir):
        shutil.rmtree(frames_dir)  # Clean up old frames folder
    os.makedirs(frames_dir)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Cannot open video file: {video_path}")

    fps = cap.get(cv2.CAP_PROP_FPS)
    interval = int(fps // frame_rate) if fps > 0 else 1
    count = 0
    saved_count = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if count % interval == 0:
            frame_path = os.path.join(frames_dir, f"frame_{saved_count:05d}.jpg")
            cv2.imwrite(frame_path, frame)
            saved_count += 1
        count += 1

    cap.release()
    print(f"Extracted {saved_count} frames to {frames_dir}")

def load_clip_model(model_path):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Checkpoint not found: {model_path}")
    state_dict = torch.load(model_path, map_location="cpu")
    try:
        model.load_state_dict(state_dict, strict=False)
        print("✅ Model loaded.")
    except Exception as e:
        raise RuntimeError(f"Failed to load model weights: {e}")
    model.eval()
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return model, processor

def classify_image(model, processor, image_path, text_labels, top_k=5):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(text=text_labels, images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
    text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
    similarities = (image_embeds @ text_embeds.T)[0]
    top_k_indices = similarities.topk(top_k).indices
    return [(text_labels[i], similarities[i].item()) for i in top_k_indices]

def run_inference_on_frames(frames_dir):
    # Use your existing crime and evidence labels here (shortened for example)
    model_path = "clip_crime_classifier1.pt"
    model, processor = load_clip_model(model_path)

    crime_labels = ["robbery", "explosion", "assault", "kidnapping", "arson", "vandalism", "shooting", "theft", "fight", "accident","normal action"]
    # Use your full evidence_labels list from your original code
    evidence_labels = [
    "knife", "gun", "blood", "weapon", "dead", "lying", "suspicious", "fight", "attack", "injury", "explosion", "fire",
    "pistol", "rifle", "body", "corpse", "skull", "molotov", "bullet shells", "syringe", "chains", "rope", "handcuffs",
    "bat", "metal rod", "glass shard", "drugs", "cash bundle", "suitcase", "torn clothes", "broken phone", "firearm",
    "blunt object", "axe", "machete", "hammer", "crowbar", "scalpel", "shiv", "box cutter", "brass knuckles",
    "stun gun", "pepper spray", "garrote", "baton", "metal pipe", "switchblade", "nail gun", "bullet", "shell casing",
    "spent cartridge", "gunpowder residue", "blood spatter", "brain matter", "hair strands", "fingerprint", "shoeprint",
    "footprint", "broken glass", "semen stain", "urine sample", "burn marks", "ligature marks", "defensive wounds",
    "duct tape", "zip tie", "gag", "blindfold", "chloroform bottle", "acid bottle", "gloves", "balaclava", "ski mask",
    "face mask", "surgical gloves", "latex gloves", "bloody cloth", "wet towel", "gasoline can", "fire debris",
    "ash residue", "charred remains", "scorch mark", "bullet hole", "burnt paper", "exploded device", "broken lock",
    "bent door", "scratched surface", "lockpick set", "ID card", "fake passport", "credit card", "stolen phone",
    "burner phone", "sim card", "CCTV footage", "notebook", "map with markings", "surveillance photo", "threat letter",
    "ransom note", "USB drive", "hard disk", "memory card", "wallet", "driver's license", "personal photo", "watch",
    "ring", "bloodied shoe", "muddy boots", "missing clothing", "ripped shirt", "torn dress", "evidence bag",
    "crime scene tape", "silencer", "tripwire", "booby trap", "grenade", "sniper rifle", "projectile", "smoke bomb",
    "carbon monoxide canister", "fire starter", "fuse wire", "circuit board", "detonator", "gas leak", "blood trail",
    "drag marks", "vomit stain", "fingerprint powder", "DNA swab", "bite mark", "nail scrapings", "glass particles",
    "paint chip", "soil sample", "fibers", "gunshot residue", "blood droplets", "taser", "wire cutters", "pliers",
    "torch", "headlamp", "ladder", "paracord", "spiked object", "whip", "screwdriver", "socket wrench", "drill",
    "nail", "saw", "shovel", "climbing gear", "ski cap", "work boots", "combat boots", "disguise", "makeup kit",
    "fake beard", "wig", "mirror fragment", "red-stained clothing", "soaked glove", "burnt shoes", "burnt wallet",
    "twine", "twisted metal", "credit card reader", "POS skimmer", "pin pad", "cloned card", "forged signature",
    "bank note", "fake cheque", "contract", "confession note", "manifest", "registry", "diary", "blackmail note",
    "spy camera", "drone", "bug detector", "tracking device", "microphone", "hidden recorder", "wiretap",
    "GPS tracker", "browser history", "deleted messages", "burned documents", "printer ink", "USB cable",
    "data cable", "footprint in blood", "smeared prints", "smashed camera", "punch hole", "broken furniture",
    "spray paint", "graffiti", "blood pool", "trail of tears", "hysterical victim", "witness statement",
    "torn notebook", "security tag", "alarm sensor", "motion detector", "entry log", "access card", "forged badge",
    "suspicious note", "powdery substance", "chemical residue", "gas mask", "hazmat suit", "burner laptop",
    "encrypted phone", "person running", "flames fire","Explosion aftermath","Rescue operation","Disaster response","Fire damage",
        "Explosion aftermath", "Fire damage", "Scattered debris", "Black smoke", "Firefighters at scene",
"Destroyed vehicles", "Collapsed buildings", "Rescue operation", "Emergency response units",
"Injured civilians", "Charred remains", "Flash burns", "Smoke inhalation victims", "Craters on ground",
"Damaged infrastructure", "Shattered glass", "Shockwave damage", "Burnt clothing", "Heat distortion",
"Bomb squad presence", "Evacuation in progress", "Military cordon", "Hazmat suits", "Surveillance footage of blast",
"Sound of explosion", "Witness reports", "Forensic chemical traces", "Explosive residue", "Satellite imagery of blast",
"Thermal imaging", "Before/after photos", "Seismograph spike", "Damaged electronics", "Overturned vehicles",
"Melted metals", "High-temperature indicators", "Blast pressure marks", "Security camera destruction",
"Media reporting live", "Flash of light", "Sudden power outages", "Flying debris injuries",
"Hospital surge", "Burn units activated", "Soot-covered victims", "Displaced people", "Airspace restrictions",
"Shockwave-injured animals", "Search and rescue dogs", "Structural engineers onsite",

"Physical altercation", "Punches thrown", "Visible bruises", "Blood stains", "Crowd gathering",
"Police breaking up fight", "Verbal aggression", "Weapons drawn", "Security camera footage", "Bodycam recordings",
"CCTV footage", "Broken bottles", "Pulled hair", "Torn clothing", "Screaming", "Knife wounds",
"Hospital reports", "Eyewitness testimony", "Self-defense stances", "Bystander videos", "Arrest records",
"911 calls", "Social media footage", "Conflicting statements", "Injury reports", "Physical evidence on ground",
"Trampled area", "Thrown objects", "Police batons", "Tasers used", "Pepper spray evidence", "Handcuff marks",
"Footage of chase", "Disturbed public area", "Fight started in queue", "Nightclub security footage",
"Bar fight scene", "Street brawl", "Schoolyard altercation", "Confiscated weapons", "Face injuries",
"Split lips", "Black eyes", "Smashed furniture", "Police statement", "Witness cell footage", "Brawl aftermath",
"Adrenaline effects", "Medical treatment logs", "Offensive gestures",

"Forced entry", "Broken locks", "Surveillance footage", "Mask-wearing suspect", "Glove prints",
"Empty safe", "Stolen valuables", "Threats with weapon", "Demand notes", "Security alarm triggered",
"Witness testimony", "Gunshot sound", "Panic button pressed", "Store clerks hiding", "Cash register emptied",
"Abandoned getaway vehicle", "Dropped loot", "Forensic evidence", "Glove fibers", "Security tape review",
"Bank teller account", "Vault tampering", "ATM ripped open", "Smashed glass", "Crowbar found",
"CCTV malfunction", "Phone jammers", "Multiple suspects", "Police pursuit", "Police sketches",
"Security guard overpowered", "Use of disguise", "Timing of heist", "Inside job suspicion",
"Drilled locks", "Camera blackout", "Customer witness", "Fake IDs", "Unusual deposits", "Tool marks",
"Fingerprint recovery", "Silent alarm logs", "DNA swab from scene", "Safe-cracking tools", "Entry point photos",
"Surrounding business footage", "Abandoned evidence", "Loot bag found", "Fake license plates",
"Escape route tracing", "Police barrier breach",

"Unpaid items in bag", "Security footage", "No receipt", "Concealed merchandise", "Price tag removal",
"Security tag tampering", "In-store alarms", "Store detective report", "Suspicious behavior", "Camera footage",
"Bag check failure", "Hidden compartments", "Items under clothing", "Unusual customer route", "Acting nervously",
"Rushed exit", "Cashier distraction", "Blind spots visited", "Multiple entry attempts", "Dressing room concealment",
"Magnetic detacher possession", "Recovered stolen goods", "Admission of guilt", "Inventory mismatch",
"Staff report", "Frequent offender", "Cellphone jammer", "Fake returns", "Returned stolen items",
"Store layout exploitation", "Distraction technique", "Witness confrontation", "Unattended carts",
"Companion as lookout", "Fake disability use", "Suspicious bag size", "Overcrowded checkout lane",
"Tampered packaging", "Evasive responses", "Rapid clothing change", "Discarded wrappers", "Matching descriptions",
"Unscanned items", "Detained by staff", "Portable scanner blocker", "Frequent item handling",
"Retail chain alert", "Repeat pattern detection", "Reluctance to open bag", "Disguised identity"
] # truncated for example

    frame_paths = sorted([os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith(".jpg")])
    crime_votes = []
    evidence_dict = Counter()

    print(f"🔍 Processing {len(frame_paths)} frames from: {frames_dir}")

    for frame in frame_paths:
        crime_preds = classify_image(model, processor, frame, crime_labels, top_k=1)
        top_crime = crime_preds[0][0]
        crime_votes.append(top_crime)

        evidence_preds = classify_image(model, processor, frame, evidence_labels, top_k=3)
        for label, _ in evidence_preds:
            evidence_dict[label] += 1

    most_common_crime = Counter(crime_votes).most_common(1)[0][0]
    top_evidence = evidence_dict.most_common(10)

    print("\n📌 Final Crime Classification:", most_common_crime.upper())
    print("🧾 Top Evidence Found:")
    for label, count in top_evidence:
        print(f"  - {label} ({count} frames)")

    return {
        "crime_type": most_common_crime,
        "top_evidence": top_evidence
    }

def run_inference(video_path):
    frames_dir = "/tmp/video_frames"  # Temporary folder for frames
    extract_frames_from_video(video_path, frames_dir, frame_rate=1)  # 1 FPS extraction
    return run_inference_on_frames(frames_dir)

# Run inference on your video:
summary = run_inference("/Users/preetham_aleti/Desktop/IMG_8679.MOV")
print(summary)


Extracted 3 frames to /tmp/video_frames
✅ Model loaded.
🔍 Processing 3 frames from: /tmp/video_frames

📌 Final Crime Classification: NORMAL ACTION
🧾 Top Evidence Found:
  - personal photo (1 frames)
  - burner laptop (1 frames)
  - Portable scanner blocker (1 frames)
  - Loot bag found (1 frames)
  - Unscanned items (1 frames)
  - Abandoned evidence (1 frames)
  - Witness cell footage (1 frames)
  - Bodycam recordings (1 frames)
  - Security camera footage (1 frames)
{'crime_type': 'normal action', 'top_evidence': [('personal photo', 1), ('burner laptop', 1), ('Portable scanner blocker', 1), ('Loot bag found', 1), ('Unscanned items', 1), ('Abandoned evidence', 1), ('Witness cell footage', 1), ('Bodycam recordings', 1), ('Security camera footage', 1)]}
