# AI Data Analysis Agent

* An intelligent agent that analyzes datasets using SQL queries and provides insights through natural language processing.
* The agent can handle CSV/Excel files, perform statistical analysis, create visualizations, and answer questions about data in plain English.
* Features include data preprocessing, automatic chart generation, correlation analysis, and comprehensive data insights.
* Users can upload their data files and ask questions like "What are the top 10 sales by region?" or "Show me a trend analysis of monthly revenue.
* The agent provides both automated analysis and custom visualization capabilities

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Dhivya-Bharathy/PraisonAI/blob/main/examples/cookbooks/ai_data_analysis_agent.ipynb)


# Dependencies

In [None]:
!pip install praisonai streamlit openai duckdb pandas numpy plotly matplotlib seaborn

# Setup Key

In [6]:
import os
openai_key = "sk-.."

os.environ["OPENAI_API_KEY"] = openai_key
model_choice = "gpt-5-nano"

print("✅ API key configured!")
print(f"✅ Using model: {model_choice}")

✅ API key configured!
✅ Using model: gpt-5-nano


# Tools

In [7]:
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from typing import Any

class DataVisualizationTool:
    def __init__(self):
        self.supported_charts = ['bar', 'line', 'scatter', 'histogram', 'box', 'pie', 'heatmap', 'area']

    def create_visualization(self, df: pd.DataFrame, chart_type: str, x_column: str, y_column: str = None, title: str = None) -> Any:
        """Create various types of data visualizations"""
        try:
            if chart_type == 'bar':
                fig = px.bar(df, x=x_column, y=y_column, title=title, color_discrete_sequence=['#1f77b4'])
            elif chart_type == 'line':
                fig = px.line(df, x=x_column, y=y_column, title=title, color_discrete_sequence=['#2ca02c'])
            elif chart_type == 'scatter':
                fig = px.scatter(df, x=x_column, y=y_column, title=title, color_discrete_sequence=['#ff7f0e'])
            elif chart_type == 'histogram':
                fig = px.histogram(df, x=x_column, title=title, color_discrete_sequence=['#d62728'])
            elif chart_type == 'box':
                fig = px.box(df, x=x_column, y=y_column, title=title, color_discrete_sequence=['#9467bd'])
            elif chart_type == 'pie':
                fig = px.pie(df, values=y_column, names=x_column, title=title)
            elif chart_type == 'heatmap':
                corr_matrix = df.corr()
                fig = px.imshow(corr_matrix, title=title, color_continuous_scale='RdBu')
            elif chart_type == 'area':
                fig = px.area(df, x=x_column, y=y_column, title=title, color_discrete_sequence=['#8c564b'])
            else:
                return "Unsupported chart type"

            fig.update_layout(
                template="plotly_white",
                font=dict(size=12),
                margin=dict(l=50, r=50, t=50, b=50)
            )
            return fig
        except Exception as e:
            return f"Error creating visualization: {str(e)}"

# Custom Data Preprocessing Tool
import tempfile
import csv

class DataPreprocessingTool:
    def __init__(self):
        self.supported_formats = ['.csv', '.xlsx']

    def preprocess_file(self, file) -> tuple:
        """Preprocess uploaded file and return processed data"""
        try:
            if file.name.endswith('.csv'):
                df = pd.read_csv(file, encoding='utf-8', na_values=['NA', 'N/A', 'missing'])
            elif file.name.endswith('.xlsx'):
                df = pd.read_excel(file, na_values=['NA', 'N/A', 'missing'])
            else:
                return None, None, None, "Unsupported file format"

            # Clean and preprocess data
            for col in df.select_dtypes(include=['object']):
                df[col] = df[col].astype(str).replace({r'"': '""'}, regex=True)

            # Parse dates and numeric columns
            for col in df.columns:
                if 'date' in col.lower():
                    df[col] = pd.to_datetime(df[col], errors='coerce')
                elif df[col].dtype == 'object':
                    try:
                        df[col] = pd.to_numeric(df[col])
                    except (ValueError, TypeError):
                        pass

            # Create temporary file
            with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_file:
                temp_path = temp_file.name
                df.to_csv(temp_path, index=False, quoting=csv.QUOTE_ALL)

            return temp_path, df.columns.tolist(), df, None
        except Exception as e:
            return None, None, None, f"Error processing file: {e}"

# Custom Statistical Analysis Tool
from typing import Dict, List

class StatisticalAnalysisTool:
    def __init__(self):
        self.analysis_types = ['descriptive', 'correlation', 'outliers', 'trends', 'patterns']

    def analyze_data(self, df: pd.DataFrame, analysis_type: str) -> Dict[str, Any]:
        """Perform statistical analysis on the dataset"""
        try:
            results = {}

            if analysis_type == 'descriptive':
                results['summary'] = df.describe()
                results['info'] = {
                    'rows': len(df),
                    'columns': len(df.columns),
                    'missing_values': df.isnull().sum().sum(),
                    'duplicates': len(df[df.duplicated()])
                }

            elif analysis_type == 'correlation':
                numeric_df = df.select_dtypes(include=[np.number])
                if len(numeric_df.columns) > 1:
                    results['correlation_matrix'] = numeric_df.corr()
                    results['high_correlations'] = self._find_high_correlations(results['correlation_matrix'])

            elif analysis_type == 'outliers':
                numeric_df = df.select_dtypes(include=[np.number])
                results['outliers'] = self._detect_outliers(numeric_df)

            elif analysis_type == 'trends':
                date_cols = df.select_dtypes(include=['datetime64']).columns
                if len(date_cols) > 0:
                    results['time_series'] = self._analyze_trends(df, date_cols[0])

            return results
        except Exception as e:
            return {'error': f"Analysis error: {str(e)}"}

    def _find_high_correlations(self, corr_matrix: pd.DataFrame, threshold: float = 0.7) -> List[tuple]:
        """Find highly correlated variable pairs"""
        high_corr = []
        for i in range(len(corr_matrix.columns)):
            for j in range(i+1, len(corr_matrix.columns)):
                if abs(corr_matrix.iloc[i, j]) > threshold:
                    high_corr.append((corr_matrix.columns[i], corr_matrix.columns[j], corr_matrix.iloc[i, j]))
        return high_corr

    def _detect_outliers(self, df: pd.DataFrame) -> Dict[str, List[int]]:
        """Detect outliers using IQR method"""
        outliers = {}
        for col in df.columns:
            Q1 = df[col].quantile(0.25)
            Q3 = df[col].quantile(0.75)
            IQR = Q3 - Q1
            outlier_indices = df[(df[col] < Q1 - 1.5*IQR) | (df[col] > Q3 + 1.5*IQR)].index.tolist()
            if outlier_indices:
                outliers[col] = outlier_indices
        return outliers

    def _analyze_trends(self, df: pd.DataFrame, date_col: str) -> Dict[str, Any]:
        """Analyze time series trends"""
        df_sorted = df.sort_values(date_col)
        numeric_cols = df.select_dtypes(include=[np.number]).columns

        trends = {}
        for col in numeric_cols:
            if col != date_col:
                # Simple trend analysis
                values = df_sorted[col].dropna()
                if len(values) > 1:
                    trend = np.polyfit(range(len(values)), values, 1)[0]
                    trends[col] = {
                        'trend_direction': 'increasing' if trend > 0 else 'decreasing',
                        'trend_strength': abs(trend),
                        'mean': values.mean(),
                        'std': values.std()
                    }
        return trends

# YAML Prompt

In [9]:
# YAML Prompt
yaml_prompt = """
name: "AI Data Analysis Agent"
description: "Expert data analyst with SQL and visualization capabilities"
instructions:
  - "You are an expert data analyst with deep knowledge of statistics, SQL, and data visualization"
  - "Analyze user queries and provide comprehensive insights about their data"
  - "Generate appropriate SQL queries when needed for data analysis"
  - "Suggest relevant visualizations based on data types and analysis goals"
  - "Provide actionable insights and recommendations based on data patterns"
  - "Always explain your analysis process and findings clearly"
  - "Use markdown formatting for better readability"
  - "Include statistical significance when applicable"
  - "Highlight any data quality issues or anomalies discovered"
  - "Focus on practical business insights and actionable recommendations"

tools:
  - name: "DataVisualizationTool"
    description: "Creates various types of data visualizations including bar, line, scatter, histogram, box, pie, heatmap, and area charts"
  - name: "DataPreprocessingTool"
    description: "Handles file upload, data cleaning, type conversion, and preprocessing for analysis"
  - name: "StatisticalAnalysisTool"
    description: "Performs descriptive statistics, correlation analysis, outlier detection, and trend analysis"

output_format:
  - "Provide clear, structured analysis results"
  - "Include relevant visualizations when appropriate"
  - "Summarize key findings and insights"
  - "Suggest follow-up analyses if relevant"
  - "Highlight any data quality concerns"
  - "Use bullet points and tables for better organization"

temperature: 0.3
max_tokens: 4000
model: "gpt-5-nano"
"""

print("✅ YAML Prompt configured!")

✅ YAML Prompt configured!


# Main

In [11]:
# Main Application (Google Colab Version)
import pandas as pd
import plotly.express as px
import numpy as np
import tempfile
import csv
import json
from typing import Dict, Any, List
from google.colab import files
import io

# Initialize tools
viz_tool = DataVisualizationTool()
preprocess_tool = DataPreprocessingTool()
stats_tool = StatisticalAnalysisTool()

print("📊 AI Data Analysis Agent")
print("Intelligent data analysis powered by AI - Upload your data and ask questions in natural language!")

# File upload section for Google Colab
print("\n📁 Upload Your Data")
print("Please upload a CSV or Excel file:")

uploaded = files.upload()

if uploaded:
    # Get the first uploaded file
    file_name = list(uploaded.keys())[0]
    file_content = uploaded[file_name]

    # Create a file-like object
    file_obj = io.BytesIO(file_content)
    file_obj.name = file_name

    # Preprocess and save the uploaded file
    temp_path, columns, df, error = preprocess_tool.preprocess_file(file_obj)

    if error:
        print(f"❌ Error: {error}")
    elif temp_path and columns and df is not None:
        # Display dataset information
        print(f"\n📊 Dataset Information:")
        print(f"- Rows: {len(df)}")
        print(f"- Columns: {len(df.columns)}")
        print(f"- Data Types: {len(df.dtypes.unique())}")

        # Data preview
        print(f"\n📋 Data Preview:")
        print(df.head(10))

        # Column information
        print(f"\n�� Column Information:")
        col_info = pd.DataFrame({
            'Column': df.columns,
            'Data Type': df.dtypes.astype(str),
            'Non-Null Count': df.count(),
            'Null Count': df.isnull().sum(),
            'Unique Values': df.nunique()
        })
        print(col_info)

        # Analysis section
        print("\n🔍 Data Analysis")
        print("Available analysis options:")
        print("1. Quick Insights")
        print("2. Auto Visualizations")
        print("3. Custom Analysis")

        # Quick insights
        print("\n�� Quick Insights:")
        try:
            # Comprehensive analysis
            all_stats = {}
            for analysis_type in ['descriptive', 'correlation', 'outliers']:
                results = stats_tool.analyze_data(df, analysis_type)
                all_stats[analysis_type] = results

            # Display insights
            if 'descriptive' in all_stats and 'info' in all_stats['descriptive']:
                info = all_stats['descriptive']['info']
                print(f"📊 Dataset Overview:")
                print(f"- Total records: {info['rows']}")
                print(f"- Complete records: {info['rows'] - info['missing_values']}")
                print(f"- Duplicate records: {info['duplicates']}")
                print(f"- Missing values: {info['missing_values']}")

            # Outlier detection
            if 'outliers' in all_stats and all_stats['outliers']:
                print(f"⚠️ Potential Outliers Detected:")
                for col, indices in all_stats['outliers'].items():
                    print(f"- {col}: {len(indices)} outliers")

            # High correlations
            if 'correlation' in all_stats and 'high_correlations' in all_stats['correlation']:
                high_corr = all_stats['correlation']['high_correlations']
                if high_corr:
                    print(f"🔗 Strong Correlations:")
                    for var1, var2, corr in high_corr:
                        print(f"- {var1} ↔ {var2}: {corr:.3f}")

        except Exception as e:
            print(f"❌ Error generating insights: {str(e)}")

        # Auto visualization
        print("\n📈 Auto-Generated Visualizations:")
        try:
            # Create multiple visualizations based on data types
            numeric_cols = df.select_dtypes(include=[np.number]).columns
            categorical_cols = df.select_dtypes(include=['object']).columns

            if len(numeric_cols) > 0:
                # Histogram for first numeric column
                fig1 = viz_tool.create_visualization(
                    df, 'histogram', numeric_cols[0],
                    title=f"Distribution of {numeric_cols[0]}"
                )
                fig1.show()

            if len(numeric_cols) > 1:
                # Scatter plot for first two numeric columns
                fig2 = viz_tool.create_visualization(
                    df, 'scatter', numeric_cols[0], numeric_cols[1],
                    title=f"{numeric_cols[0]} vs {numeric_cols[1]}"
                )
                fig2.show()

            if len(categorical_cols) > 0 and len(numeric_cols) > 0:
                # Bar chart for categorical vs numeric
                fig3 = viz_tool.create_visualization(
                    df, 'bar', categorical_cols[0], numeric_cols[0],
                    title=f"{categorical_cols[0]} by {numeric_cols[0]}"
                )
                fig3.show()

        except Exception as e:
            print(f"❌ Error creating visualizations: {str(e)}")

        # Custom analysis
        print("\n�� Custom Analysis")
        print("Available chart types:", viz_tool.supported_charts)
        print("Available columns:", list(df.columns))

        # Example custom visualization
        if len(df.columns) > 1:
            print("\n📊 Example Custom Visualization:")
            chart_type = 'bar'
            x_column = df.columns[0]
            y_column = df.columns[1] if df.columns[1] in df.select_dtypes(include=[np.number]).columns else df.columns[0]

            fig = viz_tool.create_visualization(
                df, chart_type, x_column, y_column,
                title=f"Example {chart_type.title()} Chart"
            )
            fig.show()

else:
    print("❌ No file uploaded. Please upload a CSV or Excel file.")

# Sample data section
print("\n🧪 Sample Data for Testing")
print("Generating sample sales data...")

# Generate sample data
np.random.seed(42)
sample_data = {
    'Date': pd.date_range('2023-01-01', periods=100, freq='D'),
    'Region': np.random.choice(['North', 'South', 'East', 'West'], 100),
    'Product': np.random.choice(['Product A', 'Product B', 'Product C'], 100),
    'Sales': np.random.randint(1000, 10000, 100),
    'Quantity': np.random.randint(10, 100, 100),
    'Customer_Satisfaction': np.random.uniform(3.0, 5.0, 100).round(2)
}
df_sample = pd.DataFrame(sample_data)

print("✅ Sample data generated!")
print("📋 Sample Data Preview:")
print(df_sample.head(10))

# Footer
print("\n" + "="*50)
print("�� Powered by AI Data Analysis Agent | Built with PraisonAI")

📊 AI Data Analysis Agent
Intelligent data analysis powered by AI - Upload your data and ask questions in natural language!

📁 Upload Your Data
Please upload a CSV or Excel file:


Saving customers-100.csv to customers-100.csv

📊 Dataset Information:
- Rows: 100
- Columns: 12
- Data Types: 3

📋 Data Preview:
   Index      Customer Id First Name  Last Name  \
0      1  DD37Cf93aecA6Dc     Sheryl     Baxter   
1      2  1Ef7b82A4CAAD10    Preston     Lozano   
2      3  6F94879bDAfE5a6        Roy      Berry   
3      4  5Cef8BFA16c5e3c      Linda      Olsen   
4      5  053d585Ab6b3159     Joanna     Bender   
5      6  2d08FB17EE273F4      Aimee      Downs   
6      7  EA4d384DfDbBf77     Darren       Peck   
7      8  0e04AFde9f225dE      Brett     Mullen   
8      9  C2dE4dEEc489ae0     Sheryl     Meyers   
9     10  8C2811a503C7c5a   Michelle  Gallagher   

                           Company               City  \
0                  Rasmussen Group       East Leonard   
1                      Vega-Gentry  East Jimmychester   
2                    Murillo-Perry      Isabelborough   
3  Dominguez, Mcmillan and Donovan         Bensonview   
4         Martin, Lang a


�� Custom Analysis
Available chart types: ['bar', 'line', 'scatter', 'histogram', 'box', 'pie', 'heatmap', 'area']
Available columns: ['Index', 'Customer Id', 'First Name', 'Last Name', 'Company', 'City', 'Country', 'Phone 1', 'Phone 2', 'Email', 'Subscription Date', 'Website']

📊 Example Custom Visualization:



🧪 Sample Data for Testing
Generating sample sales data...
✅ Sample data generated!
📋 Sample Data Preview:
        Date Region    Product  Sales  Quantity  Customer_Satisfaction
0 2023-01-01   East  Product C   4627        98                   4.77
1 2023-01-02   West  Product B   6450        34                   3.06
2 2023-01-03  North  Product B   2663        27                   4.16
3 2023-01-04   East  Product B   6592        91                   3.88
4 2023-01-05   East  Product B   8392        75                   4.34
5 2023-01-06   West  Product B   2306        63                   3.66
6 2023-01-07  North  Product B   7776        44                   3.31
7 2023-01-08  North  Product C   6864        89                   4.96
8 2023-01-09   East  Product C   8526        70                   4.68
9 2023-01-10  South  Product B   9901        50                   4.72

�� Powered by AI Data Analysis Agent | Built with PraisonAI
