In [2]:
:dep linfa = "0.7.0"
:dep linfa-clustering = "0.7.0"
:dep linfa-nn = "0.7.0"
:dep ndarray = "0.15.6"
:dep plotly = "0.8.4"
:dep polars = { version = "0.39.2", features = ["lazy", "dtype-categorical", "ndarray"] }
:dep rand = "0.8.5"

:dep helper-functions = { path = "../../../projects/helper-functions/" }

In [3]:
use plotly::{
    common::*,
    layout::*,
    Trace,
};

use polars::prelude::*;

use helper_functions::prelude::*;

In [4]:
let raw_dataset = LazyCsvReader::new("data/k-means-dataset.csv")
    .finish().unwrap()
    .select([
        col("Feature_1").alias("feature_1"),
        col("Feature_2").alias("feature_2"),
        col("Cluster")
            .cast(DataType::String) // You can cast from numeric to categorical directly
            .cast(DataType::Categorical(None, CategoricalOrdering::default()))
            .alias("cluster")
    ])
    .collect().unwrap();

raw_dataset.head(None)

shape: (10, 3)
┌───────────┬───────────┬─────────┐
│ feature_1 ┆ feature_2 ┆ cluster │
│ ---       ┆ ---       ┆ ---     │
│ f64       ┆ f64       ┆ cat     │
╞═══════════╪═══════════╪═════════╡
│ -7.338988 ┆ -7.729954 ┆ 2       │
│ -7.740041 ┆ -7.264665 ┆ 2       │
│ -1.686653 ┆ 7.793442  ┆ 0       │
│ 4.422198  ┆ 3.071947  ┆ 1       │
│ -8.917752 ┆ -7.888196 ┆ 2       │
│ 5.497538  ┆ 1.813231  ┆ 1       │
│ -2.336017 ┆ 9.399604  ┆ 0       │
│ 5.05281   ┆ 1.409445  ┆ 1       │
│ -2.988372 ┆ 8.828627  ┆ 0       │
│ -3.700501 ┆ 9.67084   ┆ 0       │
└───────────┴───────────┴─────────┘

In [5]:
let x = "feature_1";
let y = "feature_2";
let group_column = None;
let trace = "scatter";
let opacity = 1.0;

let traces = get_traces(
    x,
    y,
    group_column,
    &raw_dataset,
    trace,
    opacity,
);

let layout = Some(
    Layout::new()
        .title(
            Title::new("K-Means Dataset")
        )
        .x_axis(
            Axis::new()
                .title(Title::new(x))
                .range_mode(RangeMode::ToZero)
        )
        .y_axis(
            Axis::new()
                .title(Title::new(y))
                .range_mode(RangeMode::ToZero)
        )
);

let plot = show_plot(traces, layout);

![](plots/plot.png)

In [6]:
let x = "feature_1";
let y = "feature_2";

let dataset = create_linfa_dataset(
    vec![x],
    Some(y),
    vec![x, y],
    &raw_dataset,
);

In [7]:
let max_n_iterations = 200;
let tolerance = 1e-5;
let n_clusters = 3;

let model = kmeans_model(
    &dataset,
    n_clusters,
    max_n_iterations,
    tolerance,
);

model

KMeans { centroids: [[4.747103374826418],
 [-6.8971537521181006],
 [-2.6199507136625297]], shape=[3, 1], strides=[1, 1], layout=CFcf (0xf), const ndim=2, cluster_count: [100.0, 100.0, 100.0], shape=[3], strides=[1], layout=CFcf (0xf), const ndim=1, inertia: 0.911791517567829, dist_fn: L2Dist }

In [8]:
let predictions = kmeans_predict(&dataset, &model);
predictions

shape: (300, 1)
┌─────────────┐
│ predictions │
│ ---         │
│ i32         │
╞═════════════╡
│ 1           │
│ 1           │
│ 2           │
│ 0           │
│ 1           │
│ …           │
│ 2           │
│ 0           │
│ 2           │
│ 1           │
│ 1           │
└─────────────┘

In [9]:
let final_dataframe = polars::functions::concat_df_horizontal(
    &[raw_dataset.clone(), predictions.clone()]
).unwrap()
.lazy()
.select([col("*").exclude(["cluster"])])
.sort(
    ["predictions"],
    SortMultipleOptions {
        descending: vec![false],
        nulls_last: true,
        ..Default::default()
    }
)
.with_column(
    col("predictions")
        .cast(DataType::String)
        .cast(DataType::Categorical(None, CategoricalOrdering::default()))
)
.rename(["predictions"], ["cluster"])
.collect()
.unwrap();

final_dataframe

shape: (300, 3)
┌───────────┬───────────┬─────────┐
│ feature_1 ┆ feature_2 ┆ cluster │
│ ---       ┆ ---       ┆ ---     │
│ f64       ┆ f64       ┆ cat     │
╞═══════════╪═══════════╪═════════╡
│ 4.422198  ┆ 3.071947  ┆ 0       │
│ 5.497538  ┆ 1.813231  ┆ 0       │
│ 5.05281   ┆ 1.409445  ┆ 0       │
│ 4.996894  ┆ 1.28026   ┆ 0       │
│ 2.614736  ┆ 2.159624  ┆ 0       │
│ …         ┆ …         ┆ …       │
│ -3.355991 ┆ 7.499439  ┆ 2       │
│ -2.185114 ┆ 8.629204  ┆ 2       │
│ -2.72887  ┆ 9.371399  ┆ 2       │
│ -3.660191 ┆ 9.389984  ┆ 2       │
│ -4.116681 ┆ 9.19892   ┆ 2       │
└───────────┴───────────┴─────────┘

In [10]:
let x = "feature_1";
let y = "feature_2";
let group_column = Some("cluster");
let trace = "scatter";
let opacity = 1.0;

let traces = get_traces(
    x,
    y,
    group_column,
    &final_dataframe,
    trace,
    opacity,
);

let layout = Some(
    Layout::new()
        .title(
            Title::new("K-Means Predictions")
        )
        .x_axis(
            Axis::new()
                .title(Title::new(x))
                .range_mode(RangeMode::ToZero)
        )
        .y_axis(
            Axis::new()
                .title(Title::new(y))
                .range_mode(RangeMode::ToZero)
        )
        .legend(
            Legend::new()
                .title(Title::new(group_column.unwrap()))
        )
);

show_plot(traces, layout);

![](plots/plot2.png)