In [12]:
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# Load model
model = tf.keras.models.load_model("custom_cnn_model.keras")
target_classes = ['safe_driving', 'using_phone', 'drinking']

def analyze_video_for_offences(video_path, min_offence_duration=10, fps_extract=5, conf_threshold=0.9, num_end_frames=3):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("❌ Could not open video!")
        return {}
    
    real_fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    video_duration = total_frames / real_fps
    
    frame_interval = int(real_fps / fps_extract)  # sample at ~5fps
    frame_idx = 0
    timeline = []  # (timestamp, predicted_class, confidence)
    
    # Pass 1: build timeline of predictions
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        if frame_idx % frame_interval == 0:
            timestamp_sec = frame_idx / real_fps
            
            # ✅ Correct RGB preprocessing
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            img = cv2.resize(frame_rgb, (224,224)).astype("float32")/255.0
            img = np.expand_dims(img, axis=0)
            
            preds = model.predict(img, verbose=0)[0]
            top_idx = np.argmax(preds)
            top_conf = preds[top_idx]
            
            if top_conf < conf_threshold:
                pred_class = "UNKNOWN"
            else:
                pred_class = target_classes[top_idx]
            
            timeline.append((timestamp_sec, pred_class, top_conf))
        
        frame_idx += 1
    
    cap.release()
    print(f"✅ Processed {len(timeline)} sampled frames over {video_duration:.1f}s")
    
    # Merge consecutive same-class frames into segments
    segments = []
    if timeline:
        current_class = timeline[0][1]
        start_time = timeline[0][0]
        for i in range(1, len(timeline)):
            ts, cls, _ = timeline[i]
            if cls != current_class:
                segments.append((start_time, timeline[i-1][0], current_class))
                current_class = cls
                start_time = ts
        # Add last
        segments.append((start_time, timeline[-1][0], current_class))
    
    # Identify offence segments lasting >min_offence_duration
    offences = []
    for (start, end, cls) in segments:
        duration = end - start
        if cls in ['using_phone','drinking'] and duration >= min_offence_duration:
            # Get all timeline entries within this segment
            segment_entries = [(ts, c, conf) for ts,c,conf in timeline if start <= ts <= end and c == cls]
            
            if segment_entries:
                # Pick frame with max confidence
                best_frame = max(segment_entries, key=lambda x: x[2])  # (timestamp, cls, confidence)
                best_ts = best_frame[0]
                best_conf = best_frame[2]
                
                # Also pick last N frames near end of segment for context
                segment_end_entries = sorted(segment_entries, key=lambda x: x[0])[-num_end_frames:]
                end_frame_times = [ts for ts, _, _ in segment_end_entries]
                
                offences.append((start, end, cls, best_ts, best_conf, end_frame_times))
    
    # Extract snapshots for each offence + combine into subplots
    results = {"offences": []}
    
    if offences:
        cap = cv2.VideoCapture(video_path)
        for (start, end, cls, best_ts, best_conf, end_frame_times) in offences:
            frame_images = []
            frame_labels = []
            
            # Fetch BEST confidence frame
            cap.set(cv2.CAP_PROP_POS_MSEC, best_ts * 1000)
            ret, frame = cap.read()
            if ret:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame_images.append(frame_rgb)
                frame_labels.append(f"BEST ({best_conf:.2f})")
            
            # Fetch last N frames from segment end
            for ts in end_frame_times:
                cap.set(cv2.CAP_PROP_POS_MSEC, ts * 1000)
                ret, frame = cap.read()
                if ret:
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frame_images.append(frame_rgb)
                    frame_labels.append(f"End @ {int(ts)}s")
            
            # Plot them together as subplots
            n_frames = len(frame_images)
            fig, axes = plt.subplots(1, n_frames, figsize=(4*n_frames, 4))
            
            if n_frames == 1:
                axes.imshow(frame_images[0])
                axes.set_title(frame_labels[0])
                axes.axis("off")
            else:
                for i, ax in enumerate(axes):
                    ax.imshow(frame_images[i])
                    ax.set_title(frame_labels[i])
                    ax.axis("off")
            
            # Save combined figure
            combined_name = f"offence_{cls}_{int(start)}s_to_{int(end)}s_combined.jpg"
            plt.suptitle(f"{cls.upper()} from {start:.1f}s to {end:.1f}s (duration {end-start:.1f}s)")
            plt.tight_layout()
            plt.savefig(combined_name)
            plt.close()
            
            results["offences"].append({
                "class": cls,
                "start": start,
                "end": end,
                "duration": end - start,
                "best_confidence": best_conf,
                "combined_snapshot": combined_name,
                "frame_count": n_frames
            })
        cap.release()
    
    # Print final offence report
    print("\n🚨 Offence Report:")
    for offence in results["offences"]:
        print(f" - {offence['class'].upper()} [{offence['best_confidence']*100:.1f}%] "
              f"from {offence['start']:.1f}s to {offence['end']:.1f}s "
              f"→ Combined snapshot: {offence['combined_snapshot']}")
    
    return results


In [15]:

results = analyze_video_for_offences(
    video_path="../examples/input/bus_driver.mp4",
    min_offence_duration=3,
    fps_extract=5,
    conf_threshold=0.9,
    num_end_frames=3  # save 3 end-context frames
)


✅ Processed 311 sampled frames over 62.2s

🚨 Offence Report:
 - USING_PHONE [99.9%] from 0.2s to 5.4s → Combined snapshot: offence_using_phone_0s_to_5s_combined.jpg
 - USING_PHONE [99.7%] from 9.2s to 12.2s → Combined snapshot: offence_using_phone_9s_to_12s_combined.jpg
 - USING_PHONE [99.8%] from 15.8s to 20.6s → Combined snapshot: offence_using_phone_15s_to_20s_combined.jpg
 - DRINKING [100.0%] from 56.4s to 59.4s → Combined snapshot: offence_drinking_56s_to_59s_combined.jpg
