In [2]:
import os
import sqlite3
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
from dotenv import load_dotenv
import json
from utils import *
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate


In [3]:
def explore_database(db_path):
    """Explore database schema"""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [row[0] for row in cursor.fetchall()]
    
    print(f"📊 Database: {db_path}")
    print(f"📋 Tables found: {len(tables)}\n")
    
    schema = {}
    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()
        schema[table] = [col[1] for col in columns]
        print(f"  • {table}: {', '.join(schema[table][:5])}{'...' if len(schema[table]) > 5 else ''}")
    
    conn.close()
    return schema

# Explore chinook database
schema = explore_database("../data/chinook.db")

📊 Database: ../data/chinook.db
📋 Tables found: 13

  • albums: AlbumId, Title, ArtistId
  • sqlite_sequence: name, seq
  • artists: ArtistId, Name
  • customers: CustomerId, FirstName, LastName, Company, Address...
  • employees: EmployeeId, LastName, FirstName, Title, ReportsTo...
  • genres: GenreId, Name
  • invoices: InvoiceId, CustomerId, InvoiceDate, BillingAddress, BillingCity...
  • invoice_items: InvoiceLineId, InvoiceId, TrackId, UnitPrice, Quantity
  • media_types: MediaTypeId, Name
  • playlists: PlaylistId, Name
  • playlist_track: PlaylistId, TrackId
  • tracks: TrackId, Name, AlbumId, MediaTypeId, GenreId...
  • sqlite_stat1: tbl, idx, stat


In [4]:
schema_str = str(schema)

In [5]:
len(schema_str)

1000

In [6]:
context_usage=calculate_context_percentage(schema_str)

In [7]:
context_usage

{'tokens': 288,
 'percentage': 0.22499999999999998,
 'context_window': 128000,
 'remaining': 127712,
 'fits': True}

In [8]:
from dotenv import find_dotenv, load_dotenv
load_dotenv(find_dotenv())

True

In [9]:
token = os.getenv("OPEENAI_API_KEY")
model = 'o4-mini'

llm = ChatOpenAI(
    openai_api_base="https://chat.int.bayer.com/api/v2",
    openai_api_key=token,
    model=model,
    temperature=0.7
)


In [10]:
import pandas as pd

class DataAnalysisAgent:
    """Analyzes data using LLM-generated SQL"""
    
    def __init__(self, db_path, llm, schema):
        self.db_path = db_path
        self.llm = llm
        self.schema = schema
        self.conn = sqlite3.connect(db_path)
        
    def analyze(self, user_query: str, allowed_tables: list = None) -> dict:
        """Main analysis method"""
        print(f"\n{'='*60}")
        print(f"🔍 Query: {user_query}")
        print(f"{'='*60}")
        
        # Step 1: Generate SQL
        sql_query = self._generate_sql(user_query, allowed_tables)
        print(f"\n💻 Generated SQL:\n{sql_query}\n")
        
        # Step 2: Execute with safety checks
        if self._is_safe_query(sql_query):
            try:
                df = pd.read_sql_query(sql_query, self.conn)
                print(f"✅ Query executed: {len(df)} rows returned\n")
                
                # Step 3: Generate insights
                insights = self._generate_insights(user_query, df)
                
                return {
                    "status": "success",
                    "query": user_query,
                    "sql": sql_query,
                    "data": df,
                    "insights": insights,
                    "rows": len(df)
                }
            except Exception as e:
                return {
                    "status": "error",
                    "error": str(e),
                    "query": user_query
                }
        else:
            return {
                "status": "error",
                "error": "Unsafe query detected",
                "query": user_query
            }
    
    def _generate_sql(self, user_query: str, allowed_tables: list = None) -> str:
        """Generate SQL from natural language"""
        tables_info = "\n".join([
            f"- {table}: {', '.join(cols)}"
            for table, cols in self.schema.items()
            if allowed_tables is None or table in allowed_tables
        ])
        
        prompt = ChatPromptTemplate.from_template(
            """You are an expert SQL generator for SQLite databases.           
                DATABASE SCHEMA:
                {schema}
                USER QUERY: {query}
                Generate a safe, efficient SELECT query. Rules:
                1. ONLY use SELECT statements (no INSERT, UPDATE, DELETE, DROP)
                2. Include LIMIT clause if not specified (default LIMIT 100)
                3. Use proper JOINs when needed
                4. Return ONLY the SQL query, no explanations
                SQL Query:""")
        
        response = self.llm.invoke(prompt.format(schema=tables_info, query=user_query))
        
        # Clean the response
        sql = response.content.strip()
        # Remove markdown code blocks if present
        sql = sql.replace("```sql", "").replace("```", "").strip()
        
        return sql
    
    def _is_safe_query(self, sql: str) -> bool:
        """Check if SQL query is safe"""
        sql_upper = sql.upper()
        dangerous_keywords = ['DROP', 'DELETE', 'UPDATE', 'INSERT', 'ALTER', 'CREATE', 'TRUNCATE']
        
        for keyword in dangerous_keywords:
            if keyword in sql_upper:
                print(f"⚠️  Dangerous keyword detected: {keyword}")
                return False
        return True
    
    def _generate_insights(self, query: str, df: pd.DataFrame) -> str:
        """Generate natural language insights"""
        data_summary = df.head(10).to_string()
        
        prompt = ChatPromptTemplate.from_template(
            """Based on this query and results, provide 2-3 key insights in bullet points.
                QUERY: {query}
                RESULTS (first 10 rows):
                {data}
                Key Insights (2-3 bullets):""")
        
        response = self.llm.invoke(prompt.format(query=query, data=data_summary))
        
        return response.content.strip()
    
    def close(self):
        self.conn.close()

analysis_agent = DataAnalysisAgent("../data/chinook.db", llm, schema)

In [11]:
analysis_agent.analyze("Give me artist with most albulmns")


🔍 Query: Give me artist with most albulmns

💻 Generated SQL:
SELECT a.Name, COUNT(al.AlbumId) AS album_count
FROM artists AS a
JOIN albums AS al ON a.ArtistId = al.ArtistId
GROUP BY a.ArtistId
ORDER BY album_count DESC
LIMIT 1;

✅ Query executed: 1 rows returned



{'status': 'success',
 'query': 'Give me artist with most albulmns',
 'sql': 'SELECT a.Name, COUNT(al.AlbumId) AS album_count\nFROM artists AS a\nJOIN albums AS al ON a.ArtistId = al.ArtistId\nGROUP BY a.ArtistId\nORDER BY album_count DESC\nLIMIT 1;',
 'data':           Name  album_count
 0  Iron Maiden           21,
 'insights': '• Iron Maiden tops the list with 21 albums, making them the single most prolific artist in this dataset.  \n• Their 21-album output stands out sharply against all other artists, marking them as a clear outlier.  \n• This level of consistency and longevity suggests Iron Maiden’s enduring productivity over multiple decades.',
 'rows': 1}