# Visualize Clusters


## Set up


In [1]:
from pathlib import Path
import json

import altair as alt
import lightgbm as lgb
import pandas as pd
from pdpilot import PDPilotWidget

from clustering import plot_vine_clusters

In [None]:
alt.data_transformers.enable("vegafusion")

In [3]:
output_dir = Path("results")
output_dir.mkdir(parents=True, exist_ok=True)

In [4]:
# random seed
seed = 1

In [5]:
booster = lgb.Booster(model_file=(output_dir / "model.txt"))

In [6]:
df = pd.read_csv(output_dir / "data.csv")

In [7]:
df_X = df.drop(columns=["y"])
features = list(df_X.columns)
X = df_X.to_numpy()
y = df["y"].to_numpy()

## Visualize ICE plots

Only one `PDPilotWidget` can be run in a notebook at a time. Switch between the two paths for `pd_data` to see the differences in the cluster descriptions for the x2 feature.


In [None]:
w = PDPilotWidget(
    predict=booster.predict,
    df=df_X,
    labels=y,
    pd_data=(output_dir / "pdpilot_max_depth_1.json").as_posix(),
    # pd_data=(output_dir / "pdpilot_max_depth_3.json").as_posix(),
    height=650,
    seed=seed,
)

w

VINE is not implemented as a widget for Jupyter notebooks, so we use Altair to visualize the clustered ICE plots in a style similar to PDPilot.


In [9]:
def read_json(file_path):
    return json.loads(Path(file_path).read_bytes())

In [10]:
plots = {}

for num_clusters, prune_clusters in [(2, True), (5, True), (2, False), (5, False)]:
    vine_data_path = (
        output_dir / f"vine_n_clusters_{num_clusters}_prune_{prune_clusters}.json"
    )

    vine_data = read_json(vine_data_path)

    title_params = {"Initial number of clusters": num_clusters}

    plot = plot_vine_clusters(
        vine_data, feature="x2", title=f"{num_clusters} initial clusters"
    )

    vine_image_path = vine_data_path.with_suffix(".png")
    plot.save(vine_image_path, ppi=200)

    plots[(num_clusters, prune_clusters)] = plot

### VINE - `prune_clusters=True`


In [None]:
plots[(2, True)]

In [None]:
plots[(5, True)]

### VINE - `prune_clusters=False`


In [None]:
plots[(2, False)]

In [None]:
plots[(5, False)]