In [57]:
# !pip install gradio

In [58]:
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from pydantic import BaseModel, ValidationError

In [59]:
# Define a Pydantic model for user queries
class QueryModel(BaseModel):
    question: str

In [60]:
# Function to load and validate CSV file
def load_csv(file):
    try:
        df = pd.read_csv(file.name)
        df.columns = df.columns.str.lower()  # Normalize column names to lowercase
        return df
    except Exception as e:
        return f"Error loading CSV: {str(e)}"



In [61]:
# Function to answer general queries on any CSV file
def answer_query(file, question):
    df = load_csv(file)
    if isinstance(df, str):
        return df  # Return error message if CSV fails to load

    try:
        validated_query = QueryModel(question=question)
    except ValidationError as e:
        return f"Invalid question format: {e}"

    question = validated_query.question.lower()
    numeric_cols = df.select_dtypes(include=['number']).columns  # Detect numeric columns

    if numeric_cols.empty:
        return "No numeric columns found in the file. Cannot compute statistical queries."

    # Identify keywords
    if "average" in question or "mean" in question:
        col_name = extract_column_name(question, numeric_cols)
        if col_name:
            return f"Average {col_name}: {df[col_name].mean():,.2f}"
        return f"Available numeric columns: {', '.join(numeric_cols)}"

    elif "max" in question:
        col_name = extract_column_name(question, numeric_cols)
        if col_name:
            return f"Max {col_name}: {df[col_name].max():,.2f}"
        return f"Available numeric columns: {', '.join(numeric_cols)}"

    elif "min" in question:
        col_name = extract_column_name(question, numeric_cols)
        if col_name:
            return f"Min {col_name}: {df[col_name].min():,.2f}"
        return f"Available numeric columns: {', '.join(numeric_cols)}"

    elif "total" in question or "count" in question:
        return f"Total Rows: {len(df)}"

    elif "standard deviation" in question or "std" in question:
        col_name = extract_column_name(question, numeric_cols)
        if col_name:
            return f"Standard Deviation of {col_name}: {df[col_name].std():,.2f}"
        return f"Available numeric columns: {', '.join(numeric_cols)}"

    elif "most expensive" in question or "highest value" in question:
        col_name = extract_column_name(question, numeric_cols)
        if col_name:
            max_row = df.loc[df[col_name].idxmax()]
            return f"Most Expensive (Highest {col_name}): {max_row.to_dict()}"
        return f"Available numeric columns: {', '.join(numeric_cols)}"

    elif "cheapest" in question or "lowest value" in question:
        col_name = extract_column_name(question, numeric_cols)
        if col_name:
            min_row = df.loc[df[col_name].idxmin()]
            return f"Cheapest (Lowest {col_name}): {min_row.to_dict()}"
        return f"Available numeric columns: {', '.join(numeric_cols)}"

    else:
        return "Query not recognized. Try asking about average, max, min, total, or standard deviation."

# Helper function to extract the most relevant column from a question
def extract_column_name(question, numeric_cols):
    for col in numeric_cols:
        if col in question:
            return col
    return numeric_cols[0] if len(numeric_cols) == 1 else None


In [62]:
# Function to plot a graph dynamically
def plot_graph(file, x_col, y_col, plot_type):
    df = load_csv(file)
    if isinstance(df, str):
        return df

    # Convert column names to lowercase for uniformity
    x_col = x_col.lower().replace(" ", "_")
    y_col = y_col.lower().replace(" ", "_")

    # Validate if columns exist in the dataset
    if x_col not in df.columns or y_col not in df.columns:
        return f"Invalid column names: {x_col} or {y_col} not found in dataset."

    try:
        # Convert numeric columns properly
        df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
        df[x_col] = pd.to_numeric(df[x_col], errors='coerce') if df[x_col].dtype != 'O' else df[x_col]
        df = df.dropna(subset=[x_col, y_col])  # Drop NaN values

    except Exception as e:
        return f"Error processing columns: {str(e)}"

    # Create figure
    fig, ax = plt.subplots(figsize=(8, 5))

    # Plot based on type
    if plot_type == "line":
        ax.plot(df[x_col], df[y_col], marker='o', linestyle='-', color='black')
    elif plot_type == "scatter":
        ax.scatter(df[x_col], df[y_col], color='black')
    elif plot_type == "bar":
        ax.bar(df[x_col].astype(str), df[y_col], color='black')  # Convert x_col to string for categorical values
    elif plot_type == "histogram":
        ax.hist(df[y_col], bins=10, alpha=0.7, color='black')
    else:
        return "Invalid plot type. Choose from: line, scatter, bar, histogram."

    # Labels and title
    ax.set_xlabel(x_col.replace("_", " ").title())
    ax.set_ylabel(y_col.replace("_", " ").title())
    ax.set_title(f'{y_col.replace("_", " ").title()} vs {x_col.replace("_", " ").title()}')
    ax.grid()

    plt.xticks(rotation=45)  # Rotate x-axis labels for better visibility
    plt.show()  # Show plot in Google Colab

    return fig  # Return figure for further use



In [63]:
# Gradio Interface
with gr.Blocks() as app:
    gr.Markdown("## 🏡 CSV Question Answering & Visualization")

    file_input = gr.File(label="Upload CSV", type="filepath")
    query_input = gr.Textbox(label="Ask a Question")
    query_output = gr.Textbox(label="Answer", interactive=False)

    x_column = gr.Textbox(label="X-Axis Column")
    y_column = gr.Textbox(label="Y-Axis Column")
    plot_type = gr.Dropdown(choices=["line", "scatter", "bar", "histogram"], label="Select Plot Type")
    plot_output = gr.Plot(label="Plot")

    query_button = gr.Button("Get Answer")
    plot_button = gr.Button("Generate Plot")

    query_button.click(fn=answer_query, inputs=[file_input, query_input], outputs=query_output)
    plot_button.click(fn=plot_graph, inputs=[file_input, x_column, y_column, plot_type], outputs=plot_output)

app.launch(share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://6e2a8e504a144cd72b.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


