<a href="https://colab.research.google.com/github/Soutrik05/survival_predictors/blob/main/Untitled8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install ultralytics streamlit opencv-python-headless
!pip install gradio
!pip install supervision

Collecting supervision
  Downloading supervision-0.25.1-py3-none-any.whl.metadata (14 kB)
Downloading supervision-0.25.1-py3-none-any.whl (181 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m181.5/181.5 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: supervision
Successfully installed supervision-0.25.1


In [7]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import gradio as gr
from ultralytics import YOLO
from collections import deque

class AmbulanceCNN(nn.Module):
    def __init__(self):
        super(AmbulanceCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 2)  # 2 classes: ambulance or not
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class LightPatternDetector:
    def __init__(self, buffer_size=15):
        self.light_buffer = deque(maxlen=buffer_size)
        self.buffer_size = buffer_size

    def detect_flashing(self, frame_roi):
        # Convert ROI to HSV
        hsv = cv2.cvtColor(frame_roi, cv2.COLOR_BGR2HSV)

        # Define color ranges for emergency lights
        blue_lower = np.array([100, 100, 100])
        blue_upper = np.array([140, 255, 255])
        red_lower = np.array([0, 100, 100])
        red_upper = np.array([10, 255, 255])

        # Create masks for both colors
        blue_mask = cv2.inRange(hsv, blue_lower, blue_upper)
        red_mask = cv2.inRange(hsv, red_lower, red_upper)

        # Morphological operations to reduce noise
        kernel = np.ones((5,5), np.uint8)
        blue_mask = cv2.morphologyEx(blue_mask, cv2.MORPH_OPEN, kernel)
        red_mask = cv2.morphologyEx(red_mask, cv2.MORPH_OPEN, kernel)

        # Calculate the presence of emergency lights
        blue_pixels = np.sum(blue_mask > 0)
        red_pixels = np.sum(red_mask > 0)

        light_intensity = (blue_pixels + red_pixels) / (frame_roi.shape[0] * frame_roi.shape[1])
        self.light_buffer.append(light_intensity)

        if len(self.light_buffer) == self.buffer_size:
            variance = np.var(list(self.light_buffer))
            mean = np.mean(list(self.light_buffer))
            return variance > 0.001 and mean > 0.05

        return False

def preprocess_for_cnn(roi):
    # Preprocess image for CNN
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    return transform(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)).unsqueeze(0)

def detect_text_opencv(roi):
    # Convert to grayscale
    gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)

    # Apply threshold
    _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)

    # Use OCR-like features
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Filter contours that might be text
    text_like_contours = 0
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        aspect_ratio = w / float(h)
        if 0.1 < aspect_ratio < 10:  # Typical text aspect ratio
            text_like_contours += 1

    return text_like_contours > 3  # Assume text present if multiple text-like contours found

def load_models():
    yolo_model = YOLO('yolov8n.pt')
    cnn_model = AmbulanceCNN()
    # In a real application, you would load pretrained weights:
    # cnn_model.load_state_dict(torch.load('ambulance_cnn_weights.pth'))
    cnn_model.eval()
    return yolo_model, cnn_model

def detect_ambulance(frame, yolo_model, cnn_model, light_detector):
    # OpenCV color detection
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    white_lower = np.array([0, 0, 200])
    white_upper = np.array([180, 30, 255])
    white_mask = cv2.inRange(hsv, white_lower, white_upper)

    # YOLO detection
    results = yolo_model(frame)

    ambulance_detected = False
    for result in results:
        boxes = result.boxes
        for box in boxes:
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            roi = frame[y1:y2, x1:x2]
            if roi.size == 0:
                continue

            # Multiple detection methods
            roi_hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
            white_ratio = np.sum(cv2.inRange(roi_hsv, white_lower, white_upper)) / roi.size
            has_flashing_lights = light_detector.detect_flashing(roi)
            has_ambulance_text = detect_text_opencv(roi)

            # CNN prediction
            with torch.no_grad():
                cnn_input = preprocess_for_cnn(roi)
                cnn_prediction = torch.sigmoid(cnn_model(cnn_input))
                is_ambulance_cnn = cnn_prediction[0][1] > 0.5

            # Combine all detection methods
            if (white_ratio > 0.3 and has_flashing_lights) or \
               (is_ambulance_cnn and has_flashing_lights) or \
               (has_ambulance_text and has_flashing_lights):
                ambulance_detected = True
                cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(frame, 'Ambulance', (x1, y1-10),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
                cv2.putText(frame, 'Emergency Lights Active', (x1, y2+20),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)

    signal_state = "GREEN" if ambulance_detected else "RED"
    return frame, signal_state

def process_video(video_path):
    try:
        yolo_model, cnn_model = load_models()
        light_detector = LightPatternDetector()
        cap = cv2.VideoCapture(video_path)

        frames = []
        signal_states = []

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            processed_frame, signal_state = detect_ambulance(frame, yolo_model, cnn_model, light_detector)
            frames.append(processed_frame)
            signal_states.append(signal_state)

        cap.release()

        return frames[-1] if frames else None, signal_states[-1] if signal_states else "RED"
    except Exception as e:
        return None, f"Error: {str(e)}"

def create_interface():
    with gr.Blocks(theme=gr.themes.Base()) as interface:
        gr.Markdown("# Advanced Ambulance Detection System")
        gr.Markdown("""
        ## Detection Methods:
        1. CNN-based vehicle classification
        2. OpenCV color and pattern analysis
        3. YOLO object detection
        4. Emergency light pattern detection
        5. Text detection for 'AMBULANCE' marking
        """)

        with gr.Row():
            with gr.Column():
                video_input = gr.Video(label="Upload Video")
            with gr.Column():
                image_output = gr.Image(label="Processed Frame")
                signal_output = gr.Label(label="Traffic Signal State")

        submit_btn = gr.Button("Process Video")
        submit_btn.click(
            fn=process_video,
            inputs=[video_input],
            outputs=[image_output, signal_output]
        )

    return interface

if __name__ == "__main__":
    interface = create_interface()
    interface.launch()

0: 384x640 16 cars, 1 bus, 492.5ms
Speed: 3.6ms preprocess, 492.5ms inference, 1.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 15 cars, 2 buss, 958.1ms
Speed: 6.8ms preprocess, 958.1ms inference, 1.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 16 cars, 2 buss, 606.3ms
Speed: 24.0ms preprocess, 606.3ms inference, 2.3ms postprocess per image at shape (1, 3, 384, 640)

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
0: 384x640 1 person, 17 cars, 1 bus, 844.2ms
Speed: 28.6ms preprocess, 844.2ms inference, 1.7ms postprocess per image at shape (1, 3, 384, 640)

* Running on public URL: https://ab71d20903e0776b0e.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the 