In [None]:
import torch
from transformers import pipeline, Owlv2Processor, Owlv2ForObjectDetection
import numpy as np
from scipy.spatial.distance import euclidean
from collections import deque
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from matplotlib.colors import to_rgb
import os
import anthropic
import networkx as nx
import json
import cv2
import time
from google.colab.patches import cv2_imshow
from IPython.display import clear_output
import sqlite3
from datetime import datetime

TEST_VIDEO_PATH = os.getenv("TEST_VIDEO_PATH") # replace with a path to a .mov file in your google drive
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")

class SceneReasoningSystem:
    def __init__(self, api_key, history_length=5, db_path='scene_reasoning.db'):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        self.processor, self.model = self.load_object_detector()
        self.captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", device=self.device)
        self.history = deque(maxlen=history_length)
        self.claude_client = anthropic.Anthropic(api_key=api_key)
        self.db_path = db_path
        self.init_database()

    def init_database(self):
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute('''
        CREATE TABLE IF NOT EXISTS reasoning_logs (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            timestamp TEXT,
            frame_number INTEGER,
            scene_description TEXT,
            inference TEXT
        )
        ''')
        conn.commit()
        conn.close()

    def log_reasoning(self, frame_number, scene_description, inference):
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            # Print out types and first 100 characters of each parameter
            print(f"frame_number type: {type(frame_number)}, value: {frame_number}")
            print(f"scene_description type: {type(scene_description)}, first 100 chars: {scene_description[:100]}")
            print(f"inference type: {type(inference)}, first 100 chars: {inference[:100]}")
            
            # Ensure all parameters are strings
            frame_number = str(frame_number)
            scene_description = str(scene_description)
            inference = str(inference)
            
            cursor.execute('''
            INSERT INTO reasoning_logs (timestamp, frame_number, scene_description, inference)
            VALUES (?, ?, ?, ?)
            ''', (
                datetime.now().isoformat(),
                frame_number,
                scene_description,
                inference
            ))
            conn.commit()
        except sqlite3.Error as e:
            print(f"An error occurred: {e}")
        finally:
            conn.close()

    def load_object_detector(self):
        processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
        model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(self.device)
        return processor, model

    def detect_objects(self, image, texts=[['a man']]):
        inputs = self.processor(text=texts, images=image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model(**inputs)

        target_sizes = torch.tensor([image.size[::-1]]).to(self.device)
        results = self.processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)

        objects = []
        for score, label, box in zip(results[0]['scores'], results[0]['labels'], results[0]['boxes']):
            box = [round(i, 2) for i in box.tolist()]
            objects.append({
                'label': texts[0][label],
                'score': score.item(),
                'box': box
            })
        return objects

    def estimate_depth(self, box, image_size):
        width, height = image_size
        box_width = box[2] - box[0]
        box_height = box[3] - box[1]
        box_area = box_width * box_height
        image_area = width * height
        
        size_factor = box_area / image_area
        position_factor = box[3] / height
        
        depth = 1 - (0.1 * size_factor + 0.9 * position_factor)
        return depth

    def calculate_spatial_relations(self, objects, image_size):
        relations = []
        depths = [self.estimate_depth(obj['box'], image_size) for obj in objects]
        
        for i, (obj1, depth1) in enumerate(zip(objects, depths)):
            for j, (obj2, depth2) in enumerate(zip(objects[i+1:], depths[i+1:]), start=i+1):
                distance = euclidean(self.get_center(obj1['box']), self.get_center(obj2['box']))
                relations.append({
                    'obj1': obj1['label'],
                    'obj2': obj2['label'],
                    'distance': distance,
                    'relative_position': self.get_relative_position(obj1['box'], obj2['box']),
                    'depth_difference': depth1 - depth2
                })
        return relations, depths

    def get_center(self, box):
        return ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)

    def get_relative_position(self, box1, box2):
        center1 = self.get_center(box1)
        center2 = self.get_center(box2)
        dx = center2[0] - center1[0]
        dy = center2[1] - center1[1]
        if abs(dx) > abs(dy):
            return "right" if dx > 0 else "left"
        else:
            return "below" if dy > 0 else "above"

    def crop_and_caption_objects(self, image, objects, depths):
        cropped_images = []
        for obj, depth in zip(objects, depths):
            if depth < 0.4:
                spatial_tag = "Foreground"
            elif depth > 0.7:
                spatial_tag = "Background"
            else:
                spatial_tag = "Midground"
            
            cropped_img = image.crop(obj['box'])
            caption = self.captioner(cropped_img)[0]['generated_text']
            cropped_images.append({
                'label': obj['label'],
                'spatial_tag': spatial_tag,
                'caption': caption
            })
        return cropped_images

    def analyze_frame(self, frame, frame_number):
        objects = self.detect_objects(frame)
        spatial_relations, depths = self.calculate_spatial_relations(objects, frame.size)
        cropped_images = self.crop_and_caption_objects(frame, objects, depths)
        scene_graph = self.generate_scene_graph(cropped_images)
        print(scene_graph)
        
        # Generate hierarchical redundant arguments
        hierarchical_args = self.generate_hierarchical_args(scene_graph)
        
        # Perform second round of object detection with hierarchical arguments
        detailed_objects = self.detect_objects(frame, texts=[hierarchical_args])
        detailed_spatial_relations, detailed_depths = self.calculate_spatial_relations(detailed_objects, frame.size)
        
        self.history.append({
            'objects': objects,
            'relations': spatial_relations,
            'captions': cropped_images,
            'scene_graph': scene_graph,
            'detailed_objects': detailed_objects,
            'detailed_relations': detailed_spatial_relations
        })
        return self.infer_actions_and_motives(frame_number)

    def generate_scene_graph(self, cropped_images):
        G = nx.DiGraph()
        
        for idx, img_data in enumerate(cropped_images):
            if img_data['label'] == 'a man':
                man_id = f"man_{idx}"
                G.add_node(man_id, label="man")
                
                # Add redundant nodes
                redundant_labels = ["human", "person", "individual"]
                for label in redundant_labels:
                    G.add_node(f"{label}_{idx}", label=label)
                    G.add_edge(f"{label}_{idx}", man_id, relation="is")
                
                # Generate body parts and objects
                body_parts = ["head", "eye", "arm", "hand", "leg", "foot"]
                objects = ["book", "phone", "glass", "chair"]
                
                for part in body_parts:
                    part_id = f"{part}_{idx}"
                    G.add_node(part_id, label=part)
                    G.add_edge(man_id, part_id, relation="has")
                    
                    # Add redundant nodes for body parts
                    redundant_parts = [f"{man_id}'s {part}", f"human {part}", f"person's {part}"]
                    for redundant in redundant_parts:
                        G.add_node(f"{redundant}_{idx}", label=redundant)
                        G.add_edge(f"{redundant}_{idx}", part_id, relation="is")
                
                # Add potential objects (based on caption)
                caption_lower = img_data['caption'].lower()
                for obj in objects:
                    if obj in caption_lower:
                        obj_id = f"{obj}_{idx}"
                        G.add_node(obj_id, label=obj)
                        G.add_edge(man_id, obj_id, relation="interacting_with")
                        
                        # Add redundant nodes for objects
                        redundant_objs = [f"{man_id}'s {obj}", f"human's {obj}", f"person's {obj}"]
                        for redundant in redundant_objs:
                            G.add_node(f"{redundant}_{idx}", label=redundant)
                            G.add_edge(f"{redundant}_{idx}", obj_id, relation="is")
        
        return nx.node_link_data(G)  # Convert to JSON-serializable format

    def generate_hierarchical_args(self, scene_graph):
        hierarchical_args = []
        G = nx.node_link_graph(scene_graph)
        for node, data in G.nodes(data=True):
            hierarchical_args.append(data['label'])
        return list(set(hierarchical_args))  # Remove duplicates

    def infer_actions_and_motives(self, frame_number):
        scene_description = self.generate_scene_description()
        scene_graph = self.history[-1]['scene_graph']
        detailed_objects = self.history[-1]['detailed_objects']
        detailed_relations = self.history[-1]['detailed_relations']
        
        # Convert scene graph to a string representation
        scene_graph_str = json.dumps(scene_graph, indent=2)
        
        # Generate detailed spatial information string
        detailed_spatial_info = self.generate_detailed_spatial_info(detailed_objects, detailed_relations)
        
        prompt = f"""Based on the following scene description, scene graph, and detailed spatial information, infer the actions and motives of the actors in the scene. Provide a detailed analysis of what might be happening. Consider alternative explanations to what is given in the scene descriptions. Use the spatial position of eyes to infer the direction of gaze:

Scene Description:
{scene_description}

Scene Graph:
{scene_graph_str}

Detailed Spatial Information:
{detailed_spatial_info}

Analysis:"""
        print(prompt)
        response = self.claude_client.messages.create(
            model="claude-3-sonnet-20240229",
            max_tokens=1000,
            messages=[
                {"role": "user", "content": prompt}
            ]
        )
        inference = response.content
        
        # Log the reasoning trace
        self.log_reasoning(frame_number, scene_description + "\n\nScene Graph:\n" + scene_graph_str + "\n\nDetailed Spatial Information:\n" + detailed_spatial_info, inference)
        
        return inference

    def generate_scene_description(self):
        description = "Scene Description:\n\n"
        for i, frame_data in enumerate(self.history):
            description += f"Frame {i+1}:\n"
            for obj in frame_data['objects']:
                description += f"- Detected {obj['label']} at {obj['box']}\n"
            for relation in frame_data['relations']:
                description += f"- {relation['obj1']} is {relation['relative_position']} {relation['obj2']} (distance: {relation['distance']:.2f}, depth difference: {relation['depth_difference']:.2f})\n"
            for caption_data in frame_data['captions']:
                description += f"- {caption_data['label']} ({caption_data['spatial_tag']}): {caption_data['caption']}\n"
            description += "\n"
        return description

    def generate_detailed_spatial_info(self, detailed_objects, detailed_relations):
        info = "Detailed Spatial Information:\n\n"
        for obj in detailed_objects:
            info += f"- Detected {obj['label']} at {obj['box']} with confidence {obj['score']:.2f}\n"
        for relation in detailed_relations:
            info += f"- {relation['obj1']} is {relation['relative_position']} {relation['obj2']} (distance: {relation['distance']:.2f}, depth difference: {relation['depth_difference']:.2f})\n"
        return info
    
    def visualize_frame(self, frame, objects, depths):
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.imshow(frame)
        ax.axis('off')

        for obj, depth in zip(objects, depths):
            box = obj['box']
            color = plt.cm.viridis(depth)[:3]
            rect = plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], 
                                 fill=False, edgecolor=color, linewidth=2)
            ax.add_patch(rect)
            
            spatial_tag = "Foreground" if depth < 0.4 else "Background" if depth > 0.7 else "Midground"
            ax.text(box[0], box[1], f"{obj['label']}: {obj['score']:.2f}\n{spatial_tag}", 
                    bbox=dict(facecolor='white', alpha=0.8), fontsize=10, color=color)

        plt.title("Detected Objects with Spatial Tags")
        plt.show()

    def visualize_scene_graph(self, scene_graph):
        G = nx.node_link_graph(scene_graph)
        pos = nx.spring_layout(G)
        plt.figure(figsize=(12, 8))
        nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000, font_size=8, font_weight='bold')
        edge_labels = nx.get_edge_attributes(G, 'relation')
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
        plt.title("Scene Graph")
        plt.axis('off')
        plt.tight_layout()
        plt.show()

    def process_single_image(self, image_path):
        if os.path.exists(image_path):
            image = Image.open(image_path)
            inference = self.analyze_frame(image, frame_number=1)
            print(inference)
            
            self.visualize_frame(image, self.history[-1]['objects'], self.history[-1]['relations'][1])
            self.visualize_scene_graph(self.history[-1]['scene_graph'])
        else:
            print(f"Image not found at path: {image_path}")

    def process_video_stream(self, video_path, frame_interval=30):
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print("Error opening video file")
            return

        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        
        frame_count = 0
        processed_frames = 0
        start_time = time.time()

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            frame_count += 1

            if frame_count % frame_interval == 0:
                processed_frames += 1
                print(f"Processing frame {frame_count}/{total_frames}")

                rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                pil_image = Image.fromarray(rgb_frame)

                inference = self.analyze_frame(pil_image, frame_count)
                print(f"Frame {frame_count}/{total_frames} Analysis:")
                print(inference)

                self.visualize_frame(pil_image, self.history[-1]['objects'], self.history[-1]['relations'][1])
                self.visualize_scene_graph(self.history[-1]['scene_graph'])

                cv2_imshow(frame)
                clear_output(wait=True)
            else:
                print(f"Skipping frame {frame_count}/{total_frames}")

        end_time = time.time()
        total_time = end_time - start_time
        processed_fps = processed_frames / total_time
        print(f"Processed {processed_frames} frames out of {total_frames} total frames")
        print(f"Processing time: {total_time:.2f} seconds")
        print(f"Processed frames per second: {processed_fps:.2f}")
        print(f"Original video FPS: {fps}")
        print(f"Frame interval: {frame_interval}")

        cap.release()
        cv2.destroyAllWindows()

# Usage example
api_key = ANTHROPIC_API_KEY
system = SceneReasoningSystem(api_key)

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Path to your .mov file in Google Drive
video_path = TEST_VIDEO_PATH

# Process the video
system.process_video_stream(video_path, frame_interval=200)

In [None]:
import sqlite3
import pandas as pd

def view_database(db_path='scene_reasoning.db', limit=10):
    # Connect to the SQLite database
    conn = sqlite3.connect(db_path)
    
    # Query to select all rows from the reasoning_logs table
    query = f"SELECT * FROM reasoning_logs LIMIT {limit}"
    
    # Use pandas to read the SQL query into a DataFrame
    df = pd.read_sql_query(query, conn)
    
    # Close the connection
    conn.close()
    
    # Display the DataFrame
    print(df)
    
    # Return the DataFrame in case you want to do more with it
    return df

df = view_database()