In [1]:
import pandas as pd

df = pd.read_csv("/home/tym/projects/llm-data-analyst-agent/test_data/Heart_Disease_Prediction.csv")

In [2]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 270 entries, 0 to 269
Data columns (total 14 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   Age                      270 non-null    int64  
 1   Sex                      270 non-null    int64  
 2   Chest pain type          270 non-null    int64  
 3   BP                       270 non-null    int64  
 4   Cholesterol              270 non-null    int64  
 5   FBS over 120             270 non-null    int64  
 6   EKG results              270 non-null    int64  
 7   Max HR                   270 non-null    int64  
 8   Exercise angina          270 non-null    int64  
 9   ST depression            270 non-null    float64
 10  Slope of ST              270 non-null    int64  
 11  Number of vessels fluro  270 non-null    int64  
 12  Thallium                 270 non-null    int64  
 13  Heart Disease            270 non-null    object 
dtypes: float64(1), int64(12), 

In [3]:
num_df = df.select_dtypes(include="number")
round(num_df.corr(), 3)

Unnamed: 0,Age,Sex,Chest pain type,BP,Cholesterol,FBS over 120,EKG results,Max HR,Exercise angina,ST depression,Slope of ST,Number of vessels fluro,Thallium
Age,1.0,-0.094,0.097,0.273,0.22,0.123,0.128,-0.402,0.098,0.194,0.16,0.356,0.106
Sex,-0.094,1.0,0.035,-0.063,-0.202,0.042,0.039,-0.076,0.18,0.097,0.051,0.087,0.391
Chest pain type,0.097,0.035,1.0,-0.043,0.09,-0.099,0.074,-0.318,0.353,0.167,0.137,0.226,0.263
BP,0.273,-0.063,-0.043,1.0,0.173,0.156,0.116,-0.039,0.083,0.223,0.142,0.086,0.132
Cholesterol,0.22,-0.202,0.09,0.173,1.0,0.025,0.168,-0.019,0.078,0.028,-0.006,0.127,0.029
FBS over 120,0.123,0.042,-0.099,0.156,0.025,1.0,0.053,0.022,-0.004,-0.026,0.044,0.124,0.049
EKG results,0.128,0.039,0.074,0.116,0.168,0.053,1.0,-0.075,0.095,0.12,0.161,0.114,0.007
Max HR,-0.402,-0.076,-0.318,-0.039,-0.019,0.022,-0.075,1.0,-0.381,-0.349,-0.387,-0.265,-0.253
Exercise angina,0.098,0.18,0.353,0.083,0.078,-0.004,0.095,-0.381,1.0,0.275,0.256,0.153,0.321
ST depression,0.194,0.097,0.167,0.223,0.028,-0.026,0.12,-0.349,0.275,1.0,0.61,0.255,0.324


In [9]:
df.select_dtypes(include="number").drop(columns=["Survived"]).corrwith(df["Survived"])

PassengerId   -0.005007
Pclass        -0.338481
Age           -0.077221
SibSp         -0.035322
Parch          0.081629
Fare           0.257307
dtype: float64

In [10]:
feature_target_corr = (
    df.select_dtypes(include="number")
      .drop(columns=["Survived"])
      .corrwith(df["Survived"])
      .round(3)
      .to_dict()
)

In [11]:
df.isna().sum().sum()

np.int64(866)

In [12]:
num_df = df.select_dtypes(include="number")
corr = num_df.corr()
corr

Unnamed: 0,PassengerId,Survived,Pclass,Age,SibSp,Parch,Fare
PassengerId,1.0,-0.005007,-0.035144,0.036847,-0.057527,-0.001652,0.012658
Survived,-0.005007,1.0,-0.338481,-0.077221,-0.035322,0.081629,0.257307
Pclass,-0.035144,-0.338481,1.0,-0.369226,0.083081,0.018443,-0.5495
Age,0.036847,-0.077221,-0.369226,1.0,-0.308247,-0.189119,0.096067
SibSp,-0.057527,-0.035322,0.083081,-0.308247,1.0,0.414838,0.159651
Parch,-0.001652,0.081629,0.018443,-0.189119,0.414838,1.0,0.216225
Fare,0.012658,0.257307,-0.5495,0.096067,0.159651,0.216225,1.0


# Data Context as class

In [13]:
from dataclasses import dataclass
import pandas as pd


@dataclass
class DataContext:
    df: pd.DataFrame | None = None
    source: str | None = None
    format: str | None = None

    def is_loaded(self) -> bool:
        return self.df is not None


DATA_CONTEXT = DataContext()

# Tools

### load_data, dataset_head, dataset_info

In [14]:
import pandas as pd
from pathlib import Path


def load_data(path: str) -> dict:
    ext = Path(path).suffix.lower()

    if ext == ".csv":
        df = pd.read_csv(path)
        fmt = "csv"
    elif ext in [".parquet", ".pq"]:
        df = pd.read_parquet(path)
        fmt = "parquet"
    elif ext in [".tsv"]:
        df = pd.read_csv(path, sep="\t")
        fmt = "tsv"
    elif ext in [".xlsx"]:
        df = pd.read_excel(path)
        fmt = "excel"
    elif ext in [".json", ".jsonl"]:
        df = pd.read_json(path, lines=ext == ".jsonl")
        fmt = "json"
    elif ext in [".pkl"]:
        df = pd.read_pickle(path)
        fmt = "pickle"
    else:
        raise ValueError(f"Unsupported format: {ext}")

    DATA_CONTEXT.df = df
    DATA_CONTEXT.source = path
    DATA_CONTEXT.format = fmt

    return {
        "status": "ok",
        "format": fmt,
        "rows": len(df),
        "columns": len(df.columns)
    }


def dataset_head(n: int = 5) -> list[dict]:
    if not DATA_CONTEXT.is_loaded():
        raise RuntimeError("No dataset loaded")

    return DATA_CONTEXT.df.head(n).to_dict(orient="records")


def dataset_info(max_top_values: int = 5) -> dict:
    if not DATA_CONTEXT.is_loaded():
        raise RunTimeError("No dataset loaded")

    df = DATA_CONTEXT.df

    n_rows = len(df)
    n_cols = len(df.columns)
    missed_values = df.isna().sum().sum()
    missing_pct = round((missed_values / (n_rows * n_cols)) * 100, 2)
    columns_info = {}

    for col in df.columns:
        s = df[col]

        n_missing = int(s.isna().sum())
        col_missing_pct = round(n_missing / n_rows * 100, 2)
        n_unique = int(s.nunique(dropna=True))

        if pd.api.types.is_numeric_dtype(s):
            semantic = "numeric"
            top_values = None
        else:
            semantic = "categorical"
            top_values = (
                s.value_counts(dropna=True)
                 .head(max_top_values)
                 .to_dict()
            )

        columns_info[col] = {
            "dtype": str(s.dtype),
            "n_missing": n_missing,
            "col_missing_pct": col_missing_pct,
            "n_unique": n_unique,
            "semantic_type": semantic,
            **({"top_values": top_values} if top_values else {})
        }

    return {
        "rows": n_rows,
        "n_columns": len(df.columns),
        "missing_pct": missing_pct,
        "columns": columns_info
    }

### correlation_matrix + heatmap

In [15]:
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path


def correlation_matrix(
    threshold: float = 0.2,
    label: str = None
) -> dict:

    if not DATA_CONTEXT.is_loaded():
        raise RunTimeError("No dataset loaded")

    df = DATA_CONTEXT.df
    num_df = df.select_dtypes(include="number")
    corr = num_df.corr()

    pairs = []
    for i, col1 in enumerate(corr.columns):
        for col2 in corr.columns[i+1:]:
            value = corr.loc[col1, col2]
            if abs(value) >= threshold:
                pairs.append({
                    "feature_1": col1,
                    "feature_2": col2,
                    "correlation": round(float(value), 3)
                })

    if label:
        feature_target_corr = (
            df.select_dtypes(include="number")
              .drop(columns=[label])
              .corrwith(df[label])
              .abs()
              .sort_values(ascending=False)
              .round(3)
              .to_dict()
        )
    
        return {
            "threshold": threshold,
            "matrix": corr.round(3).to_dict(),
            "feature_target_abs_corr": feature_target_corr,
            "high_correlation_pairs": pairs
        }
    
    else:
        return {
            "threshold": threshold,
            "matrix": corr.round(3).to_dict(),
            "high_correlation_pairs": pairs
        }


def plot_correlation_heatmap(dataset_name: str) -> dict:
    if not DATA_CONTEXT.is_loaded():
        raise RunTimeError("No dataset loaded")

    df = DATA_CONTEXT.df
    num_df = df.select_dtypes(include="number")
    corr = num_df.corr()
    
    Path("plots").mkdir(exist_ok=True)

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        corr,
        annot=True,
        cmap="coolwarm",
        center=0,
    )
    plt.title(f"Correlation heatmap ({dataset_name})")

    path = f"plots/correlation_heatmap{dataset_name}.png"
    plt.tight_layout()
    plt.savefig(path)
    plt.close()

    return {
        "type": "heatmap",
        "path": path
    }

In [38]:
def missing_values_report() -> dict:
    if not DATA_CONTEXT.is_loaded():
        raise RuntimeError("No dataset loaded")

    df = DATA_CONTEXT.df
    total_rows = len(df)

    missing = df.isna().sum()
    missing = missing[missing > 0]

    report = {}

    for col, count in missing.items():
        report[col] = {
            "missing": int(count),
            "percent": round((count / total_rows) * 100, 2)
        }

    return report

In [39]:
def basic_statistics() -> dict:
    if not DATA_CONTEXT.is_loaded():
        raise RuntimeError("No dataset loaded")

    df = DATA_CONTEXT.df
    num_df = df.select_dtypes(include="number")

    stats = num_df.describe().to_dict()

    result = {}

    for col, values in stats.items():
        result[col] = {
            "count": int(values.get("count", 0)),
            "mean": round(values.get("mean", 0), 4),
            "std": round(values.get("std", 0), 4),
            "min": round(values.get("min", 0), 4),
            "25%": round(values.get("25%", 0), 4),
            "50%": round(values.get("50%", 0), 4),
            "75%": round(values.get("75%", 0), 4),
            "max": round(values.get("max", 0), 4),
        }

    return result

# Agent imitation

In [40]:
class AnalystAgent:
    def run(self, command: str):
        if command.startswith("load"):
            path = command.split(maxsplit=1)[1]
            return load_data(path)

        if command.startswith("info"):
            if not DATA_CONTEXT.is_loaded():
                path = command.split(maxsplit=1)[1]
                load_data(path)
            
            return dataset_info()

        if command.startswith("head"):
            if not DATA_CONTEXT.is_loaded():
                path = command.split(maxsplit=1)[1]
                load_data(path)
                
            return dataset_head()

        if command.startswith("corr"):
            if not DATA_CONTEXT.is_loaded():
                path = command.split(maxsplit=1)[1]
                load_data(path)
                
            return correlation_matrix(label="Survived")

        if command.startswith("heatmap"):
            if not DATA_CONTEXT.is_loaded():
                path = command.split(maxsplit=1)[1]
                load_data(path)
            
            path = DATA_CONTEXT.source
            dataset_name = path.split("/")[-1].split(".")[0]
            return plot_correlation_heatmap(dataset_name)

        if command.startswith("miss"):
            if not DATA_CONTEXT.is_loaded():
                path = command.split(maxsplit=1)[1]
                load_data(path)
                
            return missing_values_report()

        if command.startswith("bs"):
            if not DATA_CONTEXT.is_loaded():
                path = command.split(maxsplit=1)[1]
                load_data(path)
                
            return basic_statistics()
        
        raise ValueError(f"Unknown command: {command}")

In [22]:
path = "test_data/Titanic-Dataset.csv"
path.split("/")[-1].split(".")[0]

'Titanic-Dataset'

In [43]:
agent = AnalystAgent()

result1 = agent.run("load test_data/Titanic-Dataset.csv")

result2 = agent.run("info test_data/Titanic-Dataset.csv")

result3 = agent.run("head test_data/Titanic-Dataset.csv")

result4 = agent.run("corr test_data/Titanic-Dataset.csv")

result5 = agent.run("heatmap test_data/Titanic-Dataset.csv")

result6 = agent.run("miss test_data/Titanic-Dataset.csv")

result7 = agent.run("bs test_data/Titanic-Dataset.csv")

In [44]:
result7

{'PassengerId': {'count': 891,
  'mean': 446.0,
  'std': 257.3538,
  'min': 1.0,
  '25%': 223.5,
  '50%': 446.0,
  '75%': 668.5,
  'max': 891.0},
 'Survived': {'count': 891,
  'mean': 0.3838,
  'std': 0.4866,
  'min': 0.0,
  '25%': 0.0,
  '50%': 0.0,
  '75%': 1.0,
  'max': 1.0},
 'Pclass': {'count': 891,
  'mean': 2.3086,
  'std': 0.8361,
  'min': 1.0,
  '25%': 2.0,
  '50%': 3.0,
  '75%': 3.0,
  'max': 3.0},
 'Age': {'count': 714,
  'mean': 29.6991,
  'std': 14.5265,
  'min': 0.42,
  '25%': 20.125,
  '50%': 28.0,
  '75%': 38.0,
  'max': 80.0},
 'SibSp': {'count': 891,
  'mean': 0.523,
  'std': 1.1027,
  'min': 0.0,
  '25%': 0.0,
  '50%': 0.0,
  '75%': 1.0,
  'max': 8.0},
 'Parch': {'count': 891,
  'mean': 0.3816,
  'std': 0.8061,
  'min': 0.0,
  '25%': 0.0,
  '50%': 0.0,
  '75%': 0.0,
  'max': 6.0},
 'Fare': {'count': 891,
  'mean': 32.2042,
  'std': 49.6934,
  'min': 0.0,
  '25%': 7.9104,
  '50%': 14.4542,
  '75%': 31.0,
  'max': 512.3292}}