In [4]:

import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Load Titanic dataset from seaborn
df = sns.load_dataset("titanic")
print("Dataset loaded successfully!")
display(df.head())

# Standardize column names (capitalize first letter)
df.columns = [c.capitalize() for c in df.columns]

# ---- QUICK SUMMARY ----
def build_summary(df):
    numeric = df.select_dtypes(include=[np.number])
    return {
        "rows": df.shape[0],
        "cols": df.shape[1],
        "columns": list(df.columns),
        "missing": df.isnull().sum().to_dict(),
        "top_values": {col: df[col].value_counts().head(3).to_dict() for col in df.columns},
        "numeric": numeric.describe().to_dict(),
        "unique": df.nunique().to_dict()
    }

summary = build_summary(df)

# ---- PLOT HELPERS ----
def show_hist(col):
    if col not in df.columns:
        return f"Column '{col}' not found."
    if not np.issubdtype(df[col].dtype, np.number):
        return f"'{col}' is not numeric."
    df[col].dropna().hist(bins=30)
    plt.title(f"Histogram of {col}")
    plt.show()

def show_bar(col):
    if col not in df.columns:
        return f"Column '{col}' not found."
    df[col].value_counts().head(10).plot.bar()
    plt.title(f"Top values in {col}")
    plt.show()

# ---- AGENT FUNCTION ----
def agent(question):
    q = question.lower()

    # Detect column mentioned
    col = None
    for c in summary["columns"]:
        if c.lower() in q:
            col = c
            break

    # Basic info
    if "row" in q:
        return f"Rows: {summary['rows']}"
    if "column" in q:
        return f"Columns ({summary['cols']}): {summary['columns']}"

    # Missing values
    if "missing" in q or "null" in q:
        miss = sorted(summary["missing"].items(), key=lambda x: x[1], reverse=True)
        return "Missing values (top):\n" + "\n".join(f"{k}: {v}" for k,v in miss[:5])

    # Top values
    if "top" in q or "most common" in q:
        if not col:
            return "Specify a column, e.g., 'top values of Sex'"
        return f"Top values in {col}:\n{summary['top_values'][col]}"

    # Numeric summaries
    if "mean" in q or "average" in q or "median" in q:
        if not col:
            return "Ask like 'mean age' or 'median fare'"
        if col not in summary["numeric"]:
            return f"'{col}' is not numeric."
        col_stats = summary["numeric"][col]
        return f"{col} summary: mean={col_stats.get('mean')}, median={col_stats.get('50%')}"

    # Unique values
    if "unique" in q:
        if not col:
            return "Ask like 'unique values in Class'"
        return f"Unique values in {col}: {summary['unique'][col]}"

    # Survival analysis (dataset has 'Survived')
    if "surviv" in q:
        if "sex" in q:
            if "female" in q:
                return f"Female survival rate: {df[df['Sex']=='female']['Survived'].mean():.2f}"
            if "male" in q:
                return f"Male survival rate: {df[df['Sex']=='male']['Survived'].mean():.2f}"
        return f"Overall survival rate: {df['Survived'].mean():.2f}"

    # Plotting
    if "plot" in q or "hist" in q or "bar" in q:
        if not col:
            return "Specify column, e.g., 'histogram of age'"
        if "hist" in q:
            show_hist(col)
            return f"Displayed histogram of {col}"
        if "bar" in q:
            show_bar(col)
            return f"Displayed bar plot of {col}"

    # Fallback
    return "Ask things like:\n- rows\n- columns\n- missing values\n- top values of sex\n- mean age\n- survival rate\n- histogram of age"





Dataset loaded successfully!


Unnamed: 0,survived,pclass,sex,age,sibsp,parch,fare,embarked,class,who,adult_male,deck,embark_town,alive,alone
0,0,3,male,22.0,1,0,7.25,S,Third,man,True,,Southampton,no,False
1,1,1,female,38.0,1,0,71.2833,C,First,woman,False,C,Cherbourg,yes,False
2,1,3,female,26.0,0,0,7.925,S,Third,woman,False,,Southampton,yes,True
3,1,1,female,35.0,1,0,53.1,S,First,woman,False,C,Southampton,yes,False
4,0,3,male,35.0,0,0,8.05,S,Third,man,True,,Southampton,no,True


In [5]:
print("\nAgent demo:")
print("•", agent("How many rows?"))
print("•", agent("Which columns are there?"))
print("•", agent("Missing values?"))
print("•", agent("Top values of Sex?"))
print("•", agent("Mean age?"))
print("•", agent("What is the female survival rate?"))


Agent demo:
• Rows: 891
• Columns (15): ['Survived', 'Pclass', 'Sex', 'Age', 'Sibsp', 'Parch', 'Fare', 'Embarked', 'Class', 'Who', 'Adult_male', 'Deck', 'Embark_town', 'Alive', 'Alone']
• Missing values (top):
Deck: 688
Age: 177
Embarked: 2
Embark_town: 2
Survived: 0
• Top values in Sex:
{'male': 577, 'female': 314}
• Age summary: mean=29.69911764705882, median=28.0
• Overall survival rate: 0.38
