In [None]:
# Groq API Key -- put GROQ API KEY HERE
!pip install groq



In [None]:
import cv2
import time
import os
import json
import base64
import re
from datetime import datetime
from groq import Groq

class SimpleDrivingAssistant:
    def __init__(self, video_path, groq_api_key, output_dir="output", frame_interval=1.0):
        """
        Initialize the simple driving assistant.

        Args:
            video_path: Path to the input video file
            groq_api_key: Your Groq API key
            output_dir: Directory to save processed frames and descriptions
            frame_interval: Interval in seconds between frames to process
        """
        self.video_path = video_path
        self.groq_api_key = groq_api_key
        self.client = Groq(api_key=groq_api_key)
        self.output_dir = output_dir
        self.frame_interval = frame_interval
        self.analysis_results = []

        # Create output directories
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        self.frames_dir = os.path.join(output_dir, "frames")
        if not os.path.exists(self.frames_dir):
            os.makedirs(self.frames_dir)

    def extract_frames(self):
        """Extract frames from video at specified intervals"""
        cap = cv2.VideoCapture(self.video_path)

        if not cap.isOpened():
            raise ValueError(f"Could not open video file: {self.video_path}")

        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        duration = total_frames / fps

        print(f"Video FPS: {fps}")
        print(f"Total frames: {total_frames}")
        print(f"Duration: {duration:.2f} seconds")

        # Calculate frame indices to extract
        frame_indices = []
        for second in range(int(duration)):
            frame_indices.append(int(second * fps))

        # Extract frames
        frame_paths = []
        for i, frame_idx in enumerate(frame_indices):
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()

            if not ret:
                print(f"Failed to read frame at index {frame_idx}")
                continue

            # Save frame
            frame_path = os.path.join(self.frames_dir, f"frame_{i:04d}.jpg")
            cv2.imwrite(frame_path, frame)
            frame_paths.append((frame_path, i))

            # Print progress
            print(f"Extracted frame {i+1}/{len(frame_indices)} at time {i}s")

        cap.release()
        return frame_paths

    def get_simple_analysis(self, image_path):
        """
        Get simple analysis from Groq focusing only on key driving elements.

        Args:
            image_path: Path to the image file

        Returns:
            Simple analysis dict
        """
        try:
            # Read image file and encode as base64
            with open(image_path, "rb") as f:
                image_data = f.read()
                base64_image = base64.b64encode(image_data).decode("utf-8")

            # Format as data URL
            frame_data = f"data:image/jpeg;base64,{base64_image}"

            # Call Groq API with focused prompt
            chat_completion = self.client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": """Analyze this driving scene and respond with ONLY a JSON object in this exact format:
                            {
                                "road_signs": ["description of each sign visible"],
                                "intersections": ["description of upcoming intersections"],
                                "pedestrians": ["location of pedestrians", "etc"],
                                "vehicles": ["cars in front", "oncoming traffic", "etc"],
                                "lane_status": "brief description of lane positioning",
                                "immediate_hazards": ["description of any immediate hazards"]
                            }

                            Keep your responses short and focused on what you can actually see.
                            If you don't see any of these elements, use an empty array [].
                            The lane_status should always be a string.
                            DO NOT include any additional text before or after the JSON."""},
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": frame_data,
                                },
                            },
                        ],
                    }
                ],
                model="llama-3.2-11b-vision-preview",
            )

            response_text = chat_completion.choices[0].message.content.strip()

            # Try to parse the JSON response
            try:
                # Find JSON in the response (in case there's extra text)
                json_start = response_text.find('{')
                json_end = response_text.rfind('}') + 1

                if json_start >= 0 and json_end > json_start:
                    json_str = response_text[json_start:json_end]

                    # Fix some common JSON errors that might be in the response
                    # 1. Fix unescaped quotes in strings
                    json_str = json_str.replace('"["', '["')
                    json_str = json_str.replace('"]"', '"]')
                    json_str = json_str.replace('", "', '", "')

                    # 2. Replace any doubled quotes with single quotes
                    json_str = json_str.replace('""', '"')

                    # Try to parse the fixed JSON
                    try:
                        analysis = json.loads(json_str)
                    except Exception as e:
                        print(f"Error parsing JSON: {str(e)}")
                        print(f"Received: {json_str}")
                        # Fallback with empty structure
                        analysis = {
                            "road_signs": [],
                            "intersections": [],
                            "pedestrians": [],
                            "vehicles": [],
                            "lane_status": "Unknown",
                            "immediate_hazards": []
                        }

                    # Ensure all required keys exist
                    required_keys = ["road_signs", "intersections", "pedestrians",
                                    "vehicles", "lane_status", "immediate_hazards"]
                    for key in required_keys:
                        if key not in analysis:
                            if key == "lane_status":
                                analysis[key] = "Unknown"
                            else:
                                analysis[key] = []
                else:
                    # Fallback with empty structure
                    analysis = {
                        "road_signs": [],
                        "intersections": [],
                        "pedestrians": [],
                        "vehicles": [],
                        "lane_status": "Unknown",
                        "immediate_hazards": []
                    }
            except Exception as json_error:
                print(f"Error parsing JSON: {str(json_error)}")
                print(f"Received: {response_text}")
                analysis = {
                    "road_signs": [],
                    "intersections": [],
                    "pedestrians": [],
                    "vehicles": [],
                    "lane_status": f"Unknown",
                    "immediate_hazards": []
                }

            return analysis

        except Exception as e:
            print(f"Error getting analysis: {str(e)}")
            return {
                "road_signs": [],
                "intersections": [],
                "pedestrians": [],
                "vehicles": [],
                "lane_status": f"Error: {str(e)}",
                "immediate_hazards": []
            }

    def process_video(self):
        """Process the video and generate simple analysis for all frames"""
        print(f"Processing video: {self.video_path}")

        # Extract frames
        frame_paths = self.extract_frames()

        # Get analysis for each frame
        for i, (frame_path, second) in enumerate(frame_paths):
            print(f"Getting analysis for frame {i+1}/{len(frame_paths)} at {second}s")

            # Get analysis
            analysis = self.get_simple_analysis(frame_path)

            # Store analysis with timestamp
            self.analysis_results.append({
                "frame_index": i,
                "timestamp": second,
                "frame_path": frame_path,
                "analysis": analysis
            })

            # Save analysis to file periodically
            if i % 5 == 0 or i == len(frame_paths) - 1:
                self.save_analysis()

            # Add a small delay to avoid API rate limits
            time.sleep(0.5)

        return self.analysis_results

    def save_analysis(self):
        """Save analysis to a JSON file"""
        output_file = os.path.join(self.output_dir, "driving_analysis.json")
        with open(output_file, "w") as f:
            json.dump(self.analysis_results, f, indent=2)
        print(f"Saved analysis to {output_file}")

    def create_output_video(self, output_video_path=None):
        """
        Create a new video with simple driving information overlay

        Args:
            output_video_path: Path for the output video file
        """
        if not output_video_path:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_video_path = os.path.join(self.output_dir, f"simple_driving_assistant_{timestamp}.mp4")

        # Open the input video
        cap = cv2.VideoCapture(self.video_path)

        if not cap.isOpened():
            raise ValueError(f"Could not open video file: {self.video_path}")

        # Get video properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Create video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

        # Define colors
        COLORS = {
            "text_background": (0, 0, 0),
            "text": (255, 255, 255),
            "hazard": (0, 0, 255),   # Red
            "sign": (0, 255, 255),   # Yellow
            "normal": (255, 255, 255) # White
        }

        # Process each frame
        frame_count = 0
        current_analysis = None
        current_analysis_index = 0

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            # Get current timestamp in seconds
            timestamp = frame_count / fps

            # Check if we need to update the analysis
            if current_analysis_index < len(self.analysis_results):
                if timestamp >= self.analysis_results[current_analysis_index]["timestamp"]:
                    current_analysis = self.analysis_results[current_analysis_index]["analysis"]
                    current_analysis_index += 1

            # If we have analysis data, overlay it on the frame
            if current_analysis:
                # Create a copy of the frame
                display_frame = frame.copy()

                # Add semi-transparent overlay at the bottom
                overlay = display_frame.copy()
                cv2.rectangle(overlay, (0, height-130), (width, height), COLORS["text_background"], -1)
                alpha = 0.7  # Transparency factor
                cv2.addWeighted(overlay, alpha, display_frame, 1-alpha, 0, display_frame)

                # Add lane status
                cv2.putText(
                    display_frame,
                    f"Lane: {current_analysis['lane_status']}",
                    (20, height-100),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.6,
                    COLORS["normal"],
                    1
                )

                # Add road signs with type checking
                if current_analysis["road_signs"]:
                    # Check if road_signs is a list or other type
                    if isinstance(current_analysis["road_signs"], list):
                        # Convert any non-string elements to strings
                        sign_items = []
                        for item in current_analysis["road_signs"][:2]:
                            if isinstance(item, str):
                                sign_items.append(item)
                            elif isinstance(item, (dict, list)):
                                sign_items.append(str(item))
                        signs_text = "Signs: " + ", ".join(sign_items)
                    else:
                        # Handle case where road_signs is not a list
                        signs_text = "Signs: " + str(current_analysis["road_signs"])

                    cv2.putText(
                        display_frame,
                        signs_text,
                        (20, height-75),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        0.6,
                        COLORS["sign"],
                        1
                    )

                # Add intersections with type checking
                if current_analysis["intersections"]:
                    # Check if intersections is a list or other type
                    if isinstance(current_analysis["intersections"], list):
                        # Convert any non-string elements to strings
                        intersection_items = []
                        for item in current_analysis["intersections"][:1]:
                            if isinstance(item, str):
                                intersection_items.append(item)
                            elif isinstance(item, (dict, list)):
                                intersection_items.append(str(item))
                        intersections_text = "Intersection: " + ", ".join(intersection_items)
                    else:
                        # Handle case where intersections is not a list
                        intersections_text = "Intersection: " + str(current_analysis["intersections"])

                    cv2.putText(
                        display_frame,
                        intersections_text,
                        (20, height-50),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        0.6,
                        COLORS["sign"],
                        1
                    )

                # Add vehicles and pedestrians info
                vehicles_text = "Vehicles: "
                if current_analysis["vehicles"]:
                    # Check if vehicles is a list or other type
                    if isinstance(current_analysis["vehicles"], list):
                        # Convert any non-string elements to strings
                        vehicle_items = []
                        for item in current_analysis["vehicles"][:2]:
                            if isinstance(item, str):
                                vehicle_items.append(item)
                            elif isinstance(item, (dict, list)):
                                vehicle_items.append(str(item))
                        vehicles_text += ", ".join(vehicle_items)
                    else:
                        # Handle case where vehicles is not a list
                        vehicles_text += str(current_analysis["vehicles"])
                else:
                    vehicles_text += "None detected"

                cv2.putText(
                    display_frame,
                    vehicles_text,
                    (width//2, height-75),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.6,
                    COLORS["normal"],
                    1
                )

                # Add pedestrians info with type checking
                pedestrians_text = "Pedestrians: "
                if current_analysis["pedestrians"]:
                    # Check if pedestrians is a list or other type
                    if isinstance(current_analysis["pedestrians"], list):
                        # Convert any non-string elements to strings
                        pedestrian_items = []
                        for item in current_analysis["pedestrians"][:2]:
                            if isinstance(item, str):
                                pedestrian_items.append(item)
                            elif isinstance(item, (dict, list)):
                                pedestrian_items.append(str(item))
                        pedestrians_text += ", ".join(pedestrian_items)
                    else:
                        # Handle case where pedestrians is not a list
                        pedestrians_text += str(current_analysis["pedestrians"])
                else:
                    pedestrians_text += "None detected"

                cv2.putText(
                    display_frame,
                    pedestrians_text,
                    (width//2, height-50),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.6,
                    COLORS["normal"],
                    1
                )

                # Add hazards with type checking
                if current_analysis["immediate_hazards"]:
                    # Check if immediate_hazards is a list or other type
                    if isinstance(current_analysis["immediate_hazards"], list):
                        # Convert any non-string elements to strings
                        hazard_items = []
                        for item in current_analysis["immediate_hazards"]:
                            if isinstance(item, str):
                                hazard_items.append(item)
                            elif isinstance(item, (dict, list)):
                                hazard_items.append(str(item))
                        hazards_text = "CAUTION: " + ", ".join(hazard_items)
                    else:
                        # Handle case where immediate_hazards is not a list
                        hazards_text = "CAUTION: " + str(current_analysis["immediate_hazards"])

                    cv2.putText(
                        display_frame,
                        hazards_text,
                        (20, height-25),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        0.6,
                        COLORS["hazard"],
                        2
                    )

                # Write the modified frame
                out.write(display_frame)
            else:
                # Write the original frame if no analysis is available
                out.write(frame)

            # Update frame count
            frame_count += 1

            # Print progress
            if frame_count % 100 == 0:
                print(f"Processed {frame_count}/{total_frames} frames ({frame_count/total_frames*100:.1f}%)")

        # Release resources
        cap.release()
        out.release()

        print(f"Output video saved to: {output_video_path}")
        return output_video_path

In [None]:
# Usage example
if __name__ == "__main__":
    # Replace with your actual values
    video_path = "/content/output (13).mp4"
    groq_api_key = "Put GROQ API KEY HERE"

    # Create the assistant
    assistant = SimpleDrivingAssistant(
        video_path=video_path,
        groq_api_key=groq_api_key,
        output_dir="driving_assistant_output"
    )

    # Process the video
    assistant.process_video()

    # Create output video with information overlay
    assistant.create_output_video()

Processing video: /content/output (13).mp4
Video FPS: 30.0
Total frames: 576
Duration: 19.20 seconds
Extracted frame 1/19 at time 0s
Extracted frame 2/19 at time 1s
Extracted frame 3/19 at time 2s
Extracted frame 4/19 at time 3s
Extracted frame 5/19 at time 4s
Extracted frame 6/19 at time 5s
Extracted frame 7/19 at time 6s
Extracted frame 8/19 at time 7s
Extracted frame 9/19 at time 8s
Extracted frame 10/19 at time 9s
Extracted frame 11/19 at time 10s
Extracted frame 12/19 at time 11s
Extracted frame 13/19 at time 12s
Extracted frame 14/19 at time 13s
Extracted frame 15/19 at time 14s
Extracted frame 16/19 at time 15s
Extracted frame 17/19 at time 16s
Extracted frame 18/19 at time 17s
Extracted frame 19/19 at time 18s
Getting analysis for frame 1/19 at 0s
Saved analysis to driving_assistant_output/driving_analysis.json
Getting analysis for frame 2/19 at 1s
Getting analysis for frame 3/19 at 2s
Getting analysis for frame 4/19 at 3s
Getting analysis for frame 5/19 at 4s
Error parsing JSO