conda create -n drowning_detection python=3.10
conda activate drowning_detection
pip install -r drowning_detection.txt
pip install opencv-python tensorflow


In [2]:
import cv2
import cvlib as cv
from cvlib.object_detection import draw_bbox
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import joblib
from PIL import Image
import time
import albumentations
from IPython.display import display  # For Jupyter
import matplotlib.pyplot as plt

In [3]:
# Define model architecture
class CustomCNN(nn.Module):
    def __init__(self, num_classes):
        super(CustomCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.conv3 = nn.Conv2d(32, 64, 3)
        self.conv4 = nn.Conv2d(64, 128, 5)
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.pool = nn.MaxPool2d(2, 2)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        bs, _, _, _ = x.shape
        x = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
# Cell 3: Load models and setup preprocessing
# Update paths to your actual model files
lb = joblib.load('lb.pkl')
model = CustomCNN(num_classes=len(lb.classes_))
model.load_state_dict(torch.load('model.pth', map_location='cpu'))
model.eval()

aug = albumentations.Compose([
    albumentations.Resize(224, 224),
])

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [7]:
# Cell 4: Detection function (modified for Jupyter)
def detect_drowning(source, max_frames=100):
    cap = cv2.VideoCapture(source)
    if not cap.isOpened():
        print("Error opening video source")
        return
    
    plt.figure(figsize=(12, 8))
    frame_count = 0
    isDrowning = False
    
    try:
        while frame_count < max_frames:
            ret, frame = cap.read()
            if not ret:
                break
                
            # Object detection
            bbox, labels, conf = cv.detect_common_objects(frame)
            
            if len(bbox) == 1:  # Single person detection
                # Prepare image for model
                img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                img_rgb = img_rgb.astype(np.uint8)  # Ensure correct type
                aug_img = aug(image=img_rgb)['image']
                tensor_img = torch.tensor(
                    np.transpose(aug_img, (2, 0, 1)).astype(np.float32))
                tensor_img = tensor_img.unsqueeze(0)
                
                # Model prediction
                with torch.no_grad():
                    outputs = model(tensor_img)
                    _, pred = torch.max(outputs, 1)
                
                status = lb.classes_[pred.item()]
                isDrowning = (status == 'drowning')
                frame = draw_bbox(frame, bbox, labels, conf, isDrowning)
                
            elif len(bbox) > 1:  # Multi-person logic
                centers = []
                for box in bbox:
                    cx = (box[0] + box[2]) / 2
                    cy = (box[1] + box[3]) / 2
                    centers.append((cx, cy))
                
                min_distance = float('inf')
                for i in range(len(centers)):
                    for j in range(i+1, len(centers)):
                        dist = np.sqrt((centers[i][0]-centers[j][0])**2 + 
                                      (centers[i][1]-centers[j][1])**2)
                        min_distance = min(min_distance, dist)
                
                isDrowning = (min_distance < 50) if centers else False
                frame = draw_bbox(frame, bbox, labels, conf, isDrowning)
            
            # Display in notebook
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            plt.imshow(frame_rgb)
            plt.title(f"Frame: {frame_count} | Status: {'DROWNING!' if isDrowning else 'Normal'}")
            plt.axis('off')
            display(plt.gcf())
            clear_output(wait=True)
            
            frame_count += 1
            time.sleep(0.03)  # Control frame rate
            
    finally:
        cap.release()
        plt.close()

In [None]:
# Cell 5: Run detection (choose source)
# For webcam: source = 0
# For video file: source = 'path/to/video.mp4'

# Download YOLOv3 weights and config if not present
import os
import urllib.request

yolo_dir = "yolo"
os.makedirs(yolo_dir, exist_ok=True)

cfg_path = os.path.join(yolo_dir, "yolov3.cfg")
weights_path = os.path.join(yolo_dir, "yolov3.weights")
names_path = os.path.join(yolo_dir, "coco.names")

if not os.path.exists(cfg_path):
	urllib.request.urlretrieve(
		"https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg", cfg_path)
if not os.path.exists(weights_path):
	urllib.request.urlretrieve(
		"https://pjreddie.com/media/files/yolov3.weights", weights_path)
if not os.path.exists(names_path):
	urllib.request.urlretrieve(
		"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names", names_path)

# Patch cv.detect_common_objects to use local YOLO files
import cvlib as cv
def detect_common_objects_custom(img, confidence=0.5, model='yolov3', enable_gpu=False):
	return cv.detect_common_objects(
		img, confidence=confidence, model=model, enable_gpu=enable_gpu,
		config=cfg_path, weights=weights_path, classes=names_path
	)
cv.detect_common_objects = detect_common_objects_custom

detect_drowning(source='videos/test/drowning__006_1.mp4', max_frames=100)  # Runs 100 frames from webcam

Downloading yolov4.weights from https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v3_optimal/yolov4.weights


 99% |####################################################################### |

Downloading yolov3_classes.txt from https://github.com/arunponnusamy/object-detection-opencv/raw/master/yolov3.txt


100% |                                                                        |

error: OpenCV(4.11.0) D:\a\opencv-python\opencv-python\opencv\modules\dnn\src\darknet\darknet_io.cpp:705: error: (-215:Assertion failed) separator_index < line.size() in function 'cv::dnn::darknet::ReadDarknetFromCfgStream'
