# INSTALL LIBRARY

In [None]:
!pip install instructor openai graphviz pydantic networkx plotly gradio -q

# AI Setup

## Using OpenAI

In [2]:
#Approach 1 :  OpenAI Key

'''
import instructor
from openai import OpenAI

openai_key = "insert_your_key_here"

client = OpenAI(
 api_key=openai_key,
 )

model_id = "gpt-4o-mini"
client = instructor.from_openai(client)
'''


'\nimport instructor\nfrom openai import OpenAI\n\nopenai_key = "insert_your_key_here"\n\nclient = OpenAI(\n api_key=openai_key,\n )\n\nmodel_id = "gpt-4o-mini"\nclient = instructor.from_openai(client)\n'

## Using Azure OpenAI

In [3]:
#Approach 2 :  AzureOpenAI Key
import instructor
from openai import AzureOpenAI

azure_endpoint = "your_azure_endpoint"
azure_api_key = "your_azure_key"

az_client = AzureOpenAI(
    azure_endpoint = azure_endpoint,
    api_key=azure_api_key,
    api_version="2024-02-15-preview"
)

model_id = "your_model_id"

client = instructor.from_openai(az_client)


# WORFLOW LOGIC

In [None]:
import gradio as gr

from pydantic import BaseModel, Field
import base64
import networkx as nx
import matplotlib.pyplot as plt
import io
from PIL import Image
import os
import logging
import traceback
from openai import OpenAI

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Pydantic Structured models

class Node(BaseModel):
    id: int
    label: str
    color: str

class Edge(BaseModel):
    source: int
    target: int
    label: str
    color: str = "black"

class KnowledgeGraph(BaseModel):
    nodes: list[Node] = Field(..., default_factory=list)
    edges: list[Edge] = Field(..., default_factory=list)

# image encoding logic
def encode_image(image_path):
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode("utf-8")
    except Exception as e:
        logger.error(f"Error encoding image: {str(e)}")
        raise

import matplotlib.colors as mcolors

# graph visualiztion function

def visualize_graph(graph):
    try:
        G = nx.DiGraph()

        for node in graph.nodes:
            G.add_node(node.id, label=node.label, color=node.color)

        for edge in graph.edges:
            G.add_edge(edge.source, edge.target, label=edge.label, color=edge.color)

        pos = nx.spring_layout(G, k=0.9, iterations=50)

        plt.figure(figsize=(14, 10))
        plt.title("Enhanced Home Layout Visualization", fontsize=16, fontweight='bold')

        # Enhance node appearance
        node_colors = [mcolors.to_rgba(node[1]['color'], alpha=0.8) for node in G.nodes(data=True)]
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=3000, edgecolors='gray', linewidths=2)

        # Improve label rendering
        labels = {node[0]: node[1]['label'] for node in G.nodes(data=True)}
        nx.draw_networkx_labels(G, pos, labels, font_size=10, font_weight='bold')

        # Enhance edge appearance
        edge_colors = [mcolors.to_rgba(edge[2]['color'], alpha=0.6) for edge in G.edges(data=True)]
        nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=2, arrowsize=20)

        # Improve edge label rendering
        edge_labels = {(u, v): d['label'] for u, v, d in G.edges(data=True)}
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8, font_color='darkblue')

        plt.axis('off')
        plt.tight_layout()

        # Create a color to room type mapping
        color_to_room = {}
        for node in G.nodes(data=True):
            color = node[1]['color']
            room = node[1]['label']
            if color not in color_to_room:
                color_to_room[color] = room

        # Add a legend with color names and corresponding room types
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w',
                           label=f"{color}: {room}",
                           markerfacecolor=color, markersize=10)
                           for color, room in color_to_room.items()]
        plt.legend(handles=legend_elements, title="Room Types", loc='best', fontsize=8)

        # Save the image with higher DPI
        img_buf = io.BytesIO()
        plt.savefig(img_buf, format='png', dpi=300, bbox_inches='tight')
        img_buf.seek(0)
        return Image.open(img_buf)
    except Exception as e:
        logger.error(f"Error visualizing graph: {str(e)}")
        raise

def generate_graph(image_path) -> KnowledgeGraph:
    try:
        # Encode the image
        base64_image = encode_image(image_path)

        logger.info("Sending request to OpenAI API")
        response = client.chat.completions.create(
            model=model_id,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Describe the relationship between rooms, windows and doors openings and connectivity. Start from the entrance and describe the layout as if you are giving a walking tour of the layout. Output a array of nodes(id, label, color) and the edges(source, target, label). Help me understand the following image by describing it as a detailed knowledge graph:"},
                        {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}
                    ],
                }
            ],
            response_model=KnowledgeGraph,
        )
        logger.info("Received response from OpenAI API")
        return response
    except Exception as e:
        logger.error(f"Error generating graph: {str(e)}")
        raise

def process_image(input_image):
    try:
        logger.info(f"Processing image: {input_image}")
        graph = generate_graph(input_image)
        output_image = visualize_graph(graph)
        return output_image
    except Exception as e:
        logger.error(f"Error in process_image: {str(e)}")
        return None

# Create Gradio interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="filepath"),
    outputs=gr.Image(type="pil"),
    title="Image to Knowledge Graph",
    description="Upload an image to generate a knowledge graph representation."
)

# Launch the app
if __name__ == "__main__":
    try:
        iface.launch()
    except Exception as e:
        logger.error(f"Error launching Gradio interface: {str(e)}")
        logger.error(traceback.format_exc())