In [2]:
:dep ndarray = "0.15.6"
:dep linfa = "0.7.0"
:dep linfa-linear = "0.7.0"
:dep plotly = "0.8.4"
:dep polars = { version = "0.39.2", features = ["lazy", "ndarray"] }

:dep helper-functions = { path = "../helper-functions/" }

In [3]:
use plotly::{
    common::*,
    layout::*,
    Trace,
};

use polars::prelude::*;

use helper_functions::prelude::*;

In [4]:
let raw_dataset = LazyCsvReader::new("data/penguins.csv")
    .finish().unwrap()
    .select([
        col("flipper_length_mm").cast(DataType::Float64),
        col("body_mass_g").cast(DataType::Float64),
    ])
    .drop_nulls(None)
    .collect().unwrap();

raw_dataset

shape: (342, 2)
┌───────────────────┬─────────────┐
│ flipper_length_mm ┆ body_mass_g │
│ ---               ┆ ---         │
│ f64               ┆ f64         │
╞═══════════════════╪═════════════╡
│ 181.0             ┆ 3750.0      │
│ 186.0             ┆ 3800.0      │
│ 195.0             ┆ 3250.0      │
│ 193.0             ┆ 3450.0      │
│ 190.0             ┆ 3650.0      │
│ …                 ┆ …           │
│ 207.0             ┆ 4000.0      │
│ 202.0             ┆ 3400.0      │
│ 193.0             ┆ 3775.0      │
│ 210.0             ┆ 4100.0      │
│ 198.0             ┆ 3775.0      │
└───────────────────┴─────────────┘

In [5]:
let x = "body_mass_g";
let y = "flipper_length_mm";
let group_column = None;
let trace = "scatter";
let opacity = 1.0;

let traces = get_traces(
    x,
    y,
    group_column,
    &raw_dataset,
    trace,
    opacity,
);

let layout = Some(
    Layout::new()
        .title(
            Title::new("Flipper Length vs Body Mass")
        )
        .x_axis(
            Axis::new()
                .title(Title::new(x))
                .range_mode(RangeMode::ToZero)
        )
        .y_axis(
            Axis::new()
                .title(Title::new(y))
                .range_mode(RangeMode::ToZero)
        )
);

show_plot(traces, layout);

![](plots/plot.png)

In [6]:
let x = "body_mass_g";
let y = "flipper_length_mm";

let dataset = create_linfa_dataset(
    vec![x],
    Some(y),
    vec![x, y],
    &raw_dataset
);

In [7]:
let model = linear_regression_model(&dataset);
model

FittedLinearRegression { intercept: 136.72955927266196, params: [0.015275915608037312], shape=[1], strides=[1], layout=CFcf (0xf), const ndim=1 }

In [8]:
let predictions = linear_regression_predict(&dataset, &model);
predictions

shape: (342, 1)
┌─────────────┐
│ predictions │
│ ---         │
│ f64         │
╞═════════════╡
│ 194.014243  │
│ 194.778039  │
│ 186.376285  │
│ 189.431468  │
│ 192.486651  │
│ …           │
│ 197.833222  │
│ 188.667672  │
│ 194.396141  │
│ 199.360813  │
│ 194.396141  │
└─────────────┘

In [9]:
let final_dataframe = polars::functions::concat_df_horizontal(
    &[raw_dataset.clone(), predictions.clone()]
).unwrap();

final_dataframe

shape: (342, 3)
┌───────────────────┬─────────────┬─────────────┐
│ flipper_length_mm ┆ body_mass_g ┆ predictions │
│ ---               ┆ ---         ┆ ---         │
│ f64               ┆ f64         ┆ f64         │
╞═══════════════════╪═════════════╪═════════════╡
│ 181.0             ┆ 3750.0      ┆ 194.014243  │
│ 186.0             ┆ 3800.0      ┆ 194.778039  │
│ 195.0             ┆ 3250.0      ┆ 186.376285  │
│ 193.0             ┆ 3450.0      ┆ 189.431468  │
│ 190.0             ┆ 3650.0      ┆ 192.486651  │
│ …                 ┆ …           ┆ …           │
│ 207.0             ┆ 4000.0      ┆ 197.833222  │
│ 202.0             ┆ 3400.0      ┆ 188.667672  │
│ 193.0             ┆ 3775.0      ┆ 194.396141  │
│ 210.0             ┆ 4100.0      ┆ 199.360813  │
│ 198.0             ┆ 3775.0      ┆ 194.396141  │
└───────────────────┴─────────────┴─────────────┘

In [11]:
let x = "body_mass_g";
let y = "flipper_length_mm";
let group_column = None;

let mut traces = get_traces(
    x,
    y,
    group_column,
    &final_dataframe,
    "scatter",
    0.5,
);

let mut linear_trace = get_traces(
    x,
    "predictions",
    group_column,
    &final_dataframe,
    "linear",
    1.0,
);

traces.append(&mut linear_trace);

let layout = Some(
    Layout::new()
        .title(
            Title::new("Flipper Length vs Body Mass")
        )
        .x_axis(
            Axis::new()
                .title(Title::new(x))
                .range_mode(RangeMode::ToZero)
        )
        .y_axis(
            Axis::new()
                .title(Title::new(y))
                .range_mode(RangeMode::ToZero)
        )
);

show_plot(traces, layout);

![](plots/plot2.png)