In [None]:
plotprompt = f'''  
You are a specialist in two aspects, drawing charts with plotly, and providing detailed descriptions about
the chart. You receive the data in the format of string containing three jsons (data.json containing data, layout.json containing layout and config.json which contain relevant configuration. Your task is to extract relevant inforamtions from the jsons and create a csv. In addition, you are provided with an example of Python code
drawing a chart for reference. You also receive some parameters that could be used to increase the diversity. You need to
generate Python code to plot the given data as a chart figure and providing detailed description about the figure.
Additional requirements:
The chart should have the title, labels on x-axis and y-axis. The chart should have legend. You can annotate data values
above the point on the chart figure. Do not use show function to show the figure. The csv data should be listed in the
code.
The output contains two parts. The first part is the generated Python code wrapped in <code start> and <code end>.
Next is the detailed description about the chart wrapped in <description start> and <description end>.
The code should be able to be executed without external files.
The given data: {info1}.
The given code example: {example_plot_code}.
As for additional parameters, you could consider: {documentation}.
Ensure that you have wrapped the python code in <code start> and <code end> tags'''


In [None]:
example_plot_code = '''import json
import plotly.graph_objects as go

# The input string containing data, layout, and config
input_string = """
<data>
[
  {
    "type": "scatter",
    "x": [1972, 1974, 1976, 1978, 1980, 1982, 1984, 1986, 1988, 1990, 1992, 1994, 1996, 1998, 2000, 2002, 2004, 2006, 2008, 2010, 2012, 2014, 2016],
    "y": [2.01, 2.03, 2.25, 2.23, 2.19, 2.19, 2.25, 2.25, 2.21, 2.21, 2.19, 2.19, 2.19, 2.21, 2.23, 2.23, 2.27, 2.29, 2.29, 2.19, 2.19, 2.17, 2.15],
    "mode": "lines",
    "name": "Happiness",
    "line": {
      "color": "rgba(255, 127, 14, 1)",
      "width": 3
    }
  },
  {
    "type": "scatter",
    "x": [1972, 1974, 1976, 1978, 1980, 1982, 1984, 1986, 1988, 1990, 1992, 1994, 1996, 1998, 2000, 2002, 2004, 2006, 2008, 2010, 2012, 2014, 2016],
    "y": [20000, 22000, 24000, 26000, 28000, 30000, 32000, 34000, 36000, 38000, 40000, 42000, 44000, 46000, 48000, 50000, 52000, 53000, 54000, 55000, 55000, 53000, 51000],
    "mode": "lines",
    "name": "GDP Per Capita",
    "line": {
      "color": "rgba(14, 127, 255, 1)",
      "width": 3
    }
  }
]
</data>
<layout>
{
  "title": {
    "text": "Average Happiness and GDP Per Capita, 1972-2016",
    "font": {
      "family": "Arial, sans-serif",
      "size": 24,
      "color": "#000000"
    }
  },
  "xaxis": {
    "title": {
      "text": "Year",
      "font": {
        "family": "Arial, sans-serif",
        "size": 18,
        "color": "#000000"
      }
    },
    "showgrid": true,
    "gridcolor": "rgba(0, 0, 0, 0.1)",
    "zeroline": true,
    "zerolinecolor": "rgba(0, 0, 0, 0.1)",
    "tickangle": 45
  },
  "yaxis": {
    "title": {
      "text": "Value",
      "font": {
        "family": "Arial, sans-serif",
        "size": 18,
        "color": "#000000"
      }
    },
    "showgrid": true,
    "gridcolor": "rgba(0, 0, 0, 0.1)",
    "zeroline": true,
    "zerolinecolor": "rgba(0, 0, 0, 0.1)"
  },
  "legend": {
    "orientation": "h",
    "x": 0.5,
    "xanchor": "center",
    "y": -0.2,
    "font": {
      "family": "Arial, sans-serif",
      "size": 12,
      "color": "#000000"
    }
  },
  "margin": {
    "l": 60,
    "r": 30,
    "b": 60,
    "t": 60
  },
  "plot_bgcolor": "#ffffff",
  "paper_bgcolor": "#ffffff"
}
</layout>
<config>
{
  "responsive": true,
  "displayModeBar": true,
  "modeBarButtonsToRemove": ["toImage"],
  "scrollZoom": true
}
</config>
"""
# Function to extract content between tags
def extract_json_part(string, tag):
    start_tag = f'<{tag}>'
    end_tag = f'</{tag}>'
    start_idx = string.find(start_tag) + len(start_tag)
    end_idx = string.find(end_tag)
    return string[start_idx:end_idx].strip()

# Extracting the data, layout, and config parts
data_string = extract_json_part(input_string, 'data')
layout_string = extract_json_part(input_string, 'layout')
config_string = extract_json_part(input_string, 'config')

# Parsing the strings into JSON objects
data = json.loads(data_string)
layout = json.loads(layout_string)
config = json.loads(config_string)

# Create the Plotly figure
fig = go.Figure(data=data, layout=layout)

# Update the layout to match the parsed JSON layout
fig.update_layout(layout)

# Show the plot
fig.show(config=config)

# Save the figure as an image 
fig.write_image("new_image.png")
'''

In [None]:
with open('plotly_documentation.txt') as doc:
    documentation = doc.read()

In [None]:
# Code to get information using Gemini model

def gemini_model(prompt):
    
    genai.configure(api_key=gemini_key)

    model = genai.GenerativeModel(model_name="gemini-pro")

    template = prompt

    response = model.generate_content(template)
    return str(response.text)

gemini_output = gemini_model(plotprompt)
print(gemini_output)

In [None]:
import re

# The response string from the AI
ai_response = gemini_output

# Function to extract the code block between <plot.py> tags
def extract_code(response):
    pattern = r"```python(.*?)```"
    match = re.search(pattern, response, re.DOTALL)
    if match:
        return match.group(1).strip()
    return None

# Extract the code
code_block = extract_code(ai_response)

# Save the extracted code to plot.py
if code_block:
    with open("plot.py", "w") as file:
        file.write(code_block)
    print("Code has been saved to plot.py")
else:
    print("No code found between <plot.py> tags")


In [None]:
!python plot.py