In [None]:
import os
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Any, Optional, Union
import json
import requests
from pydantic import BaseModel, Field
import pydantic_ai as pai
from io import BytesIO
import base64
import traceback

# Define the Pydantic AI models for structured LLM responses
class GraphRequest(BaseModel):
    """A request to create a graph from CSV data."""
    graph_type: str = Field(..., description="Type of graph to create (e.g., 'bar', 'line', 'scatter', 'histogram', 'boxplot', 'heatmap')")
    x_column: str = Field(..., description="Column name for the x-axis")
    y_column: Optional[str] = Field(None, description="Column name for the y-axis (optional for some plot types)")
    title: str = Field(..., description="Title of the graph")
    hue: Optional[str] = Field(None, description="Column name for color grouping (optional)")
    
class CSVQuestion(BaseModel):
    """A query about CSV data that can be answered with text or require a graph."""
    question_type: str = Field(..., description="Either 'text' for a text answer or 'graph' for a visualization")
    answer: Optional[str] = Field(None, description="The text answer if question_type is 'text'")
    graph_request: Optional[GraphRequest] = Field(None, description="The graph specification if question_type is 'graph'")

# Ollama LLM client
class OllamaClient:
    def init(self, model_name="llama3.1-8b-q4_0"):
        self.model_name = model_name
        self.api_url = "http://localhost:11434/api/generate"
        
    def generate(self, prompt, system_prompt=None):
        payload = {
            "model": self.model_name,
            "prompt": prompt,
            "stream": False
        }
        
        if system_prompt:
            payload["system"] = system_prompt
            
        try:
            response = requests.post(self.api_url, json=payload)
            response.raise_for_status()
            return response.json().get("response", "")
        except Exception as e:
            return f"Error generating response: {str(e)}"

# CSV Data Processor
class CSVDataProcessor:
    def init(self):
        self.df = None
        self.file_path = None
        self.columns = []
        self.summary = {}
        
    def load_csv(self, file_path):
        try:
            self.df = pd.read_csv(file_path)
            self.file_path = file_path
            self.columns = list(self.df.columns)
            self.generate_summary()
            return True, "CSV file loaded successfully."
        except Exception as e:
            return False, f"Error loading CSV file: {str(e)}"
    
    def generate_summary(self):
        """Generate a summary of the CSV data"""
        if self.df is None:
            return
        
        self.summary = {
            "num_rows": len(self.df),
            "num_columns": len(self.columns),
            "column_types": {col: str(self.df[col].dtype) for col in self.columns},
            "sample_data": self.df.head(5).to_dict(orient="records"),
            "numeric_columns": list(self.df.select_dtypes(include=['int64', 'float64']).columns),
            "categorical_columns": list(self.df.select_dtypes(include=['object', 'category']).columns),
            "missing_values": self.df.isna().sum().to_dict()
        }
        
        # Add basic statistics for numeric columns
        self.summary["statistics"] = {}
        for col in self.summary["numeric_columns"]:
            self.summary["statistics"][col] = {
                "min": float(self.df[col].min()) if not pd.isna(self.df[col].min()) else None,
                "max": float(self.df[col].max()) if not pd.isna(self.df[col].max()) else None,
                "mean": float(self.df[col].mean()) if not pd.isna(self.df[col].mean()) else None,
                "median": float(self.df[col].median()) if not pd.isna(self.df[col].median()) else None
            }
        
        # Add basic statistics for categorical columns
        for col in self.summary["categorical_columns"]:
            value_counts = self.df[col].value_counts().head(5).to_dict()
            self.summary["statistics"][col] = {
                "unique_values": self.df[col].nunique(),
                "top_values": {str(k): int(v) for k, v in value_counts.items()}
            }

# LLM Query Processor
class LLMQueryProcessor:
    def init(self, csv_processor: CSVDataProcessor):
        self.csv_processor = csv_processor
        self.llm_client = OllamaClient()
        self.system_prompt = """
        You are a data analysis assistant that interprets questions about CSV data and provides 
        answers based on analysis of the data. For each question, determine whether the answer 
        should be textual or requires a graph.
        
        If the question requires analysis that can be answered with text, generate a concise, accurate 
        answer based on the data.
        
        If the question would benefit from a visualization, specify the appropriate graph type, 
        columns, and styling parameters.
        
        Always return your response as valid JSON matching the CSVQuestion schema.
        """
        
    def process_query(self, query: str) -> CSVQuestion:
        if self.csv_processor.df is None:
            return CSVQuestion(
                question_type="text",
                answer="Please upload a CSV file first."
            )
        
        try:
            # Prepare context about the data
            df_info = json.dumps(self.csv_processor.summary, default=str)
            prompt = f"""
            # CSV Data Information
            {df_info}
            
            # User Question
            {query}
            
            # Response Instructions
            Analyze the question and the CSV data information. Then, respond with a valid JSON object 
            matching the CSVQuestion schema which includes either:
            1. For analysis questions: question_type="text" and answer=<your analytical answer>
            2. For visualization questions: question_type="graph" and graph_request containing graph_type, x_column, y_column (if applicable), title, and hue (if applicable)
            
            The graph types supported are: bar, line, scatter, histogram, boxplot, heatmap.
            
            # CSVQuestion Schema
            
            class GraphRequest:
                graph_type: str  # Type of graph (bar, line, scatter, histogram, boxplot, heatmap)
                x_column: str  # Column for x-axis
                y_column: Optional[str]  # Column for y-axis (if applicable)
                title: str  # Graph title
                hue: Optional[str]  # Column for color grouping (if applicable)
                
            class CSVQuestion:
                question_type: str  # Either 'text' or 'graph'
                answer: Optional[str]  # Text answer if question_type is 'text'
                graph_request: Optional[GraphRequest]  # Graph specification if question_type is 'graph'
            
            
            Return only the JSON object, without any additional text:
            """
            
            # Get response from LLM
            llm_response = self.llm_client.generate(prompt, self.system_prompt)
            
            # Parse the response as JSON
            response_dict = json.loads(llm_response.strip())
            
            # Construct the CSVQuestion model
            if response_dict.get("question_type") == "graph" and response_dict.get("graph_request"):
                graph_data = response_dict["graph_request"]
                graph_request = GraphRequest(
                    graph_type=graph_data["graph_type"],
                    x_column=graph_data["x_column"],
                    y_column=graph_data.get("y_column"),
                    title=graph_data["title"],
                    hue=graph_data.get("hue")
                )
                result = CSVQuestion(
                    question_type="graph",
                    graph_request=graph_request
                )
            else:
                result = CSVQuestion(
                    question_type="text",
                    answer=response_dict.get("answer", "I couldn't analyze this question properly.")
                )
            
            return result
            
        except Exception as e:
            error_details = traceback.format_exc()
            print(f"Error processing query: {error_details}")
            return CSVQuestion(
                question_type="text",
                answer=f"Error processing your query: {str(e)}"
            )

# Graph Generator
class GraphGenerator:
    def init(self, csv_processor: CSVDataProcessor):
        self.csv_processor = csv_processor
        
    def create_graph(self, graph_request: GraphRequest) -> Optional[str]:
        if self.csv_processor.df is None:
            return None
            
        try:
            plt.figure(figsize=(10, 6))
            df = self.csv_processor.df
            
            # Ensure columns exist in dataframe
            if graph_request.x_column not in df.columns:
                return None
                
            if graph_request.y_column and graph_request.y_column not in df.columns:
                return None
                
            if graph_request.hue and graph_request.hue not in df.columns:
                graph_request.hue = None
                
            # Create the appropriate graph based on the request
            if graph_request.graph_type == "bar":
                if graph_request.y_column:
                    sns.barplot(x=df[graph_request.x_column], y=df[graph_request.y_column], 
                              hue=df[graph_request.hue] if graph_request.hue else None)
                else:
                    df[graph_request.x_column].value_counts().plot(kind='bar')
                    
            elif graph_request.graph_type == "line":
                if graph_request.y_column:
                    sns.lineplot(x=df[graph_request.x_column], y=df[graph_request.y_column], 
                               hue=df[graph_request.hue] if graph_request.hue else None)
                else:
                    df[graph_request.x_column].plot(kind='line')
                    
            elif graph_request.graph_type == "scatter":
                if graph_request.y_column:
                    sns.scatterplot(x=df[graph_request.x_column], y=df[graph_request.y_column], 
                                  hue=df[graph_request.hue] if graph_request.hue else None)
                else:
                    return None  # Scatter plot requires both x and y
                    
            elif graph_request.graph_type == "histogram":
                sns.histplot(df[graph_request.x_column], kde=True)
                
            elif graph_request.graph_type == "boxplot":
                if graph_request.y_column:
                    sns.boxplot(x=df[graph_request.x_column], y=df[graph_request.y_column], 
                              hue=df[graph_request.hue] if graph_request.hue else None)
                else:
                    sns.boxplot(x=df[graph_request.x_column])
                    
            elif graph_request.graph_type == "heatmap":
                if graph_request.y_column:
                    # Create a pivot table for the heatmap
                    heatmap_data = df.pivot_table(
                        index=graph_request.x_column,
                        columns=graph_request.y_column,
                        values=graph_request.hue if graph_request.hue else df.select_dtypes(include=['number']).columns[0],
                        aggfunc='mean'
                    )
                    sns.heatmap(heatmap_data, annot=True, cmap="YlGnBu")
                else:
                    # If only x is provided, create a correlation heatmap of numeric columns
                    correlation = df.select_dtypes(include=['number']).corr()
                    sns.heatmap(correlation, annot=True, cmap="coolwarm")
            
            plt.title(graph_request.title)
            plt.xticks(rotation=45)
            plt.tight_layout()
            
            # Convert plot to base64 image
            buffer = BytesIO()
            plt.savefig(buffer, format='png')
            buffer.seek(0)
            image_png = buffer.getvalue()
            plt.close()
            
            return base64.b64encode(image_png).decode('utf-8')
            
        except Exception as e:
            print(f"Error creating graph: {str(e)}")
            return None

# Gradio Application
class GradioCSVApp:
    def init(self):
        self.csv_processor = CSVDataProcessor()
        self.llm_processor = LLMQueryProcessor(self.csv_processor)
        self.graph_generator = GraphGenerator(self.csv_processor)
        
    def upload_csv(self, file):
        if file is None:
            return "Please upload a CSV file.", None
            
        success, message = self.csv_processor.load_csv(file.name)
        if success:
            # Generate and return summary information
            columns = ", ".join(self.csv_processor.columns)
            num_rows = self.csv_processor.summary["num_rows"]
            num_cols = self.csv_processor.summary["num_columns"]
            return f"✅ CSV loaded successfully: {num_rows} rows, {num_cols} columns.\n\nColumns: {columns}", None
        else:
            return f"❌ {message}", None
    
    def process_question(self, question, state):
        if self.csv_processor.df is None:
            return "Please upload a CSV file first.", None, state
            
        try:
            # Process the question using the LLM
            result = self.llm_processor.process_query(question)
            
            if result.question_type == "text":
                return result.answer, None, state
                
            elif result.question_type == "graph":
                # Generate the graph from the request
                graph_image = self.graph_generator.create_graph(result.graph_request)
                
                if graph_image:
                    graph_description = (
                        f"📊 Graph: {result.graph_request.title}\n"
                        f"Type: {result.graph_request.graph_type}\n"
                        f"X-axis: {result.graph_request.x_column}\n"
                        f"Y-axis: {result.graph_request.y_column if result.graph_request.y_column else 'N/A'}\n"
                        f"Color grouping: {result.graph_request.hue if result.graph_request.hue else 'N/A'}"
                    )
                    return graph_description, graph_image, state
                else:
                    return "Failed to create the requested graph. Please check your question and try again.", None, state
        
        except Exception as e:
            return f"Error: {str(e)}", None, state

    def launch(self):
        with gr.Blocks(title="CSV Question Answering & Visualization", theme=gr.themes.Soft()) as app:
            gr.Markdown("# CSV Question Answering & Visualization")
            gr.Markdown("Upload a CSV file, then ask questions about the data. The system can provide text answers or generate visualizations.")
            
            with gr.Row():
                with gr.Column(scale=1):
                    file_input = gr.File(label="Upload CSV File (max 25MB)")
                    upload_button = gr.Button("Upload and Process")
                    file_info = gr.Textbox(label="File Information", interactive=False)
                    
                with gr.Column(scale=2):
                    question_input = gr.Textbox(label="Ask a question about your data", placeholder="e.g., What is the average price? or Show me a histogram of prices")
                    submit_button = gr.Button("Submit Question")
                    
                    answer_output = gr.Textbox(label="Answer", interactive=False)
                    graph_output = gr.Image(label="Visualization", interactive=False)
            
            # Add state to maintain context
            state = gr.State({})
            
            # Set up event handlers
            upload_button.click(
                fn=self.upload_csv,
                inputs=[file_input],
                outputs=[file_info, graph_output]
            )
            
            submit_button.click(
                fn=self.process_question,
                inputs=[question_input, state],
                outputs=[answer_output, graph_output, state]
            )
            
            # Examples
            gr.Examples(
                examples=[
                    ["What is the average price?"],
                    ["Show me a histogram of prices"],
                    ["What is the correlation between square footage and price?"],
                    ["Show me a scatter plot of price vs. square footage"],
                    ["What are the top 5 most expensive neighborhoods?"],
                    ["Create a bar chart showing average price by neighborhood"]
                ],
                inputs=question_input
            )
            
            gr.Markdown("""
            ## Tips for asking questions:
            - Ask for statistics: "What is the average/median/max of [column]?"
            - Ask for correlations: "Is there a correlation between [column1] and [column2]?"
            - Request visualizations: "Show me a [graph type] of [columns]"
            - Ask for trends: "How does [column1] change with [column2]?"
            """)
        
        return app.launch(share=False)

# Main application entry point
if _name_ == "_main_":
    app = GradioCSVApp()
    app.launch()