In [2]:
# === Load environment variables ===
from dotenv import load_dotenv
import os

# Load .env file from project root
load_dotenv()

# Fetch the API key from .env
my_api_key = os.getenv("GOOGLE_API_KEY")

# Quick check (optional â€“ remove later)
if my_api_key:
    print("API key loaded successfully.")
else:
    print("API key NOT found. Check your .env file.")

# === Configure Gemini API ===
import google.generativeai as genai
genai.configure(api_key=my_api_key)

print("Gemini configured and ready.")


API key loaded successfully.
Gemini configured and ready.


In [3]:
model = genai.GenerativeModel("models/gemini-2.5-flash")

In [4]:
from agents.report_generator import ReportGenerator
report_gen = ReportGenerator()
import os

In [5]:
import pandas as pd
from agents.data_cleaner import DataCleaner
from agents.visualizer import Visualizer
from agents.forecaster import Forecaster

cleaner = DataCleaner()
viz = Visualizer()
forecaster = Forecaster()

In [6]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from agent.code_generator import CodeGenerator
from agent.explanation_agent import ExplanationAgent
from agent.utils import capture_output

# Dummy LLM function (replace with ADK later)
def llm(prompt):
    response = model.generate_content(prompt)
    return response.text

code_gen = CodeGenerator(llm)
exp_agent = ExplanationAgent(llm)


In [7]:
df = pd.read_csv("sample_data/sales.csv")
df.head()


Unnamed: 0,Product_ID,Sale_Date,Sales_Rep,Region,Sales_Amount,Quantity_Sold,Product_Category,Unit_Cost,Unit_Price,Customer_Type,Discount,Payment_Method,Sales_Channel,Region_and_Sales_Rep
0,1052,03-02-2023,Bob,North,5053.97,18,Furniture,152.75,267.22,Returning,0.09,Cash,Online,North-Bob
1,1093,21-04-2023,Bob,West,4384.02,17,Furniture,3816.39,4209.44,Returning,0.11,Cash,Retail,West-Bob
2,1015,21-09-2023,David,South,4631.23,30,Food,261.56,371.4,Returning,0.2,Bank Transfer,Retail,South-David
3,1072,24-08-2023,Bob,South,2167.94,39,Clothing,4330.03,4467.75,New,0.02,Credit Card,Retail,South-Bob
4,1061,24-03-2023,Charlie,East,3750.2,13,Electronics,637.37,692.71,New,0.08,Credit Card,Online,East-Charlie


In [8]:
def route_task(task, df):
    task = task.lower()

    if "clean" in task:
        return cleaner.clean(df)

    if "plot" in task or "visualize" in task:
        column = df.columns[1]
        path = viz.generate_plot(df, column)
        return {"plot_path": path}

    if "forecast" in task:
        column = df.columns[1]
        return forecaster.forecast_next(df, column)

    return "Task not understood."


In [9]:
def ensure_revenue_column(df):
    cols = df.columns.str.lower()

    # CASE 1: Already contains a revenue-like column
    for c in df.columns:
        if "revenue" in c.lower():
            df.rename(columns={c: "revenue"}, inplace=True)
            return df

    # CASE 2: Contains price & units_sold
    price_col = None
    units_col = None

    for c in df.columns:
        if "price" in c.lower():
            price_col = c
        if "unit" in c.lower() or "quantity" in c.lower():
            units_col = c

    if price_col and units_col:
        df["revenue"] = df[price_col] * df[units_col]
        return df

    # CASE 3: No revenue possible
    df["revenue"] = 0
    return df


In [10]:
df = ensure_revenue_column(df)


In [11]:
def ask_ai(task):
    df = pd.read_csv("sample_data/sales.csv")
    result = route_task(task, df)
    return result

ask_ai("clean the data")


{'missing_values': {'Product_ID': 0,
  'Sale_Date': 0,
  'Sales_Rep': 0,
  'Region': 0,
  'Sales_Amount': 0,
  'Quantity_Sold': 0,
  'Product_Category': 0,
  'Unit_Cost': 0,
  'Unit_Price': 0,
  'Customer_Type': 0,
  'Discount': 0,
  'Payment_Method': 0,
  'Sales_Channel': 0,
  'Region_and_Sales_Rep': 0},
 'duplicates_removed': 0,
 'cleaned_df':      Product_ID   Sale_Date Sales_Rep Region  Sales_Amount  Quantity_Sold  \
 0          1052  03-02-2023       Bob  North       5053.97             18   
 1          1093  21-04-2023       Bob   West       4384.02             17   
 2          1015  21-09-2023     David  South       4631.23             30   
 3          1072  24-08-2023       Bob  South       2167.94             39   
 4          1061  24-03-2023   Charlie   East       3750.20             13   
 ..          ...         ...       ...    ...           ...            ...   
 995        1010  15-04-2023   Charlie  North       4733.88              4   
 996        1067  07-09-2023 

In [12]:
user_query = "Show me the total revenue per month with a bar chart."


In [13]:
def clean_code(llm_output):
    # remove markdown fences
    code = llm_output.replace("```python", "").replace("```", "")
    
    # keep ALL lines except pure text lines
    cleaned = []
    for line in code.split("\n"):
        if line.strip() == "":
            continue
        
        # remove conversational garbage
        if any(x in line.lower() for x in [
            "here is the code",
            "sure",
            "below is",
            "the following",
            "as requested"
        ]):
            continue
        
        cleaned.append(line)
    
    return "\n".join(cleaned)

def fix_seaborn_palette(code):
    # If LLM adds deprecated seaborn palette usage, remove palette or add legend=False
    if "sns." in code and "palette=" in code and "hue=" not in code:
        code = code.replace("palette=", "legend=False, palette_removed=")
    return code

In [14]:
import matplotlib.pyplot as plt

def force_png_no_alpha():
    fig = plt.gcf()
    fig.patch.set_facecolor("white")


In [15]:
def patch_savefig_alpha(code):
    # force PNGs to be saved in RGB mode
    if "plt.savefig" in code:
        code = code.replace(
            "plt.savefig(",
            "plt.gcf().patch.set_facecolor('white'); plt.savefig("
        )
    return code

In [16]:
import re

def auto_print_last_var(code):
    lines = code.strip().split("\n")
    last_var = None

    for line in reversed(lines):
        if "=" in line and not line.strip().startswith("#"):
            last_var = line.split("=")[0].strip()
            break

    if last_var and f"print({last_var})" not in code:
        code += f"\nprint({last_var})\n"

    return code


In [17]:
def fix_unclosed_braces(code):
    open_count = code.count("{")
    close_count = code.count("}")

    if open_count > close_count:
        code += "}" * (open_count - close_count)

    return code


In [18]:
# 1. Generate raw LLM output
raw_code = code_gen.generate(user_query, df)

# 2. Clean the LLM code (remove ```python, etc.)
cleaned_code = clean_code(raw_code)

# 3. Fix seaborn palette usage if necessary
generated_code = fix_seaborn_palette(cleaned_code)
# Inject forced patch before savefig
if "plt.savefig" in generated_code:
    generated_code = generated_code.replace(
        "plt.savefig(",
        "force_png_no_alpha(); plt.savefig("
    )

print("===== RAW CODE =====")
print(raw_code)
print("====================")

print("\n===== CLEANED CODE =====")
print(generated_code)
print("====================")


===== RAW CODE =====
import pandas as pd
import matplotlib.pyplot as plt

# Ensure 'Sale_Date' is a datetime object
df['Sale_Date'] = pd.to_datetime(df['Sale_Date'])

# Extract month from 'Sale_Date'
df['Sale_Month'] = df['Sale_Date'].dt.to_period('M')

# Group by month and sum the 'revenue'
monthly_revenue = df.groupby('Sale_Month')['revenue'].sum().reset_index()

# Convert 'Sale_Month' back to string for better plotting labels if needed, or keep as Period for sorting
monthly_revenue['Sale_Month_Str'] = monthly_revenue['Sale_Month'].astype(str)

# Print the computed values
print("Total revenue per month:")
print(monthly_revenue[['Sale_Month_Str', 'revenue']])

# Create the bar chart
plt.figure(figsize=(12, 6))
plt.bar(monthly_revenue['Sale_Month_Str'], monthly_revenue['revenue'], color='skyblue')

# Add labels and title
plt.xlabel('Month')
plt.ylabel('Total Revenue')
plt.title('Total Revenue Per Month')
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', linestyle='--', alpha=0.7)


In [19]:
def auto_debug_code(bad_code, error):
    prompt = f"""
The following Python code produced an error:
{bad_code}

Error:
{error}

Fix the code. Only output corrected Python code.
"""
    return llm(prompt)


In [20]:
local_env = {"df": df, "plt": plt, "sns": sns, "pd": pd, "np": np, "force_png_no_alpha": force_png_no_alpha}

# Make sure plot file is deleted first
if os.path.exists("output_plot.png"):
    os.remove("output_plot.png")

exec_output = capture_output(generated_code, local_env)
if "ERROR" in exec_output:
    fixed = auto_debug_code(raw_code, exec_output)
    fixed = clean_code(fixed)
    exec_output = capture_output(fixed, local_env)

print(exec_output)

print("Image Exists:", os.path.exists("output_plot.png"))


Total revenue per month:
   Sale_Month_Str       revenue
0         2023-01  1.056883e+09
1         2023-02  6.507573e+08
2         2023-03  7.610588e+08
3         2023-04  7.974712e+08
4         2023-05  7.596467e+08
5         2023-06  8.789984e+08
6         2023-07  6.874590e+08
7         2023-08  9.110048e+08
8         2023-09  5.820823e+08
9         2023-10  8.466485e+08
10        2023-11  6.896101e+08
11        2023-12  8.321018e+08
12        2024-01  3.344795e+06

Image Exists: True


In [21]:
python_summary = exec_output[:500]  # short summary
print(python_summary)


Total revenue per month:
   Sale_Month_Str       revenue
0         2023-01  1.056883e+09
1         2023-02  6.507573e+08
2         2023-03  7.610588e+08
3         2023-04  7.974712e+08
4         2023-05  7.596467e+08
5         2023-06  8.789984e+08
6         2023-07  6.874590e+08
7         2023-08  9.110048e+08
8         2023-09  5.820823e+08
9         2023-10  8.466485e+08
10        2023-11  6.896101e+08
11        2023-12  8.321018e+08
12        2024-01  3.344795e+06



In [22]:
explanation = exp_agent.explain(user_query, python_summary)
print(explanation)


PermissionDenied: 403 Your API key was reported as leaked. Please use another API key.

In [None]:
os.makedirs("reports", exist_ok=True)
with open("reports/insight_report.txt", "w") as f:
    f.write(explanation)


In [None]:
with open("reports/insight_report.txt", "w") as f:
    f.write(explanation)

plt.savefig('output_plot.png')
print("Report saved!")

In [None]:
summary_text = "This is an automatically generated insight report."

images = ["output_plot.png"]

pdf_path = report_gen.generate_report(
    title="AI Generated Data Analysis Report",
    summary=summary_text,
    images=images
)

pdf_path


In [None]:
summary_text = explanation  # from explanation_agent


In [None]:
def generate_full_report(user_query):
    # 1. Generate code
    raw = code_gen.generate(user_query, df)
    cleaned = clean_code(raw)
    code = fix_seaborn_palette(cleaned)

    # 2. Execute code
    local_env = {"df": df, "plt": plt, "sns": sns, "pd": pd, "np": np}
    
    # remove old image
    if os.path.exists("output_plot.png"):
        os.remove("output_plot.png")

    exec_output = capture_output(code, local_env)

    # Auto debug
    if "ERROR" in exec_output:
        repaired = auto_debug_code(code, exec_output)
        repaired = clean_code(repaired)
        code = repaired
        exec_output = capture_output(code, local_env)

    # 3. Verify the plot
    print("Plot Exists:", os.path.exists("output_plot.png"))
    
    if not os.path.exists("output_plot.png"):
        raise Exception("The plot was not created, so report cannot be generated.")
    
    print("Plot file size:", os.path.getsize("output_plot.png"))

    # 4. Explanation
    python_summary = exec_output[:500]
    explanation = exp_agent.explain(user_query, python_summary)

    # 5. Generate report
    report_path = report_gen.generate_report(
        title=f"Report: {user_query}",
        summary=explanation,
        images=["output_plot.png"]
    )

    return explanation, report_path


In [None]:
summary, report_file = generate_full_report("Show total revenue per month")
summary, report_file
