In [1]:
import os
import sys
import torch
import av
import cv2
import numpy as np
import time
from PIL import Image
from collections import deque
from threading import Thread, Lock
import torchvision.transforms as T
from transformers import VivitImageProcessor, VivitForVideoClassification
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, 
                            QHBoxLayout, QLabel, QPushButton, QSlider, QStyle, 
                            QFileDialog, QFrame)
from PyQt5.QtCore import Qt, QTimer, pyqtSignal, QThread, QMutex
from PyQt5.QtGui import QImage, QPixmap, QPainter, QColor, QPen, QBrush

# Paths and constants
SAVED_MODEL_PATH = 'F:/SRC_Bhuvaneswari/typpo/Crimenet/VisTra/Checkpoints/flow/best_model_acc.pt'
LABEL_MAP = {0: 'Normal', 1: 'Explosion', 2: 'Fighting', 3: 'Car Accident', 4: 'Shooting', 5: 'Riot'}
COLOR_MAP = {
    'Normal': QColor(50, 200, 50),     # Green
    'Explosion': QColor(255, 127, 0),  # Orange
    'Fighting': QColor(200, 50, 50),   # Red
    'Car Accident': QColor(50, 50, 200), # Blue
    'Shooting': QColor(200, 0, 200),   # Purple
    'Riot': QColor(255, 255, 0)        # Yellow
}
CLIP_LEN = 32
WINDOW_SIZE_SECONDS = 3  # Process 3 seconds of video at a time

# Load model and processor
print("Loading ViVit model and processor...")
processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2", do_rescale=None, offset=None)
model = VivitForVideoClassification.from_pretrained(
    "google/vivit-b-16x2",
    num_labels=len(LABEL_MAP),
    ignore_mismatched_sizes=True
)

# Load saved weights
model.load_state_dict(torch.load(SAVED_MODEL_PATH))
model.eval()

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded on {device}")

# Transform for preprocessing frames
transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor()
])

Loading ViVit model and processor...


Some weights of VivitForVideoClassification were not initialized from the model checkpoint at google/vivit-b-16x2 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([6, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([6]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on cuda


In [2]:
# Class for prediction worker thread
class PredictionWorker(QThread):
    predictionReady = pyqtSignal(float, float, str)
    
    def __init__(self, model, processor, transform):
        super().__init__()
        self.model = model
        self.processor = processor
        self.transform = transform
        self.frames_queue = deque()
        self.timestamps = []
        self.running = True
        self.mutex = QMutex()
        
    def add_frames(self, frames, start_time, end_time):
        self.mutex.lock()
        self.frames_queue.append((frames, start_time, end_time))
        self.mutex.unlock()
        
    def stop(self):
        self.running = False
        
    def run(self):
        while self.running:
            if self.frames_queue:
                self.mutex.lock()
                frames, start_time, end_time = self.frames_queue.popleft()
                self.mutex.unlock()
                
                if len(frames) < CLIP_LEN:
                    # Pad with repeated frames if needed
                    if len(frames) > 0:
                        frames = frames + [frames[-1]] * (CLIP_LEN - len(frames))
                    else:
                        continue
                
                # Process and make prediction
                try:
                    # Preprocess frames
                    processed_frames = [self.transform(frame) for frame in frames]
                    frames_numpy = [frame.permute(1, 2, 0).numpy() for frame in processed_frames]
                    
                    # Process with ViVit processor
                    inputs = self.processor(frames_numpy, return_tensors="pt")
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                    
                    # Make prediction
                    with torch.no_grad():
                        outputs = self.model(**inputs)
                        logits = outputs.logits
                        predicted_id = torch.argmax(logits, dim=-1).item()
                        predicted_class = LABEL_MAP[predicted_id]
                    
                    # Emit the prediction result
                    self.predictionReady.emit(start_time, end_time, predicted_class)
                    
                except Exception as e:
                    print(f"Error making prediction: {e}")
            
            # Sleep to avoid high CPU usage
            time.sleep(0.1)

In [3]:
# Main application window
class VideoPlayerWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        
        # Setup UI
        self.setWindowTitle("Anomaly Detection Video Player")
        self.setGeometry(100, 100, 1000, 600)
        
        # Create central widget and layout
        self.central_widget = QWidget()
        self.setCentralWidget(self.central_widget)
        self.layout = QVBoxLayout(self.central_widget)
        
        # Video display area
        self.video_frame = QLabel()
        self.video_frame.setAlignment(Qt.AlignCenter)
        self.video_frame.setMinimumSize(640, 480)
        self.video_frame.setStyleSheet("background-color: black;")
        self.layout.addWidget(self.video_frame)
        
        # Current classification display
        self.classification_label = QLabel("Classification: None")
        self.classification_label.setAlignment(Qt.AlignCenter)
        self.classification_label.setStyleSheet("font-size: 18px; font-weight: bold;")
        self.layout.addWidget(self.classification_label)
        
        # Timeline widget
        self.timeline_widget = TimelineWidget()
        self.layout.addWidget(self.timeline_widget)
        
        # Controls layout
        self.controls_layout = QHBoxLayout()
        
        # Play/Pause button
        self.play_button = QPushButton()
        self.play_button.setIcon(self.style().standardIcon(QStyle.SP_MediaPlay))
        self.play_button.clicked.connect(self.toggle_play)
        self.controls_layout.addWidget(self.play_button)
        
        # Time display
        self.time_label = QLabel("00:00 / 00:00")
        self.controls_layout.addWidget(self.time_label)
        
        # Position slider
        self.position_slider = QSlider(Qt.Horizontal)
        self.position_slider.sliderMoved.connect(self.set_position)
        self.controls_layout.addWidget(self.position_slider)
        
        # Open file button
        self.open_button = QPushButton("Open Video")
        self.open_button.clicked.connect(self.open_file)
        self.controls_layout.addWidget(self.open_button)
        
        self.layout.addLayout(self.controls_layout)
        
        # Video processing variables
        self.cap = None
        self.timer = QTimer(self)
        self.timer.timeout.connect(self.update_frame)
        self.current_frame = 0
        self.fps = 0
        self.total_frames = 0
        self.playing = False
        
        # Anomaly detection variables
        self.anomaly_segments = []
        self.current_window_frames = []
        self.current_window_start_time = 0
        self.frames_since_last_prediction = 0
        self.prediction_interval = 0  # Will be set based on video fps
        
        # Start prediction thread
        self.prediction_worker = PredictionWorker(model, processor, transform)
        self.prediction_worker.predictionReady.connect(self.update_anomaly)
        self.prediction_worker.start()

    def open_file(self):
        file_path, _ = QFileDialog.getOpenFileName(self, "Open Video File", "", 
                                                 "Video Files (*.mp4 *.avi *.mkv *.mov);;All Files (*)")
        if file_path:
            self.load_video(file_path)
    
    def load_video(self, file_path):
        # Stop current video if playing
        if self.playing:
            self.timer.stop()
            self.playing = False
        
        # Release previous capture if any
        if self.cap is not None:
            self.cap.release()
        
        # Open new video
        self.cap = cv2.VideoCapture(file_path)
        if not self.cap.isOpened():
            print(f"Error: Could not open video {file_path}")
            return
        
        # Get video properties
        self.fps = self.cap.get(cv2.CAP_PROP_FPS)
        self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        self.current_frame = 0
        
        # Set prediction interval based on fps and window size
        self.prediction_interval = int(self.fps * WINDOW_SIZE_SECONDS)
        
        # Clear previous anomalies
        self.anomaly_segments = []
        self.timeline_widget.set_anomalies([])
        self.timeline_widget.set_duration(self.total_frames / self.fps)
        
        # Reset UI
        self.position_slider.setRange(0, self.total_frames)
        duration_str = self.format_time(self.total_frames / self.fps)
        self.time_label.setText(f"00:00 / {duration_str}")
        self.classification_label.setText("Classification: None")
        
        # Reset frame collection
        self.current_window_frames = []
        self.frames_since_last_prediction = 0
        self.current_window_start_time = 0
        
        # Show first frame
        ret, frame = self.cap.read()
        if ret:
            self.display_frame(frame)
    
    def toggle_play(self):
        if self.cap is None:
            return
            
        if self.playing:
            self.timer.stop()
            self.play_button.setIcon(self.style().standardIcon(QStyle.SP_MediaPlay))
        else:
            self.timer.start(1000 // 30)  # 30 fps display
            self.play_button.setIcon(self.style().standardIcon(QStyle.SP_MediaPause))
        
        self.playing = not self.playing
    
    def update_frame(self):
        if self.cap is None or not self.playing:
            return
            
        ret, frame = self.cap.read()
        if not ret:
            # End of video
            self.timer.stop()
            self.playing = False
            self.play_button.setIcon(self.style().standardIcon(QStyle.SP_MediaPlay))
            return
        
        # Convert frame for prediction (PIL Image) - FIXED CONVERSION
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_frame = Image.fromarray(frame_rgb)
        
        # Display the frame
        self.display_frame(frame)
        
        # Collect frames for prediction
        current_time = self.current_frame / self.fps
        
        # If this is the first frame in a window, set the start time
        if len(self.current_window_frames) == 0:
            self.current_window_start_time = current_time
        
        # Add frame to current window
        self.current_window_frames.append(pil_frame)
        self.frames_since_last_prediction += 1
        
        # If enough frames collected, process them
        if self.frames_since_last_prediction >= self.prediction_interval:
            # Subsample to get CLIP_LEN frames
            if len(self.current_window_frames) > CLIP_LEN:
                indices = np.linspace(0, len(self.current_window_frames) - 1, CLIP_LEN, dtype=int)
                frames_to_process = [self.current_window_frames[i] for i in indices]
            else:
                frames_to_process = self.current_window_frames
            
            # Add to prediction queue
            self.prediction_worker.add_frames(
                frames_to_process, 
                self.current_window_start_time,
                current_time
            )
            
            # Reset for next window
            self.current_window_frames = []
            self.frames_since_last_prediction = 0
        
        # Update slider and time label
        self.position_slider.setValue(self.current_frame)
        current_time_str = self.format_time(current_time)
        duration_str = self.format_time(self.total_frames / self.fps)
        self.time_label.setText(f"{current_time_str} / {duration_str}")
        
        # Update classification display
        current_class = self.get_classification_at_time(current_time)
        self.classification_label.setText(f"Classification: {current_class}")
        
        # Set classification color
        color = COLOR_MAP.get(current_class, QColor(100, 100, 100))
        self.classification_label.setStyleSheet(f"font-size: 18px; font-weight: bold; color: rgb({color.red()}, {color.green()}, {color.blue()})")
        
        # Increment frame counter
        self.current_frame += 1
    
    def display_frame(self, frame):
        # Convert frame to QImage and display
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        h, w, ch = frame_rgb.shape
        bytes_per_line = ch * w
        q_img = QImage(frame_rgb.data, w, h, bytes_per_line, QImage.Format_RGB888)
        self.video_frame.setPixmap(QPixmap.fromImage(q_img).scaled(
            self.video_frame.width(), self.video_frame.height(), 
            Qt.KeepAspectRatio, Qt.SmoothTransformation))
    
    def set_position(self, position):
        if self.cap is None:
            return
            
        # Seek to position
        self.cap.set(cv2.CAP_PROP_POS_FRAMES, position)
        self.current_frame = position
        
        # Read and display the frame
        ret, frame = self.cap.read()
        if ret:
            self.display_frame(frame)
            
            # Update time label
            current_time = position / self.fps
            current_time_str = self.format_time(current_time)
            duration_str = self.format_time(self.total_frames / self.fps)
            self.time_label.setText(f"{current_time_str} / {duration_str}")
            
            # Update classification display
            current_class = self.get_classification_at_time(current_time)
            self.classification_label.setText(f"Classification: {current_class}")
            
            # Set classification color
            color = COLOR_MAP.get(current_class, QColor(100, 100, 100))
            self.classification_label.setStyleSheet(f"font-size: 18px; font-weight: bold; color: rgb({color.red()}, {color.green()}, {color.blue()})")
    
    def update_anomaly(self, start_time, end_time, anomaly_type):
        # Add new anomaly segment
        self.anomaly_segments.append((start_time, end_time, anomaly_type))
        # Update timeline display
        self.timeline_widget.set_anomalies(self.anomaly_segments)
    
    def get_classification_at_time(self, time_point):
        # Find the most recent classification for the current time
        for start_time, end_time, anomaly_type in reversed(self.anomaly_segments):
            if start_time <= time_point and time_point <= end_time:
                return anomaly_type
        return "None"
    
    def format_time(self, seconds):
        minutes, seconds = divmod(int(seconds), 60)
        hours, minutes = divmod(minutes, 60)
        return f"{hours:02d}:{minutes:02d}:{seconds:02d}" if hours else f"{minutes:02d}:{seconds:02d}"
    
    def closeEvent(self, event):
        # Clean up resources
        if self.cap is not None:
            self.cap.release()
        self.prediction_worker.stop()
        self.prediction_worker.wait()
        event.accept()

In [4]:
# Timeline widget to show anomaly segments
class TimelineWidget(QFrame):
    def __init__(self):
        super().__init__()
        self.setMinimumHeight(40)
        self.setStyleSheet("background-color: #2d2d2d;")
        
        self.anomalies = []
        self.duration = 0
        
    def set_anomalies(self, anomalies):
        self.anomalies = anomalies
        self.update()
        
    def set_duration(self, duration):
        self.duration = duration
        self.update()
        
    def paintEvent(self, event):
        if self.duration <= 0:
            return
            
        painter = QPainter(self)
        painter.setRenderHint(QPainter.Antialiasing)
        
        # Draw background
        painter.fillRect(self.rect(), QBrush(QColor(45, 45, 45)))
        
        # Draw timeline base
        painter.setPen(QPen(QColor(200, 200, 200), 1))
        y_middle = self.height() // 2
        painter.drawLine(0, y_middle, self.width(), y_middle)
        
        # Draw time markers
        painter.setPen(QPen(QColor(150, 150, 150), 1))
        marker_interval = self.width() / 10
        for i in range(11):
            x = i * marker_interval
            painter.drawLine(int(x), y_middle - 5, int(x), y_middle + 5)
            
            # Draw time text
            time_at_marker = (i / 10) * self.duration
            minutes = int(time_at_marker / 60)
            seconds = int(time_at_marker % 60)
            time_text = f"{minutes:02d}:{seconds:02d}"
            painter.drawText(int(x) - 15, y_middle + 20, time_text)
        
        # Draw anomaly segments
        for start_time, end_time, anomaly_type in self.anomalies:
            if start_time >= self.duration:
                continue
                
            # Calculate positions
            start_pos = int((start_time / self.duration) * self.width())
            end_pos = int((min(end_time, self.duration) / self.duration) * self.width())
            
            # Get color for anomaly type
            color = COLOR_MAP.get(anomaly_type, QColor(100, 100, 100))
            
            # Draw segment
            painter.fillRect(start_pos, 5, end_pos - start_pos, self.height() - 10, QBrush(color))
            
            # Draw label if segment is wide enough
            if end_pos - start_pos > 50:
                painter.setPen(QPen(QColor(255, 255, 255), 1))
                painter.drawText(start_pos + 5, y_middle + 5, anomaly_type)

In [None]:
# Main application
if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = VideoPlayerWindow()
    window.show()
    sys.exit(app.exec_())