In [69]:
import plotly.graph_objs as go
import plotly.io as pio
import json
import re

def extract_json_section(input_string, tag):
    """Extract JSON section between specified XML-like tags."""
    pattern = f"<{tag}>(.*?)</{tag}>"
    match = re.search(pattern, input_string, re.DOTALL)
    if match:
        return match.group(1).strip()
    return None

def parse_json(json_str):
    """Attempt to parse JSON with relaxed rules."""
    try:
        return json.loads(json_str)
    except json.JSONDecodeError:
        # Try replacing single quotes with double quotes and parsing again
        try:
            fixed_json_str = json_str.replace("'", '"')
            return json.loads(fixed_json_str)
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON: {e}")

def plot_from_ai_output_v2(input_string, img_path):
    # Extract JSON sections
    data_json = extract_json_section(input_string, "data")
    layout_json = extract_json_section(input_string, "layout")
    config_json = extract_json_section(input_string, "config")

    # Parse JSON strings with relaxed rules
    data = parse_json(data_json) if data_json else None
    layout = parse_json(layout_json) if layout_json else None
    config = parse_json(config_json) if config_json else None

    if not data or not layout:
        raise ValueError("Invalid or missing data or layout JSON.")

    # Prepare traces for the plot
    traces = []
    for trace_data in data:
        trace_type = trace_data.get('type')

        if trace_type == 'heatmap':
            trace = go.Heatmap(
                z=trace_data.get('z', []),
                x=trace_data.get('x', []),
                y=trace_data.get('y', []),
                name=trace_data.get('name', ''),
                colorscale=trace_data.get('colorscale', 'Viridis'),
                hoverinfo='x+y+z'
            )

            traces.append(trace)

    # Create figure with the extracted layout and data
    fig = go.Figure(data=traces, layout=layout)

    # Render the figure with the config
    pio.show(fig, config=config)

    # Save the figure as an image (optional)
    fig.write_image(img_path)


In [70]:
import google.generativeai as genai
from config import gemini_key

In [59]:
def input_image_setup(file_loc):
    from pathlib import Path

    if not (img := Path(file_loc)).exists():
        raise FileNotFoundError(f"Could not find image: {img}")

    image_parts = [
        {
            "mime_type": "image/jpeg",
            "data": Path(file_loc).read_bytes()
            }
        ]
    return image_parts
def get_image_info(image_loc, prompt):
    genai.configure(api_key=gemini_key)
    # Set up the model
    generation_config = {
        "temperature":0,
        "top_p":1,
        "top_k":32,
        "max_output_tokens":5000,
    }
    
    model = genai.GenerativeModel(model_name="gemini-pro-vision", generation_config=generation_config)

    input_prompt = """ You are an expert in data visualization and graph analysis, adept at interpreting graphical data and generating structured JSON configurations for Plotly"""

    question_prompt = prompt

    image_prompt = input_image_setup(image_loc)
    prompt_parts = [input_prompt, image_prompt[0], question_prompt]
    response = model.generate_content(prompt_parts)
    return str(response.text)

In [71]:
import time
import json
import os
import csv
from docx import Document

def trail_run_v2(image_loc, prompt):

    # Extract base name of the image without extension
    image_name = os.path.splitext(os.path.basename(image_loc))[0]
    # Create the target directory based on the image name
    json_folder = f'D:/Chart/waterfall_charts/json/{image_name}'    # Extracted JSON directory
    os.makedirs(json_folder, exist_ok=True)
    
    # Extracted CSV directory
    csv_folder = 'D:/Chart/waterfall_charts/extracted_tables'   # Extracted CSV directory
    os.makedirs(csv_folder, exist_ok=True)
    extracted_csv_path = os.path.join(csv_folder, f'{image_name}.csv')
    
    # Placeholder functions for simulation

    extracted_json = get_image_info(image_loc, prompt)
    print("Extracted Json:\n", extracted_json)
    
    data_json = extract_json_section(extracted_json, "data")
    layout_json = extract_json_section(extracted_json, "layout")
    config_json = extract_json_section(extracted_json, "config")
    csv_data = extract_json_section(extracted_json, "csv")
    
    # Convert JSON strings to Python dictionaries
    data = json.loads(data_json) if data_json else None
    layout = json.loads(layout_json) if layout_json else None
    config = json.loads(config_json) if config_json else None
    
    # Save extracted JSON files
    with open(os.path.join(json_folder, 'data.json'), 'w') as f:
        json.dump(data, f, indent=4)
    with open(os.path.join(json_folder, 'layout.json'), 'w') as f:
        json.dump(layout, f, indent=4)
    with open(os.path.join(json_folder, 'config.json'), 'w') as f:
        json.dump(config, f, indent=4)
    
    # Save extracted CSV data
    if csv_data:
        with open(extracted_csv_path, 'w', newline='') as f:
            csv_writer = csv.writer(f)
            csv_reader = csv.reader(csv_data.splitlines())
            csv_writer.writerows(csv_reader)
    
    

In [72]:
import os
import glob

def process_all_images(png_folder, prompt):
    # Find all PNG files in the png_folder
    png_files = glob.glob(os.path.join(png_folder, '*.png'))
    for png_file in png_files:
        try:
            trail_run_v2(png_file, prompt)
        except Exception as e:
            with open('error.log', 'a') as f:
                f.write(f"Error processing {png_file}: {e}\n")
            print(f"Error processing {png_file}: {e}")


with open('prompts/waterfall/prompt2.txt', 'r') as file:
    prompt = file.read()
# Example usage
png_folder = 'charts/waterfall_charts/img'

process_all_images(png_folder, prompt)


Extracted Json:
  <data>
[
  {
    "type": "waterfall",
    "x": ["Beginning", "Sales 1", "Sales 2", "Returns 1", "Returns 2", "Net Sales 3", "Expenses 1", "Expenses 2", "Profit 1", "Profit 2", "Profit 3", "Final Profit"],
    "y": [0, 500, 300, -50, -20, 700, -100, -120, 150, 100, -50, 400],
    "measure": ["total", "relative", "relative", "relative", "relative", "total", "relative", "relative", "relative", "relative", "relative", "total"]
  }
]
</data>
<layout>
{
  "title": {
    "text": "Super Complex Financial Analysis 1",
    "font": {
      "family": "Arial, sans-serif",
      "size": 24,
      "color": "#000000"
    }
  },
  "xaxis": {
    "title": {
      "text": "Categories",
      "font": {
        "family": "Arial, sans-serif",
        "size": 18,
        "color": "#000000"
      }
    }
  },
  "yaxis": {
    "title": {
      "text": "Values",
      "font": {
        "family": "Arial, sans-serif",
        "size": 18,
        "color": "#000000"
      }
    }
  }
}
</layout>
<