<a href="https://colab.research.google.com/github/CodeWithSridhar/gdi-agent/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import json
from typing import Any
from google.genai import types
import google.genai as genai
from google.genai.types import Tool, GenerativeContentBlob

# Initialize Gemini 2.0 Flash
genai.configure(api_key=os.environ.get("GEMINI_API_KEY"))
model = genai.GenerativeModel("gemini-2.0-flash-exp")

class DataUnderstandingAgent:
    """Agent for data profiling and validation"""
    def __init__(self):
        self.name = "Data Understanding Agent"
        self.tools = self._create_tools()

    def _create_tools(self):
        """Create tools for data analysis"""
        return [
            Tool.function_tool(
                name="analyze_data_types",
                description="Analyze data types and structure",
                callback=self.analyze_data_types
            ),
            Tool.function_tool(
                name="check_missing_values",
                description="Check for missing values in dataset",
                callback=self.check_missing_values
            )
        ]

    def analyze_data_types(self, file_path: str) -> dict:
        """Analyze data types in CSV/Excel"""
        import pandas as pd
        df = pd.read_csv(file_path)
        return {
            "columns": df.columns.tolist(),
            "dtypes": df.dtypes.astype(str).to_dict(),
            "shape": df.shape,
            "summary": df.describe().to_dict()
        }

    def check_missing_values(self, file_path: str) -> dict:
        """Check for missing values"""
        import pandas as pd
        df = pd.read_csv(file_path)
        missing = df.isnull().sum().to_dict()
        return {
            "missing_counts": missing,
            "missing_percentage": {k: (v/len(df)*100) for k, v in missing.items()}
        }

class InsightAgent:
    """Agent for ML insights and EDA"""
    def __init__(self):
        self.name = "Insight Agent"
        self.tools = self._create_tools()

    def _create_tools(self):
        return [
            Tool.function_tool(
                name="generate_insights",
                description="Generate statistical insights",
                callback=self.generate_insights
            ),
            Tool.function_tool(
                name="detect_trends",
                description="Detect trends in time-series data",
                callback=self.detect_trends
            )
        ]

    def generate_insights(self, file_path: str, target_column: str = None) -> dict:
        """Generate ML insights"""
        import pandas as pd
        import numpy as np
        df = pd.read_csv(file_path)

        insights = {
            "correlations": df.corr().to_dict() if df.select_dtypes(include=[np.number]).shape[1] > 0 else {},
            "outliers": df.quantile([0.25, 0.75]).to_dict(),
            "distribution": df.describe().to_dict()
        }
        return insights

    def detect_trends(self, file_path: str, date_column: str, value_column: str) -> dict:
        """Detect trends in temporal data"""
        import pandas as pd
        df = pd.read_csv(file_path)
        df[date_column] = pd.to_datetime(df[date_column])
        df = df.sort_values(date_column)

        return {
            "trend_direction": "increasing" if df[value_column].iloc[-1] > df[value_column].iloc[0] else "decreasing",
            "min_value": float(df[value_column].min()),
            "max_value": float(df[value_column].max()),
            "average_value": float(df[value_column].mean())
        }

class VisualizationAgent:
    """Agent for chart generation"""
    def __init__(self):
        self.name = "Visualization Agent"

    def generate_visualization_code(self, data_type: str, columns: list) -> str:
        """Generate matplotlib code for visualization"""
        if data_type == "histogram":
            return f"""
import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_csv(data_path)
df['{columns[0]}'].hist(bins=30)
plt.title('Distribution of {columns[0]}')
plt.savefig('histogram.png')
"""
        elif data_type == "scatter":
            return f"""
import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_csv(data_path)
plt.scatter(df['{columns[0]}'], df['{columns[1]}'])
plt.xlabel('{columns[0]}')
plt.ylabel('{columns[1]}')
plt.title('Scatter Plot')
plt.savefig('scatter.png')
"""
        return ""

class PolicyRecommendationAgent:
    """Agent for policy recommendations"""
    def __init__(self):
        self.name = "Policy Recommendation Agent"
        self.domain_policies = {
            "education": "Focus on resource allocation, teacher training, student performance",
            "health": "Focus on disease prevention, resource distribution, health outcomes",
            "sanitation": "Focus on waste management, hygiene practices, infrastructure improvement"
        }

    def generate_recommendations(self, insights: dict, domain: str) -> list:
        """Generate policy recommendations based on insights"""
        recommendations = []
        recommendations.append(f"Domain: {domain.upper()}")
        recommendations.append(self.domain_policies.get(domain, "General recommendations"))

        if "correlations" in insights:
            recommendations.append("Recommendation 1: Investigate high correlations for causal relationships")
        if "outliers" in insights:
            recommendations.append("Recommendation 2: Investigate outliers for data quality and anomalies")
        if "trend_direction" in insights:
            recommendations.append(f"Recommendation 3: Trend is {insights.get('trend_direction')}, adjust resources accordingly")

        return recommendations

class ReportGeneratorAgent:
    """Agent for final report generation"""
    def __init__(self):
        self.name = "Report Generator Agent"

    def generate_report(self, data_understanding: dict, insights: dict, recommendations: list) -> str:
        """Generate comprehensive report"""
        report = f"""
