In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
from crewai import Agent, Task, Crew, Process
from crewai_tools import BaseTool

import warnings
warnings.filterwarnings('ignore')

# --- 1. Define Custom Tools ---

class DataLoaderTool(BaseTool):
    """
    A tool to load data from a CSV file into a pandas DataFrame.
    """
    name: str = "Data Loader Tool"
    description: str = "Loads data from a specified CSV file path into a pandas DataFrame."
    df: pd.DataFrame = None # Store the DataFrame within the tool for access by other agents

    def _run(self, file_path: str) -> str:
        """
        Loads the CSV file.

        Args:
            file_path (str): The path to the CSV file.
        """
        try:
            self.df = pd.read_csv(file_path)
            # Make sure all columns are string type for easier manipulation
            self.df.columns = self.df.columns.astype(str)
            return f"Data loaded successfully from {file_path}. DataFrame shape: {self.df.shape}"
        except FileNotFoundError:
            return f"Error: File not found at {file_path}"
        except Exception as e:
            return f"An error occurred while loading data: {e}"

class DataAnalysisTool(BaseTool):
    """
    A tool to perform basic data analysis operations on a DataFrame.
    Assumes a DataFrame is available via the DataLoaderTool.
    """
    name: str = "Data Analysis Tool"
    description: str = "Performs basic data analysis: checks info, descriptive stats, and missing values. " \
                       "Can also handle missing values (mean, median, mode, drop) and identify column types."

    def _run(self, operation: str, column: str = None, strategy: str = None, columns_list: list = None,
              data_loader_tool: DataLoaderTool = None) -> str:
        """
        Executes a specified data analysis operation.

        Args:
            operation (str): The analysis operation to perform (e.g., 'info', 'describe', 'missing_values',
                             'handle_missing_values', 'get_column_types').
            column (str, optional): The specific column to operate on.
            strategy (str, optional): Strategy for handling missing values ('mean', 'median', 'mode', 'drop').
            columns_list (list, optional): List of columns for missing value handling or correlation heatmap.
            data_loader_tool (DataLoaderTool): An instance of DataLoaderTool containing the DataFrame.
        """
        if data_loader_tool is None or data_loader_tool.df is None:
            return "Error: DataFrame not available for analysis. Please ensure data is loaded."

        df = data_loader_tool.df

        if operation == 'info':
            info_buffer = []
            df.info(buf=info_buffer)
            return "--- Data Information ---\n" + "\n".join(info_buffer)
        elif operation == 'describe':
            return "--- Descriptive Statistics ---\n" + df.describe(include='all').to_string()
        elif operation == 'missing_values':
            missing_count = df.isnull().sum()
            missing_percentage = (df.isnull().sum() / len(df)) * 100
            return (f"--- Missing Values (Count) ---\n{missing_count.to_string()}\n\n"
                    f"--- Missing Values (Percentage) ---\n{missing_percentage.to_string()}")
        elif operation == 'handle_missing_values':
            if not strategy:
                return "Error: 'strategy' parameter is required for 'handle_missing_values' operation."

            df_before_handling = df.copy()
            initial_missing = df.isnull().sum().sum()

            if strategy == 'drop':
                rows_before = df.shape[0]
                data_loader_tool.df.dropna(inplace=True) # Update the actual DataFrame
                rows_after = data_loader_tool.df.shape[0]
                return f"Dropped rows with missing values. {rows_before - rows_after} rows removed."
            else:
                target_cols = df.columns if columns_list is None else columns_list
                report = []
                for col in target_cols:
                    if col in df.columns and df[col].isnull().any():
                        if pd.api.types.is_numeric_dtype(df[col]):
                            if strategy == 'mean':
                                fill_value = df[col].mean()
                                data_loader_tool.df[col].fillna(fill_value, inplace=True)
                                report.append(f"Filled missing values in '{col}' with mean: {fill_value:.2f}")
                            elif strategy == 'median':
                                fill_value = df[col].median()
                                data_loader_tool.df[col].fillna(fill_value, inplace=True)
                                report.append(f"Filled missing values in '{col}' with median: {fill_value:.2f}")
                            elif strategy == 'mode':
                                fill_value = df[col].mode()[0]
                                data_loader_tool.df[col].fillna(fill_value, inplace=True)
                                report.append(f"Filled missing values in '{col}' with mode: {fill_value:.2f}")
                            else:
                                report.append(f"Warning: Strategy '{strategy}' not supported for numeric column '{col}'. Skipping.")
                        else: # Categorical or other types
                            if strategy == 'mode':
                                fill_value = df[col].mode()[0]
                                data_loader_tool.df[col].fillna(fill_value, inplace=True)
                                report.append(f"Filled missing values in '{col}' with mode: '{fill_value}'")
                            else:
                                report.append(f"Warning: Strategy '{strategy}' not supported for non-numeric column '{col}'. Skipping.")
                final_missing = data_loader_tool.df.isnull().sum().sum()
                return "Missing values after handling:\n" + "\n".join(report) + f"\nTotal missing values before: {initial_missing}, after: {final_missing}"
        elif operation == 'get_column_types':
            numerical_cols = df.select_dtypes(include=np.number).columns.tolist()
            categorical_cols = df.select_dtypes(include='object').columns.tolist()
            return f"Numerical Columns: {numerical_cols}\nCategorical Columns: {categorical_cols}"
        else:
            return f"Error: Unknown operation '{operation}' for DataAnalysisTool."

class VisualizationTool(BaseTool):
    """
    A tool to create various data visualizations and save them as PNG files.
    Assumes a DataFrame is available via the DataLoaderTool.
    """
    name: str = "Visualization Tool"
    description: str = "Generates and saves various plots (histogram, scatter, bar, box, heatmap). " \
                       "Requires 'plot_type', 'x_col', and optionally 'y_col', 'hue_col', 'by_col', 'bins', 'title'."

    def _run(self, plot_type: str, x_col: str, y_col: str = None, hue_col: str = None,
              by_col: str = None, bins: int = 10, title: str = None,
              numerical_cols_for_heatmap: list = None,
              data_loader_tool: DataLoaderTool = None) -> str:
        """
        Creates and saves a visualization.

        Args:
            plot_type (str): Type of plot ('histogram', 'scatterplot', 'barplot', 'boxplot', 'heatmap').
            x_col (str): Column for the x-axis.
            y_col (str, optional): Column for the y-axis (for scatter, bar, box plots).
            hue_col (str, optional): Column for color-coding (for scatter plot).
            by_col (str, optional): Column for grouping (for box plot).
            bins (int, optional): Number of bins for histogram.
            title (str, optional): Custom plot title.
            numerical_cols_for_heatmap (list, optional): Specific numerical columns for heatmap.
            data_loader_tool (DataLoaderTool): An instance of DataLoaderTool containing the DataFrame.
        """
        if data_loader_tool is None or data_loader_tool.df is None:
            return "Error: DataFrame not available for visualization. Please ensure data is loaded."

        df = data_loader_tool.df
        output_dir = "visualizations"
        os.makedirs(output_dir, exist_ok=True)

        file_name = f"{output_dir}/{plot_type}_{x_col}{'_vs_'+y_col if y_col else ''}.png"

        plt.figure(figsize=(10, 6))
        plot_success = False

        try:
            if plot_type == 'histogram':
                if x_col not in df.columns or not pd.api.types.is_numeric_dtype(df[x_col]):
                    return f"Error: Cannot create histogram. Column '{x_col}' not found or not numeric."
                sns.histplot(df[x_col], bins=bins, kde=True)
                plt.title(title if title else f'Distribution of {x_col}')
                plt.xlabel(x_col)
                plt.ylabel('Frequency')
                plot_success = True
            elif plot_type == 'scatterplot':
                if x_col not in df.columns or y_col not in df.columns or \
                   not pd.api.types.is_numeric_dtype(df[x_col]) or not pd.api.types.is_numeric_dtype(df[y_col]):
                    return f"Error: Cannot create scatter plot. Columns '{x_col}' or '{y_col}' not found or not numeric."
                sns.scatterplot(data=df, x=x_col, y=y_col, hue=hue_col)
                plt.title(title if title else f'Scatter Plot of {y_col} vs {x_col}')
                plt.xlabel(x_col)
                plt.ylabel(y_col)
                plot_success = True
            elif plot_type == 'barplot':
                if x_col not in df.columns or y_col not in df.columns or not pd.api.types.is_numeric_dtype(df[y_col]):
                    return f"Error: Cannot create bar plot. Columns '{x_col}' or '{y_col}' not found or '{y_col}' not numeric."
                sns.barplot(data=df, x=x_col, y=y_col)
                plt.title(title if title else f'Bar Plot of {y_col} by {x_col}')
                plt.xlabel(x_col)
                plt.ylabel(y_col)
                plt.xticks(rotation=45, ha='right')
                plt.tight_layout()
                plot_success = True
            elif plot_type == 'boxplot':
                if x_col not in df.columns or not pd.api.types.is_numeric_dtype(df[x_col]):
                    return f"Error: Cannot create box plot. Column '{x_col}' not found or not numeric."
                sns.boxplot(data=df, x=by_col, y=x_col)
                plt.title(title if title else f'Box Plot of {x_col} {"by " + by_col if by_col else ""}')
                plt.xlabel(by_col if by_col else '')
                plt.ylabel(x_col)
                plt.xticks(rotation=45, ha='right')
                plt.tight_layout()
                plot_success = True
            elif plot_type == 'heatmap':
                cols_for_corr = numerical_cols_for_heatmap if numerical_cols_for_heatmap else df.select_dtypes(include=np.number).columns.tolist()
                if not cols_for_corr:
                    return "Error: No numerical columns found for heatmap."
                correlation_matrix = df[cols_for_corr].corr()
                sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f", linewidths=.5)
                plt.title(title if title else 'Correlation Heatmap of Numerical Features')
                plot_success = True
            else:
                return f"Error: Unknown plot type '{plot_type}'. Supported types: histogram, scatterplot, barplot, boxplot, heatmap."

            if plot_success:
                plt.grid(alpha=0.75)
                plt.savefig(file_name)
                plt.close() # Close the plot to free memory
                return f"Plot '{plot_type}' for column(s) '{x_col}{', '+y_col if y_col else ''}' saved to {file_name}"
            return "Plotting failed for an unknown reason."
        except Exception as e:
            plt.close() # Ensure plot is closed even on error
            return f"An error occurred while creating plot '{plot_type}': {e}"


# --- 2. Initialize Tools ---
data_loader = DataLoaderTool()
data_analysis = DataAnalysisTool()
data_visualizer = VisualizationTool()


# --- 3. Define Agents ---
# Ensure you have your LLM API key set as an environment variable (e.g., OPENAI_API_KEY)
# from langchain_openai import ChatOpenAI
# llm = ChatOpenAI(model_name="gpt-4", temperature=0) # Example for OpenAI, adjust as needed

# For demonstration, we'll use a placeholder LLM if not explicitly provided
# In a real scenario, replace this with your actual LLM setup
try:
    from langchain_openai import ChatOpenAI
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) # Use a suitable model
except ImportError:
    print("Warning: langchain_openai not found. Using a dummy LLM. Please install and configure your LLM.")
    class DummyLLM:
        def invoke(self, prompt):
            return "Dummy LLM response."
    llm = DummyLLM()


data_engineer_agent = Agent(
    role="Data Engineer",
    goal="Load raw data from files and ensure it's in a usable format for analysis.",
    backstory="You are a meticulous Data Engineer, skilled in handling various data formats "
              "and ensuring data integrity for downstream analysis. You are responsible for "
              "ingesting data and making it ready for the analysts.",
    tools=[data_loader],
    verbose=True,
    allow_delegation=False,
    llm=llm
)

data_analyst_agent = Agent(
    role="Senior Data Analyst",
    goal="Perform comprehensive exploratory data analysis, identify patterns, handle missing values, "
         "and summarize key insights from the dataset.",
    backstory="You are an experienced Data Analyst with a keen eye for detail. "
              "You excel at understanding data structures, cleaning data, and extracting meaningful "
              "information to inform decision-making.",
    tools=[data_analysis],
    verbose=True,
    allow_delegation=True, # Allow delegation to the Visualizer Agent
    llm=llm
)

data_visualizer_agent = Agent(
    role="Master Data Visualizer",
    goal="Create clear, compelling, and insightful visualizations from data to communicate findings effectively.",
    backstory="You are a highly creative and skilled Data Visualizer, proficient in Matplotlib and Seaborn. "
              "You know how to choose the right chart type to tell the data's story and make complex data "
              "understandable through visual representation.",
    tools=[data_visualizer],
    verbose=True,
    allow_delegation=False,
    llm=llm
)

# --- 4. Define Tasks ---

# Task 1: Load Data
load_data_task = Task(
    description=(
        "Load the 'sample_sales_data.csv' file using the DataLoaderTool. "
        "Confirm successful loading and report the DataFrame shape."
    ),
    expected_output="A confirmation message about successful data loading and DataFrame shape.",
    agent=data_engineer_agent,
    output_file="data_loading_report.txt"
)

# Task 2: Analyze Data
# This task will run after load_data_task.
analyze_data_task = Task(
    description=(
        "Perform initial data overview: get data info, descriptive statistics, and missing value counts/percentages. "
        "Then, identify numerical and categorical columns. After that, "
        "handle missing values in 'Missing_Col_Numeric' using 'median' strategy and "
        "in 'Missing_Col_Categorical' using 'mode' strategy. "
        "Finally, provide a summary of the data, the cleaning steps taken, and "
        "the identified column types after cleaning."
    ),
    expected_output="A comprehensive report detailing data info, descriptive stats, missing values (before/after handling), "
                    "and a list of numerical and categorical columns.",
    agent=data_analyst_agent,
    context=[load_data_task], # This task depends on the data being loaded
    output_file="data_analysis_report.txt"
)

# Task 3: Create Visualizations
create_visualizations_task = Task(
    description=(
        "Based on the analyzed data and column types: "
        "1. Create a histogram for 'Age' with 15 bins, titled 'Distribution of Customer Age'."
        "2. Create a histogram for 'Purchase_Amount' with 20 bins, titled 'Distribution of Purchase Amount'."
        "3. Create a scatter plot of 'Income' (x-axis) vs. 'Purchase_Amount' (y-axis), colored by 'Gender', "
        "   titled 'Income vs. Purchase Amount by Gender'."
        "4. Create a bar plot showing the average 'Purchase_Amount' for each 'Product_Category', "
        "   titled 'Average Purchase Amount by Product Category'."
        "5. Create a box plot of 'Purchase_Amount' grouped by 'Product_Category', "
        "   titled 'Purchase Amount Distribution by Product Category'."
        "6. Create a correlation heatmap for all numerical columns identified, "
        "   titled 'Correlation Heatmap of Numerical Features'."
        "Report the file paths of all generated plots."
    ),
    expected_output="A list of file paths to the generated PNG visualization files.",
    agent=data_visualizer_agent,
    context=[analyze_data_task], # This task depends on data analysis being complete
    output_file="visualization_report.txt"
)

# --- 5. Build the Crew ---
data_analysis_crew = Crew(
    agents=[data_engineer_agent, data_analyst_agent, data_visualizer_agent],
    tasks=[load_data_task, analyze_data_task, create_visualizations_task],
    process=Process.sequential, # Tasks will be executed in order
    verbose=True,
    output_log_file="crew_activity.log"
)

# --- Example Usage (Dummy Data Creation) ---
if __name__ == "__main__":
    # Create a dummy CSV file for demonstration
    # In a real scenario, you would have your own data.
    data = {
        'CustomerID': range(1, 101),
        'Age': np.random.randint(18, 70, 100),
        'Gender': np.random.choice(['Male', 'Female'], 100),
        'Income': np.random.randint(30000, 120000, 100),
        'Purchase_Amount': np.random.uniform(50, 1000, 100),
        'Product_Category': np.random.choice(['Electronics', 'Clothing', 'Home Goods', 'Books'], 100),
        'Satisfaction_Score': np.random.randint(1, 6, 100),
        'Missing_Col_Numeric': np.random.choice([np.nan, 10, 20, 30], 100, p=[0.1, 0.3, 0.3, 0.3]),
        'Missing_Col_Categorical': np.random.choice([np.nan, 'A', 'B', 'C'], 100, p=[0.15, 0.3, 0.3, 0.25])
    }
    dummy_df = pd.DataFrame(data)
    dummy_df.to_csv('sample_sales_data.csv', index=False)
    print("Created 'sample_sales_data.csv' for demonstration.")

    # Kick off the crew's work
    print("--- Starting Data Analysis Crew ---")
    inputs = {"file_path": "sample_sales_data.csv"} # Initial input for the crew
    result = data_analysis_crew.kickoff(inputs=inputs)
    print("\n--- Crew Execution Finished ---")
    print("Final Result from Crew:")
    print(result)
    print("\nCheck the 'visualizations' folder for generated plots and the '.txt' files for reports.")

