In [None]:
# Cell 1: Install packages
!pip install google-generativeai gradio plotly pandas sqlalchemy requests -q

# Import libraries
import pandas as pd
import sqlite3
import google.generativeai as genai
import gradio as gr
import plotly.express as px
import plotly.graph_objects as go
import json
import re
from datetime import datetime
import os

In [None]:
# Cell 2: Setup Gemini API (UPDATED)
from google.colab import userdata
import google.generativeai as genai

# Method 1: Store API key in Colab Secrets (Recommended)
# Go to 🔑 icon in left sidebar → Add secret → Name: GEMINI_API_KEY
try:
    GEMINI_API_KEY = userdata.get('GEMINI_API_KEY')
except:
    # Method 2: Direct input (less secure)
    GEMINI_API_KEY = input("Enter your Gemini API key: ")

genai.configure(api_key=GEMINI_API_KEY)

# ✅ FIXED: Use updated model name
model = genai.GenerativeModel('gemini-1.5-flash')  # Updated model name

print("✅ Gemini API configured successfully!")

# Test the model
try:
    test_response = model.generate_content("Hello, this is a test.")
    print("✅ Model test successful!")
    print(f"Test response: {test_response.text[:50]}...")
except Exception as e:
    print(f"❌ Model test failed: {e}")
    print("\n🔧 Trying alternative model names...")

    # Try other available models
    alternative_models = [
        'gemini-1.5-pro',
        'gemini-1.0-pro',
        'models/gemini-1.5-flash',
        'models/gemini-1.5-pro'
    ]

    for alt_model in alternative_models:
        try:
            model = genai.GenerativeModel(alt_model)
            test_response = model.generate_content("Hello, this is a test.")
            print(f"✅ Success with model: {alt_model}")
            break
        except:
            print(f"❌ Failed with model: {alt_model}")
    else:
        print("❌ All models failed. Please check your API key and try again.")

In [None]:
# Cell 3: Upload your three datasets
from google.colab import files
import pandas as pd

print("📁 Upload your THREE CSV files:")
print("1. Product-Level Eligibility Table")
print("2. Product-Level Ad Sales and Metrics")
print("3. Product-Level Total Sales and Metrics")
print("\nClick 'Choose Files' and select ALL THREE files at once (Ctrl+click or Cmd+click)")

uploaded = files.upload()

# Verify we have 3 files
print(f"\n✅ Total files uploaded: {len(uploaded)}")

# List uploaded files with preview
for filename in uploaded.keys():
    print(f"\n📁 File: {filename}")
    print(f"   Size: {len(uploaded[filename])} bytes")

    # Show first few rows of each file
    try:
        df = pd.read_csv(filename)
        print(f"   Rows: {len(df)}, Columns: {len(df.columns)}")
        print(f"   Column names: {list(df.columns)}")
        print(f"   First few rows:")
        print(df.head(2).to_string())
    except Exception as e:
        print(f"   Error reading file: {e}")

    print("-" * 50)

# Check if we have all required files
expected_files = 3
if len(uploaded) == expected_files:
    print(f"✅ Perfect! All {expected_files} files uploaded successfully!")
elif len(uploaded) < expected_files:
    print(f"⚠️  Warning: Expected {expected_files} files, but only {len(uploaded)} uploaded.")
    print("   You can run this cell again to upload more files.")
else:
    print(f"ℹ️  Info: You uploaded {len(uploaded)} files (more than expected {expected_files}).")
    print("   That's fine - we'll use all of them!")

📁 Upload your THREE CSV files:
1. Product-Level Eligibility Table
2. Product-Level Ad Sales and Metrics
3. Product-Level Total Sales and Metrics

Click 'Choose Files' and select ALL THREE files at once (Ctrl+click or Cmd+click)


Saving Product-Level Ad Sales and Metrics (mapped) - Product-Level Ad Sales and Metrics (mapped).csv to Product-Level Ad Sales and Metrics (mapped) - Product-Level Ad Sales and Metrics (mapped) (1).csv
Saving Product-Level Eligibility Table (mapped) - Product-Level Eligibility Table (mapped).csv to Product-Level Eligibility Table (mapped) - Product-Level Eligibility Table (mapped) (1).csv
Saving Product-Level Total Sales and Metrics (mapped) - Product-Level Total Sales and Metrics (mapped).csv to Product-Level Total Sales and Metrics (mapped) - Product-Level Total Sales and Metrics (mapped) (1).csv

✅ Total files uploaded: 3

📁 File: Product-Level Ad Sales and Metrics (mapped) - Product-Level Ad Sales and Metrics (mapped) (1).csv
   Size: 102585 bytes
   Rows: 3696, Columns: 7
   Column names: ['date', 'item_id', 'ad_sales', 'impressions', 'ad_spend', 'clicks', 'units_sold']
   First few rows:
         date  item_id  ad_sales  impressions  ad_spend  clicks  units_sold
0  2025-06-01    

In [None]:
# Cell 4: Create database (FIXED VERSION)
def setup_database():
    # Create database connection
    conn = sqlite3.connect('ecommerce_data.db')

    # Load uploaded CSV files
    csv_files = [f for f in uploaded.keys() if f.endswith('.csv')]

    tables_created = []
    for csv_file in csv_files:
        # Read CSV
        df = pd.read_csv(csv_file)

        # Clean table name more aggressively
        table_name = csv_file.lower()

        # Remove file extension
        table_name = table_name.replace('.csv', '')

        # Clean up the name - keep only letters, numbers, underscores
        import re
        table_name = re.sub(r'[^a-z0-9_]', '_', table_name)

        # Remove multiple underscores and leading/trailing underscores
        table_name = re.sub(r'_+', '_', table_name).strip('_')

        # Shorten long names
        if len(table_name) > 30:
            if 'ad_sales' in table_name:
                table_name = 'ad_sales_metrics'
            elif 'eligibility' in table_name:
                table_name = 'product_eligibility'
            elif 'total_sales' in table_name:
                table_name = 'total_sales_metrics'
            else:
                table_name = table_name[:30]

        # Ensure it doesn't start with a number
        if table_name[0].isdigit():
            table_name = 't_' + table_name

        # Create table
        df.to_sql(table_name, conn, if_exists='replace', index=False)
        tables_created.append((table_name, list(df.columns)))

        print(f"✅ Created table: {table_name}")
        print(f"   Columns: {list(df.columns)}")
        print(f"   Rows: {len(df)}")
        print(f"   Sample data:")
        print(f"   {df.head(2).to_string()}")
        print()

    conn.close()
    return tables_created

# Run the fixed setup
tables_info = setup_database()

# Show final table summary
print("\n" + "="*50)
print("📊 DATABASE SUMMARY:")
print("="*50)
for table_name, columns in tables_info:
    print(f"Table: {table_name}")
    print(f"Columns: {', '.join(columns)}")
    print()

✅ Created table: ad_sales_metrics
   Columns: ['date', 'item_id', 'ad_sales', 'impressions', 'ad_spend', 'clicks', 'units_sold']
   Rows: 3696
   Sample data:
            date  item_id  ad_sales  impressions  ad_spend  clicks  units_sold
0  2025-06-01        0    332.96         1963     16.87       8           3
1  2025-06-01        1      0.00         1764     20.39      11           0

✅ Created table: product_eligibility
   Columns: ['eligibility_datetime_utc', 'item_id', 'eligibility', 'message']
   Rows: 4381
   Sample data:
     eligibility_datetime_utc  item_id  eligibility                                                                                                                                                                                                                   message
0       2025-06-04 8:50:07       29        False  This product's cost to Amazon does not allow us to meet customers’ pricing expectations. Consider reducing the cost. It may take a few weeks for

In [None]:
# Quick Verification Cell - Run this instead of the debug cell
import sqlite3
import pandas as pd

def verify_clean_tables():
    """Verify that our clean tables are working properly"""
    conn = sqlite3.connect('ecommerce_data.db')

    # Test the clean tables we want to use
    clean_tables = ['ad_sales_metrics', 'product_eligibility', 'total_sales_metrics']

    print("🔍 VERIFYING CLEAN TABLES:")
    print("="*50)

    working_tables = {}

    for table_name in clean_tables:
        try:
            # Test basic query
            query = f"SELECT COUNT(*) as row_count FROM {table_name}"
            result = pd.read_sql_query(query, conn)
            row_count = result.iloc[0]['row_count']

            # Get column info
            columns_query = f"PRAGMA table_info({table_name})"
            columns_df = pd.read_sql_query(columns_query, conn)
            columns = columns_df['name'].tolist()

            # Get sample data
            sample_query = f"SELECT * FROM {table_name} LIMIT 3"
            sample_data = pd.read_sql_query(sample_query, conn)

            print(f"✅ {table_name}:")
            print(f"   Rows: {row_count}")
            print(f"   Columns: {columns}")
            print(f"   Sample data:")
            print(f"   {sample_data.head(2).to_string()}")
            print()

            working_tables[table_name] = {
                'columns': columns,
                'row_count': row_count
            }

        except Exception as e:
            print(f"❌ Error with {table_name}: {e}")

    conn.close()

    # Clean up duplicate tables (optional)
    if len(working_tables) == 3:
        print("🧹 CLEANING UP DUPLICATE TABLES...")
        conn = sqlite3.connect('ecommerce_data.db')
        cursor = conn.cursor()

        # Drop the messy-named duplicate tables
        messy_tables = [
            'product_level_ad_sales_and_metrics_(mapped)___product_level_ad_sales_and_metrics_(mapped)',
            'product_level_eligibility_table_(mapped)___product_level_eligibility_table_(mapped)',
            'product_level_total_sales_and_metrics_(mapped)___product_level_total_sales_and_metrics_(mapped)'
        ]

        for messy_table in messy_tables:
            try:
                cursor.execute(f"DROP TABLE IF EXISTS [{messy_table}]")
                print(f"🗑️ Removed duplicate: {messy_table}")
            except:
                pass

        conn.commit()
        conn.close()

    return working_tables

# Run verification
tables_status = verify_clean_tables()

print("="*50)
print("🎯 SUMMARY:")
if len(tables_status) == 3:
    print("✅ All 3 tables are working perfectly!")
    print("✅ Database is ready for the AI Agent!")
    print("\n🚀 NEXT STEP: Run Cell 5 (AI Agent initialization)")

    # Test the required questions will work
    print("\n🧪 CAPABILITY CHECK:")
    print("✅ Total Sales: Can sum 'total_sales' column")
    print("✅ RoAS: Can calculate 'ad_sales' ÷ 'ad_spend'")
    print("✅ Highest CPC: Can find max 'ad_spend' ÷ 'clicks'")
else:
    print(f"❌ Only {len(tables_status)} tables working. Check your data upload.")

In [None]:
# Cell 5: AI Agent Implementation (ROBUST VERSION)
class EcommerceAIAgent:
    def __init__(self, db_path='ecommerce_data.db'):
        self.db_path = db_path
        self.model = model  # Use the model that was successfully initialized
        self.schema_info = self.get_schema_info()

    def get_schema_info(self):
        """Get database schema information with robust error handling"""
        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()

            # Get all tables
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
            tables = cursor.fetchall()

            if not tables:
                print("❌ No tables found in database!")
                return {}

            schema_info = {}
            for table in tables:
                table_name = table[0]
                try:
                    # Use square brackets to handle any special characters
                    cursor.execute(f"PRAGMA table_info([{table_name}]);")
                    columns = cursor.fetchall()
                    schema_info[table_name] = {
                        'columns': [col[1] for col in columns],
                        'types': [(col[1], col[2]) for col in columns]
                    }
                    print(f"✅ Loaded schema for table: {table_name}")
                except Exception as e:
                    print(f"❌ Error loading schema for {table_name}: {e}")
                    continue

            conn.close()
            return schema_info

        except Exception as e:
            print(f"❌ Database connection error: {e}")
            return {}

    def generate_sql(self, question):
        """Convert natural language question to SQL"""
        if not self.schema_info:
            return "Error: No database schema available"

        schema_text = self._format_schema_for_prompt()

        prompt = f"""
You are a SQL expert. Convert the following question to a SQL query.

Database Schema:
{schema_text}

Question: {question}

Important Rules:
1. Return ONLY the SQL query, no explanations or markdown
2. Use exact table names and column names from the schema above
3. For calculations like RoAS, use (revenue/ad_spend) formula
4. For total sales, sum all revenue/sales columns
5. Always use LIMIT 100 unless asked for more
6. Use proper SQL syntax

SQL Query:
"""

        try:
            response = self.model.generate_content(
                prompt,
                generation_config=genai.GenerationConfig(
                    temperature=0.1,
                    max_output_tokens=300
                )
            )
            sql_query = response.text.strip()

            # Clean the SQL query
            sql_query = re.sub(r'^```sql\s*', '', sql_query)
            sql_query = re.sub(r'^```\s*', '', sql_query)
            sql_query = re.sub(r'\s*```$', '', sql_query)
            sql_query = sql_query.strip()

            return sql_query
        except Exception as e:
            print(f"AI generation error: {e}")
            return self._generate_fallback_sql(question)

    def _generate_fallback_sql(self, question):
        """Generate basic SQL queries when AI fails"""
        question_lower = question.lower()
        table_names = list(self.schema_info.keys())

        if not table_names:
            return "Error: No tables available"

        # Get all columns from all tables
        all_columns = []
        for table_name, info in self.schema_info.items():
            for col in info['columns']:
                all_columns.append((table_name, col))

        if "total sales" in question_lower:
            # Look for sales/revenue columns
            sales_cols = [(t, c) for t, c in all_columns
                         if any(keyword in c.lower() for keyword in ['sales', 'revenue', 'amount'])]

            if sales_cols:
                table, col = sales_cols[0]
                return f"SELECT SUM([{col}]) as total_sales FROM [{table}];"

        elif "roas" in question_lower:
            # Look for revenue and ad_spend columns
            rev_cols = [(t, c) for t, c in all_columns if 'revenue' in c.lower()]
            spend_cols = [(t, c) for t, c in all_columns if 'spend' in c.lower() or 'cost' in c.lower()]

            if rev_cols and spend_cols:
                table = rev_cols[0][0]  # Use first table found
                return f"SELECT *, ([{rev_cols[0][1]}]/[{spend_cols[0][1]}]) as roas FROM [{table}] WHERE [{spend_cols[0][1]}] > 0 LIMIT 10;"

        elif "cpc" in question_lower:
            cpc_cols = [(t, c) for t, c in all_columns if 'cpc' in c.lower()]
            if cpc_cols:
                table, col = cpc_cols[0]
                return f"SELECT * FROM [{table}] ORDER BY [{col}] DESC LIMIT 10;"

        # Default: show first table
        return f"SELECT * FROM [{table_names[0]}] LIMIT 10;"

    def execute_query(self, sql_query):
        """Execute SQL query and return results"""
        try:
            conn = sqlite3.connect(self.db_path)
            result = pd.read_sql_query(sql_query, conn)
            conn.close()
            return result
        except Exception as e:
            error_msg = f"Error executing query: {str(e)}\nSQL: {sql_query}"
            print(error_msg)
            return error_msg

    def format_response(self, question, result, sql_query):
        """Format the response in natural language"""
        if isinstance(result, str):  # Error case
            return result

        if len(result) == 0:
            return "No data found for your query."

        # Simple formatting for reliable results
        if len(result) == 1 and len(result.columns) == 1:
            # Single value result
            value = result.iloc[0, 0]
            return f"Answer: {value:,.2f}" if isinstance(value, (int, float)) else f"Answer: {value}"

        # Table result
        result_preview = result.head(5).to_string()
        return f"Results for '{question}':\n\n{result_preview}\n\n(Showing first 5 rows of {len(result)} total rows)"

    def _format_schema_for_prompt(self):
        """Format schema information for the prompt"""
        schema_text = ""
        for table_name, info in self.schema_info.items():
            schema_text += f"\nTable: {table_name}\n"
            schema_text += f"Columns: {', '.join(info['columns'])}\n"
        return schema_text

    def create_visualization(self, result, question):
        """Create simple visualization"""
        if isinstance(result, str) or len(result) == 0:
            return None

        try:
            if len(result.columns) == 2 and len(result) <= 20 and result.iloc[:, 1].dtype in ['int64', 'float64']:
                fig = px.bar(result, x=result.columns[0], y=result.columns[1],
                           title=f"Chart: {question}")
                return fig
        except:
            pass
        return None

    def process_question(self, question):
        """Main method to process a question"""
        # Generate SQL
        sql_query = self.generate_sql(question)

        if sql_query.startswith("Error"):
            return {
                'question': question,
                'sql_query': sql_query,
                'result': sql_query,
                'chart': None
            }

        # Execute query
        result = self.execute_query(sql_query)

        # Format response
        formatted_response = self.format_response(question, result, sql_query)

        # Create visualization
        chart = self.create_visualization(result, question)

        return {
            'question': question,
            'sql_query': sql_query,
            'result': formatted_response,
            'data': result,
            'chart': chart
        }

# Initialize agent with comprehensive error handling
print("🚀 Initializing AI Agent...")

try:
    agent = EcommerceAIAgent()

    if agent.schema_info:
        print("✅ AI Agent initialized successfully!")
        print(f"📊 Available tables: {list(agent.schema_info.keys())}")

        for table_name, info in agent.schema_info.items():
            print(f"\n📋 Table: {table_name}")
            print(f"   Columns ({len(info['columns'])}): {', '.join(info['columns'][:5])}{'...' if len(info['columns']) > 5 else ''}")

        print("\n🎯 Ready to answer questions!")

    else:
        print("❌ No schema loaded. Database may be empty or corrupted.")

except Exception as e:
    print(f"❌ Error initializing agent: {e}")
    print("\n🔧 Troubleshooting steps:")
    print("1. Run the debug cell above first")
    print("2. Make sure your database was created successfully")
    print("3. Check that your Excel files were uploaded properly")

🚀 Initializing AI Agent...
✅ Loaded schema for table: ad_sales_metrics
✅ Loaded schema for table: product_eligibility
✅ Loaded schema for table: total_sales_metrics
✅ AI Agent initialized successfully!
📊 Available tables: ['ad_sales_metrics', 'product_eligibility', 'total_sales_metrics']

📋 Table: ad_sales_metrics
   Columns (7): date, item_id, ad_sales, impressions, ad_spend...

📋 Table: product_eligibility
   Columns (4): eligibility_datetime_utc, item_id, eligibility, message

📋 Table: total_sales_metrics
   Columns (4): date, item_id, total_sales, total_units_ordered

🎯 Ready to answer questions!


In [None]:
# Cell 6: Enhanced Conversational AI Agent with Chat Memory
import time
from datetime import datetime

class ConversationalEcommerceAgent:
    def __init__(self, db_path='ecommerce_data.db'):
        self.db_path = db_path
        self.model = model  # Use existing Gemini model
        self.schema_info = self.get_schema_info()
        self.conversation_history = []
        self.context_memory = {
            'last_queries': [],
            'user_preferences': {},
            'frequently_asked': {}
        }

    def get_schema_info(self):
        """Get database schema with sample data for better context"""
        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()

            cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
            tables = cursor.fetchall()

            schema_info = {}
            for table in tables:
                table_name = table[0]
                # Get column info
                cursor.execute(f"PRAGMA table_info([{table_name}]);")
                columns = cursor.fetchall()

                # Get sample data for context
                cursor.execute(f"SELECT * FROM [{table_name}] LIMIT 3;")
                sample_data = cursor.fetchall()

                # Get data ranges for numeric columns
                numeric_columns = [col[1] for col in columns if col[2] in ['REAL', 'INTEGER']]
                ranges = {}
                for col in numeric_columns:
                    try:
                        cursor.execute(f"SELECT MIN([{col}]), MAX([{col}]), AVG([{col}]) FROM [{table_name}] WHERE [{col}] IS NOT NULL;")
                        min_val, max_val, avg_val = cursor.fetchone()
                        ranges[col] = {'min': min_val, 'max': max_val, 'avg': avg_val}
                    except:
                        continue

                schema_info[table_name] = {
                    'columns': [col[1] for col in columns],
                    'types': [(col[1], col[2]) for col in columns],
                    'sample_data': sample_data,
                    'ranges': ranges,
                    'row_count': self._get_row_count(cursor, table_name)
                }

            conn.close()
            return schema_info

        except Exception as e:
            print(f"Error loading schema: {e}")
            return {}

    def _get_row_count(self, cursor, table_name):
        """Get total row count for a table"""
        try:
            cursor.execute(f"SELECT COUNT(*) FROM [{table_name}];")
            return cursor.fetchone()[0]
        except:
            return 0

    def understand_intent(self, question):
        """Enhanced intent understanding with conversation context"""
        question_lower = question.lower()

        # Check for conversation context
        context_prompt = ""
        if self.conversation_history:
            recent_questions = [item['question'] for item in self.conversation_history[-3:]]
            context_prompt = f"\nPrevious questions in this conversation: {recent_questions}"

        # Enhanced schema context
        schema_context = self._create_rich_schema_context()

        intent_prompt = f"""
You are an AI assistant analyzing e-commerce data questions. Based on the conversation context and database schema, understand what the user wants.

Database Context:
{schema_context}
{context_prompt}

Current Question: "{question}"

Analyze the intent and provide:
1. Primary goal (what they want to know)
2. Relevant tables to query
3. Key metrics/columns needed
4. Any filtering or grouping requirements
5. Expected output format

Respond in JSON format:
{{
    "intent": "brief description",
    "tables": ["table1", "table2"],
    "metrics": ["column1", "column2"],
    "filters": "any filtering needed",
    "output_type": "single_value|table|chart|comparison"
}}
"""

        try:
            response = self.model.generate_content(intent_prompt)
            intent_text = response.text.strip()

            # Extract JSON from response
            import json
            if '{' in intent_text and '}' in intent_text:
                json_start = intent_text.find('{')
                json_end = intent_text.rfind('}') + 1
                intent_json = json.loads(intent_text[json_start:json_end])
                return intent_json
        except:
            pass

        # Fallback intent detection
        return self._basic_intent_detection(question_lower)

    def _basic_intent_detection(self, question_lower):
        """Fallback intent detection"""
        if any(word in question_lower for word in ['total', 'sum', 'all']):
            return {
                "intent": "Get total/sum of metrics",
                "tables": ["total_sales_metrics", "ad_sales_metrics"],
                "metrics": ["total_sales", "ad_sales"],
                "filters": "",
                "output_type": "single_value"
            }
        elif any(word in question_lower for word in ['roas', 'return on ad spend']):
            return {
                "intent": "Calculate Return on Ad Spend",
                "tables": ["ad_sales_metrics"],
                "metrics": ["ad_sales", "ad_spend"],
                "filters": "ad_spend > 0",
                "output_type": "table"
            }
        elif any(word in question_lower for word in ['cpc', 'cost per click']):
            return {
                "intent": "Calculate Cost Per Click",
                "tables": ["ad_sales_metrics"],
                "metrics": ["ad_spend", "clicks"],
                "filters": "clicks > 0",
                "output_type": "table"
            }
        else:
            return {
                "intent": "General data exploration",
                "tables": list(self.schema_info.keys()),
                "metrics": [],
                "filters": "",
                "output_type": "table"
            }

    def _create_rich_schema_context(self):
        """Create detailed schema context for better AI understanding"""
        context = "DATABASE SCHEMA AND SAMPLE DATA:\n"

        for table_name, info in self.schema_info.items():
            context += f"\nTable: {table_name} ({info['row_count']} rows)\n"
            context += f"Columns: {', '.join(info['columns'])}\n"

            if info['ranges']:
                context += "Data Ranges:\n"
                for col, range_info in info['ranges'].items():
                    context += f"  {col}: {range_info['min']:.2f} to {range_info['max']:.2f} (avg: {range_info['avg']:.2f})\n"

            if info['sample_data']:
                context += f"Sample Data: {info['sample_data'][0]}\n"

        return context

    def generate_enhanced_sql(self, question, intent):
        """Generate SQL with enhanced context and conversation memory"""
        schema_text = self._create_rich_schema_context()

        # Add conversation context
        conversation_context = ""
        if self.conversation_history:
            conversation_context = "\nPrevious successful queries:\n"
            for item in self.conversation_history[-2:]:
                if not item['sql_query'].startswith('Error'):
                    conversation_context += f"Q: {item['question']}\nSQL: {item['sql_query']}\n\n"

        prompt = f"""
You are an expert SQL analyst for e-commerce data. Generate precise SQL queries.

{schema_text}
{conversation_context}

User Intent Analysis:
- Goal: {intent['intent']}
- Relevant Tables: {intent['tables']}
- Key Metrics: {intent['metrics']}
- Expected Output: {intent['output_type']}

Current Question: "{question}"

IMPORTANT RULES:
1. Use EXACT table and column names from schema above
2. For calculations like RoAS: use (ad_sales/ad_spend) only where ad_spend > 0
3. For CPC: use (ad_spend/clicks) only where clicks > 0
4. Always include meaningful column aliases
5. Use JOINs when combining data from multiple tables
6. Add ORDER BY for meaningful sorting
7. Use LIMIT appropriately (10-100 rows for tables, no limit for single values)

Generate only the SQL query (no explanations):
"""

        try:
            response = self.model.generate_content(
                prompt,
                generation_config=genai.GenerationConfig(
                    temperature=0.1,
                    max_output_tokens=500
                )
            )

            sql_query = response.text.strip()
            sql_query = re.sub(r'^```sql\s*', '', sql_query, flags=re.IGNORECASE)
            sql_query = re.sub(r'^```\s*', '', sql_query)
            sql_query = re.sub(r'\s*```$', '', sql_query)
            sql_query = sql_query.strip()

            return sql_query

        except Exception as e:
            print(f"SQL generation error: {e}")
            return self._generate_smart_fallback(question, intent)

    def _generate_smart_fallback(self, question, intent):
        """Smart fallback SQL generation"""
        question_lower = question.lower()

        if intent['output_type'] == 'single_value':
            if 'total' in question_lower and 'sales' in question_lower:
                return "SELECT SUM(total_sales) as total_sales FROM total_sales_metrics;"

        elif 'roas' in question_lower:
            return """
            SELECT item_id,
                   ad_sales,
                   ad_spend,
                   ROUND(ad_sales/ad_spend, 2) as roas
            FROM ad_sales_metrics
            WHERE ad_spend > 0
            ORDER BY roas DESC
            LIMIT 10;
            """

        elif 'cpc' in question_lower:
            return """
            SELECT item_id,
                   ad_spend,
                   clicks,
                   ROUND(ad_spend/clicks, 2) as cpc
            FROM ad_sales_metrics
            WHERE clicks > 0
            ORDER BY cpc DESC
            LIMIT 10;
            """

        # Default exploration
        return "SELECT * FROM ad_sales_metrics LIMIT 10;"

    def execute_query_with_retry(self, sql_query, max_retries=2):
        """Execute query with retry logic and error handling"""
        for attempt in range(max_retries + 1):
            try:
                conn = sqlite3.connect(self.db_path)
                result = pd.read_sql_query(sql_query, conn)
                conn.close()
                return result

            except Exception as e:
                if attempt < max_retries:
                    # Try to fix common SQL errors
                    fixed_query = self._fix_common_sql_errors(sql_query, str(e))
                    if fixed_query != sql_query:
                        sql_query = fixed_query
                        continue

                return f"Query execution failed: {str(e)}"

        return "Query failed after retries"

    def _fix_common_sql_errors(self, sql_query, error_msg):
        """Fix common SQL errors automatically"""
        if "no such column" in error_msg.lower():
            # Try to fix column name issues
            for table_name, info in self.schema_info.items():
                for col in info['columns']:
                    if col.lower() in sql_query.lower():
                        sql_query = sql_query.replace(col, f"[{col}]")

        if "no such table" in error_msg.lower():
            # Try to fix table name issues
            for table_name in self.schema_info.keys():
                if table_name.lower() in sql_query.lower():
                    sql_query = sql_query.replace(table_name, f"[{table_name}]")

        return sql_query

    def format_conversational_response(self, question, result, sql_query, intent):
        """Format response in a conversational way"""
        if isinstance(result, str):  # Error case
            return f"I'm sorry, I encountered an issue: {result}\n\nLet me try a different approach. Could you rephrase your question?"

        if len(result) == 0:
            return "I didn't find any data matching your question. Could you try asking about something else?"

        # Format based on intent and result type
        if intent['output_type'] == 'single_value' and len(result) == 1 and len(result.columns) == 1:
            value = result.iloc[0, 0]
            if isinstance(value, (int, float)):
                if value > 1000000:
                    formatted_value = f"${value:,.0f}"
                elif value > 1000:
                    formatted_value = f"${value:,.2f}"
                else:
                    formatted_value = f"{value:.2f}"
            else:
                formatted_value = str(value)

            return f"Based on your data, the answer is: **{formatted_value}**\n\nThis calculation was performed across {self._get_data_context(sql_query)} records in your database."

        # Table results with conversational context
        response = f"Here's what I found for your question:\n\n"

        if len(result) > 10:
            response += f"**Top Results** (showing 10 of {len(result)} total):\n\n"
            display_result = result.head(10)
        else:
            response += f"**All Results** ({len(result)} items):\n\n"
            display_result = result

        response += display_result.to_string(index=False)
        response += f"\n\n💡 **Insights**: {self._generate_insights(result, question)}"

        return response

    def _get_data_context(self, sql_query):
        """Extract data context from SQL query for better responses"""
        if 'total_sales_metrics' in sql_query:
            return f"{self.schema_info.get('total_sales_metrics', {}).get('row_count', 'many')}"
        elif 'ad_sales_metrics' in sql_query:
            return f"{self.schema_info.get('ad_sales_metrics', {}).get('row_count', 'many')}"
        return "multiple"

    def _generate_insights(self, result, question):
        """Generate quick insights from the results"""
        if len(result) == 0:
            return "No data found."

        insights = []

        # Numeric column insights
        numeric_cols = result.select_dtypes(include=[np.number]).columns
        for col in numeric_cols:
            if len(result[col]) > 1:
                max_val = result[col].max()
                min_val = result[col].min()
                avg_val = result[col].mean()

                if 'roas' in col.lower():
                    insights.append(f"RoAS ranges from {min_val:.2f} to {max_val:.2f}")
                elif 'cpc' in col.lower():
                    insights.append(f"CPC ranges from ${min_val:.2f} to ${max_val:.2f}")
                elif 'sales' in col.lower():
                    insights.append(f"Sales range from ${min_val:,.2f} to ${max_val:,.2f}")

        return ". ".join(insights[:2]) if insights else "Data retrieved successfully."

    def chat(self, question):
        """Main conversational interface"""
        start_time = time.time()

        # Understand user intent
        intent = self.understand_intent(question)

        # Generate SQL with context
        sql_query = self.generate_enhanced_sql(question, intent)

        # Execute query with retry logic
        result = self.execute_query_with_retry(sql_query)

        # Format conversational response
        formatted_response = self.format_conversational_response(question, result, sql_query, intent)

        # Create visualization if appropriate
        chart = self._create_smart_visualization(result, question, intent)

        # Store in conversation history
        conversation_item = {
            'timestamp': datetime.now(),
            'question': question,
            'intent': intent,
            'sql_query': sql_query,
            'result': formatted_response,
            'execution_time': time.time() - start_time,
            'data': result if not isinstance(result, str) else None
        }

        self.conversation_history.append(conversation_item)

        # Update context memory
        self._update_context_memory(question, intent)

        return {
            'question': question,
            'sql_query': sql_query,
            'response': formatted_response,
            'chart': chart,
            'execution_time': conversation_item['execution_time'],
            'intent': intent
        }

    def _create_smart_visualization(self, result, question, intent):
        """Create intelligent visualizations based on data and context"""
        if isinstance(result, str) or len(result) == 0:
            return None

        try:
            # Don't create charts for single values
            if len(result) == 1 and len(result.columns) == 1:
                return None

            # Chart for RoAS data
            if intent['output_type'] == 'table' and any('roas' in col.lower() for col in result.columns):
                if len(result) <= 20:
                    roas_col = [col for col in result.columns if 'roas' in col.lower()][0]
                    item_col = [col for col in result.columns if 'item' in col.lower()]
                    if item_col:
                        fig = px.bar(result.head(15),
                                   x=item_col[0],
                                   y=roas_col,
                                   title="Return on Ad Spend (RoAS) by Product",
                                   labels={roas_col: "RoAS", item_col[0]: "Product ID"})
                        return fig

            # Chart for CPC data
            elif intent['output_type'] == 'table' and any('cpc' in col.lower() for col in result.columns):
                if len(result) <= 20:
                    cpc_col = [col for col in result.columns if 'cpc' in col.lower()][0]
                    item_col = [col for col in result.columns if 'item' in col.lower()]
                    if item_col:
                        fig = px.bar(result.head(15),
                                   x=item_col[0],
                                   y=cpc_col,
                                   title="Cost Per Click (CPC) by Product",
                                   labels={cpc_col: "CPC ($)", item_col[0]: "Product ID"})
                        return fig

            # Time series charts
            elif any('date' in col.lower() for col in result.columns):
                date_col = [col for col in result.columns if 'date' in col.lower()][0]
                numeric_cols = result.select_dtypes(include=[np.number]).columns
                if len(numeric_cols) > 0:
                    value_col = numeric_cols[0]
                    fig = px.line(result,
                                x=date_col,
                                y=value_col,
                                title=f"{value_col.title()} Over Time")
                    return fig

        except Exception as e:
            print(f"Visualization error: {e}")

        return None

    def _update_context_memory(self, question, intent):
        """Update conversation context for better future responses"""
        # Track frequently asked questions
        q_key = question.lower()
        self.context_memory['frequently_asked'][q_key] = self.context_memory['frequently_asked'].get(q_key, 0) + 1

        # Track user preferences
        if intent['output_type'] == 'chart':
            self.context_memory['user_preferences']['likes_charts'] = True

        # Keep last 10 queries for context
        self.context_memory['last_queries'].append(question)
        if len(self.context_memory['last_queries']) > 10:
            self.context_memory['last_queries'].pop(0)

    def get_conversation_summary(self):
        """Get a summary of the conversation"""
        if not self.conversation_history:
            return "No conversation history yet."

        summary = f"**Conversation Summary** ({len(self.conversation_history)} questions asked)\n\n"

        for i, item in enumerate(self.conversation_history[-5:], 1):
            summary += f"{i}. Q: {item['question']}\n"
            summary += f"   Intent: {item['intent']['intent']}\n"
            summary += f"   Time: {item['execution_time']:.2f}s\n\n"

        return summary

# Initialize the enhanced conversational agent
print("🚀 Initializing Conversational AI Agent...")

try:
    conv_agent = ConversationalEcommerceAgent()

    if conv_agent.schema_info:
        print("✅ Conversational AI Agent ready!")
        print(f"📊 Database loaded with {len(conv_agent.schema_info)} tables")

        # Show capabilities
        print("\n🎯 ENHANCED CAPABILITIES:")
        print("✅ Unlimited questions and follow-ups")
        print("✅ Conversation memory and context")
        print("✅ Smart intent understanding")
        print("✅ Automatic error recovery")
        print("✅ Intelligent visualizations")
        print("✅ Conversational responses")

        print("\n💬 Ready for real-time conversation!")

    else:
        print("❌ Failed to load database schema")

except Exception as e:
    print(f"❌ Error initializing conversational agent: {e}")

🚀 Initializing Conversational AI Agent...
✅ Conversational AI Agent ready!
📊 Database loaded with 3 tables

🎯 ENHANCED CAPABILITIES:
✅ Unlimited questions and follow-ups
✅ Conversation memory and context
✅ Smart intent understanding
✅ Automatic error recovery
✅ Intelligent visualizations
✅ Conversational responses

💬 Ready for real-time conversation!


In [None]:
# Enhanced Gradio Interface for Real-Time Conversation
import gradio as gr
import numpy as np

def chat_with_agent(message, history):
    """Process user message and return conversational response"""
    if not message.strip():
        return history, ""

    try:
        # Get response from conversational agent
        response = conv_agent.chat(message)

        # Format the response for chat interface
        bot_response = f"**🤖 AI Agent Response:**\n\n{response['response']}"

        # Add SQL query info (collapsible)
        if not response['sql_query'].startswith('Error'):
            bot_response += f"\n\n<details><summary>🔍 <i>View SQL Query</i></summary>\n\n```sql\n{response['sql_query']}\n```\n</details>"

        # Add execution time
        bot_response += f"\n\n⚡ *Processed in {response['execution_time']:.2f} seconds*"

        # Update history
        history.append([message, bot_response])

        return history, ""

    except Exception as e:
        error_response = f"I apologize, but I encountered an error: {str(e)}\n\nPlease try rephrasing your question or ask something else!"
        history.append([message, error_response])
        return history, ""

def show_chart_for_last_query():
    """Show chart for the last query if available"""
    if conv_agent.conversation_history:
        last_item = conv_agent.conversation_history[-1]
        chart = conv_agent._create_smart_visualization(
            last_item.get('data'),
            last_item['question'],
            last_item['intent']
        )
        return chart
    return None

def get_conversation_summary():
    """Get conversation summary"""
    return conv_agent.get_conversation_summary()

def clear_conversation():
    """Clear conversation history"""
    conv_agent.conversation_history = []
    return [], ""

# Create the enhanced Gradio interface
with gr.Blocks(
    title="E-commerce AI Agent - Real-Time Conversation",
    theme=gr.themes.Soft(),
    css="""
    .gradio-container {
        max-width: 1200px !important;
    }
    .chat-message {
        border-radius: 10px !important;
    }
    """
) as demo:

    gr.Markdown("""
    # 🛒 E-commerce AI Agent - Real-Time Conversation

    """)

    with gr.Row():
        with gr.Column(scale=2):
            # Main chat interface
            chatbot = gr.Chatbot(
                height=400,
                show_label=False,
                container=True,
                bubble_full_width=False
            )

            with gr.Row():
                msg = gr.Textbox(
                    placeholder="Ask me about your e-commerce data... (e.g., 'What are my top performing products?')",
                    show_label=False,
                    scale=4,
                    container=False
                )
                send_btn = gr.Button("Send 📤", scale=1, variant="primary")
                clear_btn = gr.Button("Clear 🗑️", scale=1, variant="secondary")

            # Example questions
            with gr.Row():
                gr.Examples(
                    examples=[
                        ["What is my total sales revenue?"],
                        ["Calculate the RoAS for all products"],
                        ["Which product has the highest CPC?"],
                        ["Show me products with RoAS above 5"],
                        ["What's the average conversion rate?"],
                        ["Which products are not eligible for advertising?"],
                        ["Compare ad spend vs organic sales"],
                        ["Show me the worst performing products"],
                        ["What's my most profitable product?"],
                        ["Analyze trends in my sales data"]
                    ],
                    inputs=msg,
                    label="💡 Try these example questions:"
                )

        with gr.Column(scale=1):
            # Visualization panel
            gr.Markdown("### 📊 Visualizations")
            chart_display = gr.Plot(label="Chart")
            chart_btn = gr.Button("Show Chart for Last Query 📈", variant="secondary")

            # Conversation summary
            gr.Markdown("### 📝 Conversation Summary")
            summary_display = gr.Textbox(
                label="Summary",
                lines=6,
                interactive=False
            )
            summary_btn = gr.Button("Update Summary 📋", variant="secondary")

    # Event handlers
    def respond_and_update(message, history):
        """Handle message and update all components"""
        new_history, cleared_msg = chat_with_agent(message, history)
        chart = show_chart_for_last_query()
        summary = get_conversation_summary()
        return new_history, cleared_msg, chart, summary

    # Button events
    send_btn.click(
        respond_and_update,
        inputs=[msg, chatbot],
        outputs=[chatbot, msg, chart_display, summary_display]
    )

    msg.submit(
        respond_and_update,
        inputs=[msg, chatbot],
        outputs=[chatbot, msg, chart_display, summary_display]
    )

    clear_btn.click(
        clear_conversation,
        outputs=[chatbot, summary_display]
    )

    chart_btn.click(
        show_chart_for_last_query,
        outputs=chart_display
    )

    summary_btn.click(
        get_conversation_summary,
        outputs=summary_display
    )

    gr.Markdown("""

    """)

# Launch the enhanced interface
print("🚀 Launching Real-Time Conversational Interface...")

demo.launch(
    share=True,
    debug=True,
    server_name="0.0.0.0",
    server_port=7860,
    show_error=True
)

print("✅ Real-time conversational interface is now running!")
print("💬 You can now have unlimited conversations with your AI agent!")

In [None]:
# Cell 7: Test required questions
required_questions = [
    "What is my total sales?",
    "Calculate the RoAS (Return on Ad Spend)",
    "Which product had the highest CPC (Cost Per Click)?"
]

print("🧪 Testing Required Questions:\n" + "="*50)

for i, question in enumerate(required_questions, 1):
    print(f"\n{i}. Question: {question}")
    print("-" * 40)

    result = agent.process_question(question)

    print(f"SQL Query: {result['sql_query']}")
    print(f"Answer: {result['result']}")

    if result.get('chart'):
        result['chart'].show()

    print("\n" + "="*50)

🧪 Testing Required Questions:

1. Question: What is my total sales?
----------------------------------------
SQL Query: SELECT SUM(ad_sales) + SUM(total_sales) AS total_sales FROM ad_sales_metrics AS asm JOIN total_sales_metrics AS tsm ON asm.item_id = tsm.item_id LIMIT 100;
Answer: Answer: 18,974,321.79


2. Question: Calculate the RoAS (Return on Ad Spend)
----------------------------------------
SQL Query: SELECT
  (SUM(ad_sales) + SUM(total_sales)) / SUM(ad_spend) AS RoAS
FROM ad_sales_metrics
JOIN total_sales_metrics
  ON ad_sales_metrics.item_id = total_sales_metrics.item_id AND ad_sales_metrics.date = total_sales_metrics.date
LIMIT 100;
Answer: Answer: 27.15


3. Question: Which product had the highest CPC (Cost Per Click)?
----------------------------------------
SQL Query: SELECT item_id
FROM ad_sales_metrics
WHERE clicks > 0
ORDER BY ad_spend/clicks DESC
LIMIT 1;
Answer: Answer: 22



In [None]:
# Cell 8: Export functionality
def export_project():
    """Export the complete project"""

    # Create a summary report
    report = {
        'project_name': 'E-commerce AI Agent',
        'created_at': datetime.now().isoformat(),
        'database_schema': agent.schema_info,
        'test_results': []
    }

    # Test all required questions
    for question in required_questions:
        result = agent.process_question(question)
        report['test_results'].append({
            'question': question,
            'sql_query': result['sql_query'],
            'answer': result['result']
        })

    # Save report
    with open('project_report.json', 'w') as f:
        json.dump(report, f, indent=2)

    # Create Python script version
    script_content = f'''
# E-commerce AI Agent - Standalone Version
# Generated from Google Colab

import google.generativeai as genai
import pandas as pd
import sqlite3
import gradio as gr

# Configure your API key
genai.configure(api_key="YOUR_API_KEY_HERE")

# [Insert the complete EcommerceAIAgent class here]
class EcommerceAIAgent:
    # ... (copy the class from Cell 5)

# Usage example:
if __name__ == "__main__":
    agent = EcommerceAIAgent()

    # Test questions
    questions = {required_questions}

    for question in questions:
        result = agent.process_question(question)
        print(f"Q: {{result['question']}}")
        print(f"A: {{result['result']}}\\n")
'''

    with open('ecommerce_ai_agent.py', 'w') as f:
        f.write(script_content)

    # Download files
    files.download('project_report.json')
    files.download('ecommerce_ai_agent.py')
    files.download('ecommerce_data.db')

    print("✅ Project exported successfully!")
    print("📁 Files downloaded: project_report.json, ecommerce_ai_agent.py, ecommerce_data.db")

# Call export function
export_project()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✅ Project exported successfully!
📁 Files downloaded: project_report.json, ecommerce_ai_agent.py, ecommerce_data.db
