In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import av
import numpy as np
import torchvision.transforms as T
from transformers import VivitImageProcessor, VivitForVideoClassification
from tqdm import tqdm

# Paths and constants
SAVED_MODEL_PATH = 'F:/SRC_Bhuvaneswari/typpo/Crimenet/VisTra/Checkpoints/v1.0/best_model_acc.pt'
LABEL_MAP = {0: 'Normal', 1: 'Explosion', 2: 'Fighting', 3: 'Car Accident', 4: 'Shooting', 5: 'Riot'}
CLIP_LEN = 32
FRAME_SAMPLE_RATE = 1

# Load ViVit processor and model with saved weights
processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2", do_rescale=None, offset=None)
model = VivitForVideoClassification.from_pretrained(
    "google/vivit-b-16x2",
    num_labels=len(LABEL_MAP),
    ignore_mismatched_sizes=True
)


Some weights of VivitForVideoClassification were not initialized from the model checkpoint at google/vivit-b-16x2 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([6, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([6]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
# Load saved weights
model.load_state_dict(torch.load(SAVED_MODEL_PATH))
model.eval()

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Function to extract frames from video
def extract_frames_from_video(video_path, num_frames=CLIP_LEN, sample_rate=FRAME_SAMPLE_RATE):
    frames = []
    try:
        container = av.open(video_path)
        # Get total number of frames
        total_frames = container.streams.video[0].frames
        
        # Calculate indices to sample
        if total_frames >= num_frames * sample_rate:
            # Uniform sampling
            indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
        else:
            # If video is too short, loop the video
            indices = np.arange(0, num_frames * sample_rate, sample_rate) % total_frames
        
        # Extract frames at the selected indices
        container.seek(0)
        for i, frame in enumerate(container.decode(video=0)):
            if i in indices:
                # Convert to PIL Image and apply basic transformations
                img = frame.to_image()
                frames.append(img)
            if len(frames) == num_frames:
                break
                
        container.close()
    except Exception as e:
        print(f"Error extracting frames from {video_path}: {e}")
        return None
    
    return frames

In [3]:
# Function to preprocess frames and make predictions
def predict_video_class(video_path):
    # Extract frames
    frames = extract_frames_from_video(video_path)
    
    if frames is None or len(frames) < CLIP_LEN:
        print(f"Could not extract enough frames from {video_path}")
        return "Error: Insufficient frames"
    
    # Process frames with the image processor
    # Convert frames to format expected by the processor
    transform = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor()
    ])
    
    processed_frames = [transform(frame) for frame in frames]
    frames_tensor = torch.stack(processed_frames)
    
    # Convert to numpy for processor
    frames_numpy = [frame.permute(1, 2, 0).numpy() for frame in processed_frames]
    
    # Process with ViVit processor
    inputs = processor(frames_numpy, return_tensors="pt")
    
    # Move to device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_id = torch.argmax(logits, dim=-1).item()
    
    return LABEL_MAP[predicted_id]

In [4]:
# Sample usage
video_file_path = "E:/SRC-Bhuvaneswari/VAD_XDViolence/ViVi/Dataset/XD Violence/Test/Brick.Mansions.2014__#00-16-26_00-17-12_label_B1-0-0.mp4"
predicted_class = predict_video_class(video_file_path)
print(f"Predicted video class: {predicted_class}")

Predicted video class: Fighting
