In [5]:
# Install required libraries
import bitsandbytes
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from scipy import stats
from typing import Callable, Dict
from datetime import datetime
import gradio as gr
import json
import re


In [1]:
# Hugging face login
from huggingface_hub import login
HF_TOKEN = "Replace with your actual token"
login(HF_TOKEN)


In [6]:
## Configure Mistral-7B-Instruct with 4-bit Quantization

# Enable GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

# Load model and tokenizer
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    token=HF_TOKEN
)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)


config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.10k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [None]:
# Function to call the LLM
def call_llm(prompt: str, max_tokens: int = 500) -> str:
    # Construct a prompt that strongly emphasizes JSON response format
    full_prompt = (
        f"[INST] {prompt}\n\n"
        f"IMPORTANT: Your response must be a valid JSON object only. Format your response as proper JSON.\n"
        f"DO NOT include any text outside the JSON object.\n"
        f"Example of correct format: {{\"key\": \"value\", \"otherKey\": 123}}\n"
        f"Do not use markdown formatting for JSON. [/INST]"
    )

    # Tokenize and generate response
    inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.1,  # Lower temperature for more predictable outputs
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

    # Decode the response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the part after [/INST]
    if "[/INST]" in response:
        response = response.split("[/INST]", 1)[1].strip()

    return response


In [8]:
## Defining Tools

# Tool 1: Load and Clean Dataset
def load_and_clean_data(file_path: str) -> dict:
    try:
        print(f"Attempting to load data from {file_path}")
        df = pd.read_csv(file_path)
        initial_rows = df.shape[0]

        # Cleaning
        df = df.drop_duplicates()  # Remove duplicates
        df = df.dropna(how='all')  # Drop rows with all NaN
        # Impute missing values for numeric columns with median
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns
        for col in numeric_cols:
            df[col] = df[col].fillna(df[col].median())
        # Convert object columns to categorical if unique values < 50%
        for col in df.select_dtypes(include=['object']).columns:
            if df[col].nunique() / len(df) < 0.5:
                df[col] = df[col].astype('category')
        # Remove outliers using IQR method
        for col in numeric_cols:
            Q1 = df[col].quantile(0.25)
            Q3 = df[col].quantile(0.75)
            IQR = Q3 - Q1
            df = df[~((df[col] < (Q1 - 1.5 * IQR)) | (df[col] > (Q3 + 1.5 * IQR)))]

        cleaned_rows = df.shape[0]
        return {
            "data": df,
            "status": "success",
            "message": f"Loaded and cleaned dataset from {file_path}. Rows: {initial_rows} -> {cleaned_rows}. Removed duplicates, imputed missing values, and handled outliers."
        }
    except Exception as e:
        return {"data": None, "status": "error", "message": f"Failed to load/clean data: {str(e)}"}

# Tool 2: Perform EDA
def perform_eda(df: pd.DataFrame) -> dict:
    try:
        if df is None or df.empty:
            return {"status": "error", "message": "No data available for EDA.", "summary_stats": {}, "skewness": {}, "missing_values": {}, "outliers": {}}

        # Summary statistics
        summary_stats = df.describe(include='all').to_dict()

        # Missing values
        missing_values = df.isnull().sum().to_dict()

        # Skewness for numeric columns
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns
        skewness = {col: float(stats.skew(df[col].dropna())) for col in numeric_cols}

        # Outlier detection using IQR
        outliers = {}
        for col in numeric_cols:
            Q1 = df[col].quantile(0.25)
            Q3 = df[col].quantile(0.75)
            IQR = Q3 - Q1
            outliers[col] = len(df[(df[col] < (Q1 - 1.5 * IQR)) | (df[col] > (Q3 + 1.5 * IQR))])

        # Correlation matrix for numeric columns
        corr_matrix = df[numeric_cols].corr().to_dict()

        return {
            "status": "success",
            "summary_stats": summary_stats,
            "skewness": skewness,
            "missing_values": missing_values,
            "outliers": outliers,
            "corr_matrix": corr_matrix,
            "message": "Comprehensive EDA completed."
        }
    except Exception as e:
        return {"status": "error", "message": f"EDA failed: {str(e)}", "summary_stats": {}, "skewness": {}, "missing_values": {}, "outliers": {}, "corr_matrix": {}}

# Tool 3: Generate Visualizations
def generate_plot(df: pd.DataFrame, plot_type: str = "auto", x_col: str = None, y_col: str = None, annot: bool = False, **kwargs) -> dict:
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns
    categorical_cols = df.select_dtypes(include=['category', 'object']).columns

    # Auto-select plot type if not specified
    if plot_type == "auto":
        if x_col and y_col:
            plot_type = "scatter" if y_col in numeric_cols else "bar"
        elif x_col:
            plot_type = "histogram" if x_col in numeric_cols else "count"
        else:
            plot_type = "heatmap"  # Default to correlation heatmap

    # Validate columns
    if x_col and x_col not in df.columns:
        return {"status": "error", "message": f"Column '{x_col}' not found in dataset."}
    if y_col and y_col not in df.columns:
        return {"status": "error", "message": f"Column '{y_col}' not found in dataset."}

    plt.figure(figsize=(16, 10))  # Increased size for better visibility
    try:
        if plot_type.lower() == "histogram" and x_col:
            sns.histplot(data=df, x=x_col, kde=True)
        elif plot_type.lower() == "scatter" and x_col and y_col:
            sns.scatterplot(data=df, x=x_col, y=y_col, hue=categorical_cols[0] if len(categorical_cols) > 0 else None)
        elif plot_type.lower() == "line" and x_col and y_col:
            sns.lineplot(data=df, x=x_col, y=y_col)
        elif plot_type.lower() == "box" and x_col:
            # Use y=x_col for vertical box plot to show distribution
            sns.boxplot(data=df, y=x_col, color='skyblue')
            # Add grid for better readability
            plt.grid(True, axis='y', linestyle='--', alpha=0.7)
        elif plot_type.lower() == "count" and x_col:
            sns.countplot(data=df, x=x_col)
        elif plot_type.lower() == "heatmap":
            if x_col or y_col:
                return {"status": "error", "message": "Heatmap does not use specific x_col or y_col; it shows correlation matrix."}

            # Compute correlation
            corr_matrix = df[numeric_cols].corr()

            # Generate heatmap
            sns.heatmap(corr_matrix, annot=annot, cmap="coolwarm", linewidths=0.5, vmin=-1, vmax=1,
                        square=True, cbar_kws={"shrink": 0.75})

            # Fix x and y labels for better readability
            plt.xticks(rotation=45, ha="right", fontsize=10)
            plt.yticks(fontsize=10)

        else:
            return {"status": "error", "message": f"Invalid plot type '{plot_type}' or missing required columns."}

        plt.title(f"{plot_type.capitalize()} Plot: {x_col if x_col else 'Heatmap'} {f'vs {y_col}' if y_col else ''}", fontsize=14, pad=20)
        plt.tight_layout()

        # Save plot
        file_name = f"plot_{plot_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
        plt.savefig(file_name)
        plt.close()

        return {"status": "success", "message": f"Generated {plot_type} plot.", "file": file_name}

    except Exception as e:
        return {"status": "error", "message": f"Plot generation failed: {str(e)}"}

# Tool 4: Correlation Analysis
def calculate_correlation(df: pd.DataFrame, col1: str = None, col2: str = None, **kwargs) -> dict:
    try:
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns
        if not col1 or not col2:
            col1, col2 = numeric_cols[:2] if len(numeric_cols) >= 2 else (None, None)
        if col1 not in numeric_cols or col2 not in numeric_cols:
            return {"status": "error", "message": "Columns must be numeric for correlation."}

        corr = df[col1].corr(df[col2])
        corr_matrix = df[numeric_cols].corr().to_dict()
        return {
            "correlation": float(corr),
            "col1": col1,
            "col2": col2,
            "corr_matrix": corr_matrix,
            "status": "success",
            "message": f"Correlation between {col1} and {col2} calculated with full matrix."
        }
    except Exception as e:
        return {"status": "error", "message": f"Correlation failed: {str(e)}", "correlation": None, "corr_matrix": {}}

# Tool 5: Comprehensive Report Generation
def generate_report(df: pd.DataFrame, eda_results: dict, analysis_results: dict, plot_results: list) -> dict:
    report_lines = [
        "=== Data Science Report ===",
        f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
        "\n1. Dataset Overview:",
        f"  - Rows: {df.shape[0]}, Columns: {df.shape[1]}",
        f"  - Columns: {', '.join(df.columns)}",
        "\n2. Exploratory Data Analysis:",
        f"  - Summary Stats: {json.dumps(eda_results.get('summary_stats', {}), indent=2)}",
        f"  - Missing Values: {json.dumps(eda_results.get('missing_values', {}), indent=2)}",
        f"  - Skewness: {json.dumps(eda_results.get('skewness', {}), indent=2)}",
        f"  - Outliers: {json.dumps(eda_results.get('outliers', {}), indent=2)}",
        "\n3. Correlation Analysis:",
        f"  - Specific Correlation: {analysis_results.get('correlation', 'N/A')} between {analysis_results.get('col1', 'N/A')} and {analysis_results.get('col2', 'N/A')}",
        f"  - Full Correlation Matrix: {json.dumps(analysis_results.get('corr_matrix', {}), indent=2)}",
        "\n4. Visualizations:",
        "\n  - " + "\n  - ".join(plot_results or ["No plots generated."]),
        "\n=== End of Report ==="
    ]
    return {"report": "\n".join(report_lines), "status": "success", "message": "Comprehensive report generated."}


In [9]:
## Defining Pydantic Models

# Data Ingestion Agent Output
class DataIngestionOutput(BaseModel):
    status: str
    message: str
    rows: int | None = None
    columns: int | None = None

# EDA Agent Output
class EDAOutput(BaseModel):
    status: str
    message: str
    summary_stats: dict = {}
    skewness: dict = {}
    missing_values: dict = {}
    outliers: dict = {}
    corr_matrix: dict = {}

# Visualization Agent Output
class VisualizationOutput(BaseModel):
    status: str
    message: str
    plot_type: str | None = None
    x_column: str | None = None
    y_column: str | None = None

# Analysis Agent Output
class AnalysisOutput(BaseModel):
    status: str
    message: str
    correlation: float | None = None
    col1: str | None = None
    col2: str | None = None
    corr_matrix: dict = {}

# Report Generation Agent Output
class ReportOutput(BaseModel):
    status: str
    message: str
    report_text: str

# Orchestrator Agent Output
class OrchestratorOutput(BaseModel):
    status: str
    message: str
    target_agent: str
    parameters: dict


In [10]:
## Implementing Agents

# Create agent function
def create_agent(role: str, tools: Dict[str, Callable], output_model: type[BaseModel]) -> Callable:
    def agent(query: str, context: dict) -> BaseModel:
        tool_names = list(tools.keys())
        parameters = context.get("parameters", {})

        prompt = (
            f"You are a {role} agent. Available tools: {', '.join(tool_names)}.\n"
            f"User query: {query}\n\n"
            f"Decide which tool to use based on the query. Respond with a JSON object:\n"
            f"- 'action': the name of the tool to use\n"
            f"- 'parameters': dict of parameters needed for the tool\n"
            f"- 'message': explanation of your decision\n"
        )
        llm_response = call_llm(prompt)
        json_response = extract_json(llm_response)

        if json_response and "action" in json_response:
            action = json_response.get("action")
            tool_params = {**parameters, **json_response.get("parameters", {})}

            if action in tools:
                try:
                    if role == "data ingestion" and action == "load_and_clean_data":
                        file_path = tool_params.get("file_path")
                        if file_path:
                            result = load_and_clean_data(file_path)
                            if result["status"] == "success" and result["data"] is not None:
                                context["df"] = result["data"]
                            return DataIngestionOutput(
                                status=result["status"],
                                message=result["message"],
                                rows=result["data"].shape[0] if result["status"] == "success" else None,
                                columns=result["data"].shape[1] if result["status"] == "success" else None
                            )
                    elif role == "EDA" and action == "perform_eda":
                        if "df" in context and context["df"] is not None:
                            result = perform_eda(context["df"])
                            context["eda_results"] = result
                            return EDAOutput(**result)
                    else:
                        result = tools[action](**tool_params)
                        return output_model(**result)
                except Exception as e:
                    return output_model(status="error", message=f"Error executing {action}: {str(e)}")
        return output_model(status="error", message="Failed to process request.")

    return agent

# Extract json function
def extract_json(response: str) -> dict:
    try:
        return json.loads(response)
    except:
        json_pattern = r'(\{.*\})'
        match = re.search(json_pattern, response, re.DOTALL)
        if match:
            json_str = match.group(1)
            json_str = re.sub(r"'([^']*)':", r'"\1":', json_str)
            json_str = re.sub(r'([{,])\s*([a-zA-Z0-9_]+):', r'\1"\2":', json_str)
            return json.loads(json_str)
        return {}

# Define agents
data_agent = create_agent("data ingestion", {"load_and_clean_data": load_and_clean_data}, DataIngestionOutput)
eda_agent = create_agent("EDA", {"perform_eda": perform_eda}, EDAOutput)
visualization_agent = create_agent("visualization", {"generate_plot": generate_plot}, VisualizationOutput)
analysis_agent = create_agent("analysis", {"calculate_correlation": calculate_correlation}, AnalysisOutput)
report_agent = create_agent("report generation", {"generate_report": generate_report}, ReportOutput)


In [11]:
# Implement the Orchestrator Agent
def orchestrator_agent(query: str, context: dict) -> OrchestratorOutput:
    has_data = context.get("df") is not None

    prompt = (
        f"You are a data science orchestration system. Based on the user query, determine which "
        f"specialized agent should handle the request.\n\n"
        f"User query: \"{query}\"\n"
        f"Data loaded: {'Yes' if has_data else 'No'}\n\n"
        f"Available agents:\n"
        f"- 'data': For loading datasets\n"
        f"- 'eda': For exploratory data analysis\n"
        f"- 'visualization': For creating plots\n"
        f"- 'analysis': For statistical analysis\n"
        f"- 'report': For generating reports (automatically perform prior steps if needed)\n\n"
        f"Respond with a JSON object: {{'target_agent': 'name', 'message': 'explanation'}}"
    )

    llm_response = call_llm(prompt)
    json_response = extract_json(llm_response)

    if not json_response or "target_agent" not in json_response:
        return OrchestratorOutput(status="error", message="Invalid request.", target_agent="none", parameters={})

    target_agent = json_response.get("target_agent", "none").lower()
    message = json_response.get("message", "Processing...")
    parameters = {}

    if target_agent == "data":
        file_path_match = re.search(r'(/\S+\.\w+)', query)
        if file_path_match:
            parameters["file_path"] = file_path_match.group(1)
    elif target_agent == "visualization":
        # Extract plot type from query
        plot_type = "auto"
        if "scatter" in query.lower():
            plot_type = "scatter"
        elif "line" in query.lower():
            plot_type = "line"
        elif "histogram" in query.lower():
            plot_type = "histogram"
        elif "box" in query.lower():
            plot_type = "box"
        elif "count" in query.lower():
            plot_type = "count"
        elif "heatmap" in query.lower():
            plot_type = "heatmap"

        parameters["plot_type"] = plot_type

        # Extract column names from query (case-insensitive)
        df_cols = context.get("df", pd.DataFrame()).columns
        query_lower = query.lower()
        columns = []
        for col in df_cols:
            if col.lower() in query_lower:
                columns.append(col)

        if len(columns) >= 1:
            parameters["x_col"] = columns[0]
        if len(columns) >= 2:
            parameters["y_col"] = columns[1]

        # If no columns specified but plot type requires them, use defaults
        if not parameters.get("x_col") and plot_type in ["scatter", "line", "histogram", "box", "count"]:
            numeric_cols = context["df"].select_dtypes(include=['float64', 'int64']).columns
            parameters["x_col"] = numeric_cols[0] if len(numeric_cols) > 0 else None
            if plot_type in ["scatter", "line"] and len(numeric_cols) > 1:
                parameters["y_col"] = numeric_cols[1]

    elif target_agent == "analysis":
        # Extract column names from query for correlation
        df_cols = context.get("df", pd.DataFrame()).columns
        query_lower = query.lower()
        columns = [col for col in df_cols if col.lower() in query_lower]
        if len(columns) >= 2:
            parameters["col1"] = columns[0]
            parameters["col2"] = columns[1]

    return OrchestratorOutput(status="success", message=message, target_agent=target_agent, parameters=parameters)


In [12]:
# Main Inference with Gradio
def run_data_science_assistant_gradio(query, file=None):
    context = {"df": None, "eda_results": {}, "analysis_results": {}, "plot_results": []}
    output_text = []
    output_images = []

    # Handle file upload if provided
    if file is not None:
        file_path = file.name  # Gradio uploads files to a temp location
        result = load_and_clean_data(file_path)
        if result["status"] == "success":
            context["df"] = result["data"]
            output_text.append(f"Data Agent: {result['message']}")
        else:
            output_text.append(f"Data Agent: {result['message']}")
            return "\n".join(output_text), []

    # Process the query if no file or after file is loaded
    if query:
        orch_output = orchestrator_agent(query, context)
        output_text.append(f"Orchestrator: {orch_output.message}")
        target_agent = orch_output.target_agent
        parameters = orch_output.parameters

        if target_agent == "data":
            if file is None:
                output_text.append("Assistant: Please upload a dataset file first.")

        elif target_agent == "eda":
            if context["df"] is None or context["df"].empty:
                output_text.append("Assistant: Please upload a dataset first.")
            else:
                result = perform_eda(context["df"])
                context["eda_results"] = result
                output_text.append(f"Assistant: {result['message']}")
                if result["status"] == "success":
                    output_text.append("Key Insights:")
                    for col, stats in result["summary_stats"].items():
                        if isinstance(stats, dict):
                            output_text.append(f"  {col}: mean={stats.get('mean', 'N/A'):.2f}, missing={result['missing_values'].get(col, 0)}")

        elif target_agent == "visualization":
            if context["df"] is None or context["df"].empty:
                output_text.append("Assistant: Please upload a dataset first.")
            else:
                result = generate_plot(context["df"], **parameters)
                context["plot_results"].append(result["message"])
                output_text.append(f"Assistant: {result['message']}")
                if result["status"] == "success":
                    # Find the generated plot file
                    import glob
                    plot_files = glob.glob(f"plot_{parameters.get('plot_type', 'auto')}*.png")
                    if plot_files:
                        output_images.append(plot_files[-1])  # Add the latest plot

        elif target_agent == "analysis":
            if context["df"] is None or context["df"].empty:
                output_text.append("Assistant: Please upload a dataset first.")
            else:
                result = calculate_correlation(context["df"], **parameters)
                context["analysis_results"] = result
                if result["status"] == "success":
                    output_text.append(f"Assistant: Correlation between {result['col1']} and {result['col2']} is {result['correlation']:.4f}")
                    output_text.append("Full Correlation Matrix:")
                    for col1, corrs in result["corr_matrix"].items():
                        for col2, corr in corrs.items():
                            output_text.append(f"  {col1} vs {col2}: {corr:.4f}")
                else:
                    output_text.append(f"Assistant: {result['message']}")

        elif target_agent == "report":
            if context["df"] is None or context["df"].empty:
                output_text.append("Assistant: No data loaded. Please upload a dataset first.")
            else:
                if not context["eda_results"]:
                    context["eda_results"] = perform_eda(context["df"])
                    output_text.append("Assistant: Performed EDA for report.")
                if not context["analysis_results"]:
                    context["analysis_results"] = calculate_correlation(context["df"])
                    output_text.append("Assistant: Performed correlation analysis for report.")
                if not context["plot_results"]:
                    plot_result = generate_plot(context["df"])
                    context["plot_results"].append(plot_result["message"])
                    output_text.append("Assistant: Generated a default plot for report.")
                    import glob
                    plot_files = glob.glob("plot_auto*.png")
                    if plot_files:
                        output_images.append(plot_files[-1])
                result = generate_report(context["df"], context["eda_results"], context["analysis_results"], context["plot_results"])
                output_text.append(f"Assistant: {result['message']}")
                output_text.append(result["report"])

    return "\n".join(output_text), output_images

# Gradio Interface
with gr.Blocks(title="Multi-Agent System for Data Analysis") as demo:
    gr.Markdown("# Multi-Agent System for Data Analysis")
    gr.Markdown("Upload a dataset and enter a query to perform data science tasks. Examples: 'perform EDA', 'scatter plot', 'correlation analysis', 'generate report'.")

    with gr.Row():
        with gr.Column():
            file_input = gr.File(label="Upload Dataset (CSV)")
            query_input = gr.Textbox(label="Enter your query", placeholder="e.g., 'perform EDA' or 'scatter plot'")
            submit_btn = gr.Button("Submit")
        with gr.Column():
            output_text = gr.Textbox(label="Output", lines=20)
            output_gallery = gr.Gallery(label="Visualizations")

    submit_btn.click(
        fn=run_data_science_assistant_gradio,
        inputs=[query_input, file_input],
        outputs=[output_text, output_gallery]
    )

# Launch the Gradio app
demo.launch(share=True)


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://1ca1599af3a95d63cc.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


