<a href="https://colab.research.google.com/github/Nitesh10coder/airflow-sim/blob/main/Desicion_tree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import io
import base64
from typing import Optional

import numpy as np
import pandas as pd
from sklearn.datasets import load_iris, load_wine, load_breast_cancer, make_classification
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
)
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
# Install streamlit if not already installed
try:
    import streamlit as st
except ModuleNotFoundError:
    !pip install -q streamlit
    import streamlit as st

st.set_page_config(layout="wide", page_title="Decision Tree Visualizer")

# ---------------------------
# Helpers
# ---------------------------

@st.cache_data
def load_sample_dataset(name: str) -> pd.DataFrame:
    if name == "Iris":
        data = load_iris(as_frame=True)
        df = data.frame
    elif name == "Wine":
        data = load_wine(as_frame=True)
        df = data.frame
    elif name == "Breast Cancer":
        data = load_breast_cancer(as_frame=True)
        df = data.frame
    elif name == "Synthetic (make_classification)":
        X, y = make_classification(n_samples=300, n_features=4, n_informative=3, random_state=1)
        df = pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])])
        df["target"] = y
    else:
        raise ValueError("Unknown dataset")
    # For Iris, Wine, Breast Cancer, data.frame already includes the 'target' column.
    # The previous line `df = pd.concat([data.frame, data.target.rename("target")], axis=1)`
    # caused duplicate 'target' columns because data.frame already contained 'target'.
    return df

def encode_target(y: pd.Series) -> (np.ndarray, Optional[LabelEncoder]):
    if y.dtype == "object" or y.dtype.name == "category" or y.dtype == "bool":
        le = LabelEncoder()
        return le.fit_transform(y.astype(str)), le
    try:
        # numeric, leave as-is
        return y.values, None
    except Exception:
        le = LabelEncoder()
        return le.fit_transform(y.astype(str)), le

def plot_decision_tree(clf: DecisionTreeClassifier, feature_names, class_names):
    fig, ax = plt.subplots(figsize=(12, 8))
    plot_tree(
        clf,
        feature_names=feature_names,
        class_names=[str(c) for c in class_names],
        filled=True,
        impurity=True,
        rounded=True,
        fontsize=10,
        ax=ax,
    )
    plt.tight_layout()
    return fig

def plot_confusion(cm, labels):
    fig, ax = plt.subplots(figsize=(5, 4))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
    disp.plot(cmap="Blues", ax=ax, colorbar=False)
    plt.tight_layout()
    return fig

def plot_feature_importances(importances, feature_names):
    fig, ax = plt.subplots(figsize=(6, 4))
    indices = np.argsort(importances)[::-1]
    sns.barplot(x=importances[indices], y=np.array(feature_names)[indices], ax=ax)
    ax.set_title("Feature importances")
    plt.tight_layout()
    return fig

def decision_boundary_plot(clf, X, y, feature_names, use_pca=False):
    # X: 2D array
    x_min, x_max = X[:, 0].min() - 1.0, X[:, 0].max() + 1.0
    y_min, y_max = X[:, 1].min() - 1.0, X[:, 1].max() + 1.0
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 300), np.linspace(y_min, y_max, 300))
    grid = np.c_[xx.ravel(), yy.ravel()]
    Z = clf.predict(grid)
    Z = Z.reshape(xx.shape)

    fig, ax = plt.subplots(figsize=(6, 5))
    contour = ax.contourf(xx, yy, Z, alpha=0.3, cmap="coolwarm")
    scatter = ax.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap="coolwarm", edgecolor="k")
    ax.set_xlabel(f"{feature_names[0]}")
    ax.set_ylabel(f"{feature_names[1]}")
    ax.set_title("Decision boundary (visualization)")
    plt.tight_layout()
    return fig

def df_to_download_bytes(obj, filename: str):
    buf = io.BytesIO()
    joblib.dump(obj, buf)
    buf.seek(0)
    b = buf.read()
    return b

# ---------------------------
# UI
# ---------------------------

st.title("Decision Tree Classifier — Interactive Visualizer")
st.markdown(
    """
Upload your CSV or pick a sample dataset, choose the target column, tune hyperparameters interactively, train the model, and visualize the fitted decision tree, feature importances, evaluation metrics and a 2D decision-boundary visualization (with optional PCA when >2 features).
"""
)

with st.sidebar:
    st.header("Data")
    dataset_source = st.selectbox(
        "Load dataset",
        ("Sample datasets", "Upload CSV"),
    )

    if dataset_source == "Sample datasets":
        sample_name = st.selectbox("Choose sample dataset", ("Iris", "Wine", "Breast Cancer", "Synthetic (make_classification)"))
        df = load_sample_dataset(sample_name)
    else:
        uploaded_file = st.file_uploader("Upload CSV file", type=["csv"])
        if uploaded_file is not None:
            try:
                df = pd.read_csv(uploaded_file)
            except Exception as e:
                st.error(f"Error reading CSV: {e}")
                st.stop()
        else:
            st.info("Upload a CSV to proceed or pick a sample dataset above.")
            st.stop()

    st.write("Preview of data (first 5 rows):")
    st.dataframe(df.head())

    st.markdown("---")
    st.header("Preprocessing")
    drop_na = st.checkbox("Drop rows with NA", value=True)
    if drop_na:
        df = df.dropna()
    encode_categorical = st.checkbox("One-hot encode non-numeric features", value=True)
    if encode_categorical:
        df = pd.get_dummies(df)

    # target selection
    all_columns = list(df.columns)
    target_col = st.selectbox("Select target column", all_columns, index=len(all_columns) - 1)
    X = df.drop(columns=[target_col])
    y_raw = df[target_col]
    st.write(f"Features: {len(X.columns)} — Target dtype: {y_raw.dtype}")

    scale_data = st.checkbox("Standard scale features", value=False)
    test_size = st.slider("Test set proportion", min_value=0.05, max_value=0.5, value=0.25, step=0.05)
    random_state = st.number_input("Random state", value=1, step=1)

    st.markdown("---")
    st.header("Decision Tree Hyperparameters")
    criterion = st.selectbox("Criterion", ("gini", "entropy"))
    splitter = st.selectbox("Splitter", ("best", "random"))
    max_depth_choice = st.checkbox("Limit max_depth", value=False)
    max_depth = st.slider("max_depth", min_value=1, max_value=50, value=3) if max_depth_choice else None
    min_samples_split = st.slider("min_samples_split", min_value=2, max_value=50, value=2)
    min_samples_leaf = st.slider("min_samples_leaf", min_value=1, max_value=50, value=1)
    max_features = st.selectbox("max_features", ("None", "auto", "sqrt", "log2", "int"))
    if max_features == "int":
        max_features_int = st.slider("max_features (int)", min_value=1, max_value=max(1, X.shape[1]), value=min(3, X.shape[1]))
        max_features_param = max_features_int
    elif max_features == "None":
        max_features_param = None
    else:
        max_features_param = max_features

    ccp_alpha = st.slider("ccp_alpha (pruning)", min_value=0.0, max_value=1.0, value=0.0, step=0.0001, format="%.4f")

    st.markdown("---")
    st.header("Train")
    run_train = st.button("Train decision tree")

# ---------------------------
# Training & Visualization
# ---------------------------

if 'df' not in locals():
    st.stop()

if run_train:
    # prepare X and y
    y, label_encoder = encode_target(y_raw)
    X_values = X.values
    feature_names = list(X.columns)
    class_names = np.unique(y)

    if scale_data:
        scaler = StandardScaler()
        X_values = scaler.fit_transform(X_values)
    else:
        scaler = None

    X_train, X_test, y_train, y_test = train_test_split(X_values, y, test_size=test_size, random_state=int(random_state), stratify=y if len(class_names)>1 else None)

    clf = DecisionTreeClassifier(
        criterion=criterion,
        splitter=splitter,
        max_depth=max_depth,
        min_samples_split=int(min_samples_split),
        min_samples_leaf=int(min_samples_leaf),
        max_features=max_features_param,
        random_state=int(random_state),
        ccp_alpha=float(ccp_alpha),
    )
    clf.fit(X_train, y_train)

    # metrics
    y_pred = clf.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    report = classification_report(y_test, y_pred, zero_division=0)
    cm = confusion_matrix(y_test, y_pred)

    col1, col2 = st.columns([1, 1])

    with col1:
        st.subheader("Model metrics")
        st.write(f"Accuracy on test set: **{acc:.4f}**")
        st.write("Classification report:")
        st.text(report)

        st.download_button(
            label="Download trained model (.joblib)",
            data=df_to_download_bytes({"model": clf, "scaler": scaler, "features": feature_names, "label_encoder": label_encoder}),
            file_name="decision_tree_model.joblib",
            mime="application/octet-stream",
        )

    with col2:
        st.subheader("Cross-validation")
        try:
            cvs = cross_val_score(clf, X_values, y, cv=5)
            st.write(f"CV scores (5-fold): {np.round(cvs, 4)}")
            st.write(f"CV mean: {cvs.mean():.4f}  std: {cvs.std():.4f}")
        except Exception as e:
            st.write("Cross-validation failed:", e)

    # Tree plot
    st.subheader("Decision tree structure")
    try:
        fig_tree = plot_decision_tree(clf, feature_names, class_names)
        st.pyplot(fig_tree)
    except Exception as e:
        st.write("Could not plot tree:", e)
        st.code(export_text(clf, feature_names=feature_names))

    # Confusion matrix
    st.subheader("Confusion Matrix")
    fig_cm = plot_confusion(cm, labels=class_names)
    st.pyplot(fig_cm)

    # Feature importances
    st.subheader("Feature importances")
    importances = clf.feature_importances_
    fig_imp = plot_feature_importances(importances, feature_names)
    st.pyplot(fig_imp)

    # Decision boundary (if 2 features) or PCA option
    st.subheader("2D decision boundary (visualization)")
    if X_values.shape[1] == 2:
        X_vis = X_values
        vis_feature_names = feature_names
        fig_db = decision_boundary_plot(clf, X_vis, y, vis_feature_names)
        st.pyplot(fig_db)
    else:
        use_pca_vis = st.checkbox("Project features to 2D using PCA for visualization", value=True)
        if use_pca_vis:
            pca = PCA(n_components=2, random_state=1)
            X_pca = pca.fit_transform(X_values)
            # train a new tree on PCA-projected data only for visualization
            clf_viz = DecisionTreeClassifier(
                criterion=criterion,
                splitter=splitter,
                max_depth=max_depth,
                min_samples_split=int(min_samples_split),
                min_samples_leaf=int(min_samples_leaf),
                max_features=None,
                random_state=int(random_state),
                ccp_alpha=float(ccp_alpha),
            )
            clf_viz.fit(X_pca, y)
            fig_db = decision_boundary_plot(clf_viz, X_pca, y, ["PC1", "PC2"])
            st.pyplot(fig_db)
        else:
            st.info("Decision boundary plot requires exactly 2 features or PCA projection enabled.")

    st.success("Training and visualizations completed.")
else:
    st.info("Adjust settings in the sidebar and click Train decision tree to start.")

2026-01-12 06:33:13.192 No runtime found, using MemoryCacheStorageManager
2026-01-12 06:33:13.252 No runtime found, using MemoryCacheStorageManager


In [6]:
!pip install pyngrok
from pyngrok import ngrok
# Replace 'YOUR_NGROK_AUTH_TOKEN' with your actual ngrok token
ngrok.set_auth_token('YOUR_NGROK_AUTH_TOKEN')

Collecting pyngrok
  Downloading pyngrok-7.5.0-py3-none-any.whl.metadata (8.1 kB)
Downloading pyngrok-7.5.0-py3-none-any.whl (24 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.5.0
