In [None]:
library('palmerpenguins')

In [None]:
library('tidymodels')

In [None]:
library('rpart.plot')

In [None]:
penguins2 = penguins |> drop_na()

penguins2 |> head()

In [None]:
penguins2 |>
    count(species)

In [None]:
show_engines('decision_tree')

In [None]:
mod = decision_tree() |>
    set_engine('rpart') |>
    set_mode('classification')

mod_fit = mod |> fit(species ~ ., data = penguins2)

mod_fit

In [None]:
augment(mod_fit, penguins2) |> head()

In [None]:
options(repr.plot.height = 7, repr.plot.width = 8)

rpart.plot(mod_fit$fit)

In [None]:
v = c('a', 'b', 'c', 'd', 'e', 'f')

# (1 - ((6 / 6)^2)) * 100

In [None]:
gini_impurity = function(values) {
    freqs = table(values) / length(values)
    (1 - sum(freqs^2)) * 100
}

mean_gini_impurity = function(split_point, penguins_df = penguins2) {
    # assumes splitting variable is "flipper_length_mm"
    
    n1 = penguins_df |> filter(flipper_length_mm < split_point) |> nrow()
    g1 = gini_impurity(
        penguins_df |> 
            filter(flipper_length_mm < split_point) |> 
            pull(species)
    )
    
    n2 = penguins_df |> filter(flipper_length_mm >= split_point) |> nrow()
    g2 = gini_impurity(
        penguins_df |> 
            filter(flipper_length_mm >= split_point) |> 
            pull(species)
    )
    
    weighted.mean(c(g1, g2), c(n1, n2))
}

mean_gini_impurity(190, penguins2)

In [None]:
gini_impurity(penguins2 |> pull(species))

In [None]:
options(repr.plot.height = 2, repr.plot.width = 6)

unique_values = penguins2 |>
    arrange(flipper_length_mm) |>
    pull(flipper_length_mm) |>
    unique() 

tibble::tibble(
    flipper_length_mm = unique_values,
    mean_gini = purrr::map_vec(unique_values, ~mean_gini_impurity(split_point = .x))
) |>
ggplot(aes(x = flipper_length_mm, y = mean_gini)) + 
    geom_point() +
    geom_line()

In [None]:
# decrease in gini impurity
gini_impurity(penguins2 |> pull(species)) - mean_gini_impurity(207, penguins2)

In [None]:
show_engines('rand_forest')

In [None]:
mod = rand_forest(trees = 1000) |>
    set_engine('ranger', importance = 'impurity') |>
    set_mode('classification')

mod_fit = mod |> fit(species ~ ., data = penguins2)

In [None]:
mod_fit

In [None]:
augment(mod_fit, penguins2) |> head()

In [None]:
library('vip')

In [None]:
options(repr.plot.height = 4)

mod_fit |> extract_fit_engine() |> vip(num_features = 25)