In [None]:
import gradio as gr
import cv2
import os
from ultralytics import YOLO
import numpy as np

class PPEDetector:
    def __init__(self, model_path):
        # Load the saved model
        self.model = YOLO(model_path)

        # Class names and colors for visualization
        self.class_names = {
            0: 'Hardhat',
            1: 'Mask',
            2: 'NO-Hardhat',
            3: 'NO-Mask',
            4: 'NO-Safety Vest',
            5: 'Person',
            6: 'Safety Cone',
            7: 'Safety Vest',
            8: 'machinery',
            9: 'vehicle'
        }

        # Define colors for different classes (in BGR format)
        self.colors = {
            'Hardhat': (0, 255, 0),      # Green
            'Mask': (0, 255, 0),         # Green
            'NO-Hardhat': (0, 0, 255),   # Red
            'NO-Mask': (0, 0, 255),      # Red
            'NO-Safety Vest': (0, 0, 255),# Red
            'Person': (255, 255, 0),     # Cyan
            'Safety Cone': (255, 165, 0), # Orange
            'Safety Vest': (0, 255, 0),   # Green
            'machinery': (128, 0, 128),   # Purple
            'vehicle': (128, 0, 128)      # Purple
        }

    def draw_violations(self, frame, results):
        """Draw bounding boxes and violations on the frame"""
        for result in results:
            boxes = result.boxes.cpu().numpy()
            for box in boxes:
                # Get box coordinates
                x1, y1, x2, y2 = map(int, box.xyxy[0])

                # Get class name and confidence
                class_id = int(box.cls[0])
                conf = float(box.conf[0])
                class_name = self.class_names[class_id]

                # Get color based on class
                color = self.colors.get(class_name, (255, 255, 255))

                # Draw bounding box
                cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)

                # Draw label with confidence
                label = f'{class_name} {conf:.2f}'
                (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
                cv2.rectangle(frame, (x1, y1-label_height-10), (x1+label_width, y1), color, -1)
                cv2.putText(frame, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

        return frame

    def process_video(self, video_path):
        """Process video and return the processed video path"""
        # Open video capture
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise Exception("Error: Could not open video.")

        # Get video properties
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(cap.get(cv2.CAP_PROP_FPS))

        # Initialize video writer
        output_path = 'output_video.mp4'
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

        # Process video frames
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Perform detection
            results = self.model.predict(frame, conf=0.25, iou=0.45)

            # Draw detections on frame
            annotated_frame = self.draw_violations(frame.copy(), results)

            # Write the annotated frame to the output video
            out.write(annotated_frame)

        # Clean up
        cap.release()
        out.release()
        return output_path

# Gradio function
def detect_ppe(video):
    detector = PPEDetector('ppe_detection_model.pt')  # Load your model
    output_path = detector.process_video(video)
    return output_path

# Gradio interface
iface = gr.Interface(
    fn=detect_ppe,
    inputs=gr.Video(label="Upload a Video"),
    outputs=gr.Video(label="Processed Video"),
    title="PPE Detector",
    description="Upload a video to detect PPE violations."
)

iface.launch()
