Price and Carat

In [2]:
library(tidyverse)
library(tidymodels)
library(kknn)
library(recipes)

── [1mAttaching packages[22m ─────────────────────────────────────── tidyverse 1.3.2 ──
[32m✔[39m [34mggplot2[39m 3.4.2     [32m✔[39m [34mpurrr  [39m 1.0.1
[32m✔[39m [34mtibble [39m 3.2.1     [32m✔[39m [34mdplyr  [39m 1.1.1
[32m✔[39m [34mtidyr  [39m 1.3.0     [32m✔[39m [34mstringr[39m 1.5.0
[32m✔[39m [34mreadr  [39m 2.1.3     [32m✔[39m [34mforcats[39m 0.5.2
── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()
── [1mAttaching packages[22m ────────────────────────────────────── tidymodels 1.0.0 ──

[32m✔[39m [34mbroom       [39m 1.0.2     [32m✔[39m [34mrsample     [39m 1.1.1
[32m✔[39m [34mdials       [39m 1.1.0     [32m✔[39m [34mtune        [39m 1.0.1
[32m✔[39m [34minfer       [39m 1.0.4     [32m✔[39m [34mworkflows   [39m 1.1.2
[32m✔[39

In [None]:
set.seed(5)

diamonds <- read_csv("diamonds.csv")
diamonds_split <- initial_split(diamonds, prop = 0.75, strata = price)
diamonds_train <- training(diamonds_split)
diamonds_test <- testing(diamonds_split)

diam_recipe <- recipe(price ~ carat, data = diamonds_train) |>
  step_scale(all_predictors()) |>
  step_center(all_predictors())

diam_spec <- nearest_neighbor(weight_func = "rectangular", 
                              neighbors = tune()) |>
  set_engine("kknn") |>
  set_mode("regression")

diam_vfold <- vfold_cv(diamonds_train, v = 5, strata = price)

diam_wkflw <- workflow() |>
  add_recipe(diam_recipe) |>
  add_model(diam_spec)

gridvals <- tibble(neighbors = seq(from = 1, to = 200, by = 3))

diam_results <- diam_wkflw |>
  tune_grid(resamples = diam_vfold, grid = gridvals) |>
  collect_metrics() |>
  filter(.metric == "rmse")

diam_min <- diam_results |>
  filter(mean == min(mean))

kmin <- diam_min |> pull(neighbors)

diam_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = kmin) |>
  set_engine("kknn") |>
  set_mode("regression")

diam_fit <- workflow() |>
  add_recipe(diam_recipe) |>
  add_model(diam_spec) |>
  fit(data = diamonds_train)

diam_summary <- diam_fit |>
  predict(diamonds_test) |>
  bind_cols(diamonds_test) |>
  metrics(truth = price, estimate = .pred) |>
  filter(.metric == 'rmse')

carat_prediction_grid <- tibble(
    carat = seq(
        from = diamonds |> select(carat) |> min(),
        to = diamonds |> select(carat) |> max(),
        by = 10))

diam_preds <- diam_fit |>
  predict(carat_prediction_grid) |>
  bind_cols(carat_prediction_grid)

plot_final <- ggplot(diamonds, aes(x = carat, y = price)) +
  geom_point(alpha = 0.4) +
  geom_line(data = diam_preds, 
            mapping = aes(x = carat, y = .pred), 
            color = "blue") +
  xlab("Carat") +
  ylab("Price (USD)") +
  scale_y_continuous(labels = dollar_format()) +
  ggtitle(paste0("K = ", kmin)) + 
  theme(text = element_text(size = 12))
plot_final

── [1mAttaching packages[22m ─────────────────────────────────────── tidyverse 1.3.2 ──
[32m✔[39m [34mggplot2[39m 3.4.2     [32m✔[39m [34mpurrr  [39m 1.0.1
[32m✔[39m [34mtibble [39m 3.2.1     [32m✔[39m [34mdplyr  [39m 1.1.1
[32m✔[39m [34mtidyr  [39m 1.3.0     [32m✔[39m [34mstringr[39m 1.5.0
[32m✔[39m [34mreadr  [39m 2.1.3     [32m✔[39m [34mforcats[39m 0.5.2
── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()
── [1mAttaching packages[22m ────────────────────────────────────── tidymodels 1.0.0 ──

[32m✔[39m [34mbroom       [39m 1.0.2     [32m✔[39m [34mrsample     [39m 1.1.1
[32m✔[39m [34mdials       [39m 1.1.0     [32m✔[39m [34mtune        [39m 1.0.1
[32m✔[39m [34minfer       [39m 1.0.4     [32m✔[39m [34mworkflows   [39m 1.1.2
[32m✔[39

Price and Carat + Clarity