In [1]:
library(tidyverse)
library(tidymodels)

── [1mAttaching core tidyverse packages[22m ──────────────────────── tidyverse 2.0.0 ──
[32m✔[39m [34mdplyr    [39m 1.1.2     [32m✔[39m [34mreadr    [39m 2.1.4
[32m✔[39m [34mforcats  [39m 1.0.0     [32m✔[39m [34mstringr  [39m 1.5.0
[32m✔[39m [34mggplot2  [39m 3.4.2     [32m✔[39m [34mtibble   [39m 3.2.1
[32m✔[39m [34mlubridate[39m 1.9.2     [32m✔[39m [34mtidyr    [39m 1.3.0
[32m✔[39m [34mpurrr    [39m 1.0.1     
── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()
[36mℹ[39m Use the conflicted package ([3m[34m<http://conflicted.r-lib.org/>[39m[23m) to force all conflicts to become errors
── [1mAttaching packages[22m ────────────────────────────────────── tidymodels 1.1.0 ──

[32m✔[39m [34mbroom       [39m 1.0.4     [32m✔[39m [34mrsample     [39

In [11]:
chr_training <- read_csv("data/chr_training.csv") |>
    mutate(label = factor(label))
chr_testing <- read_csv("data//chr_testing.csv") |>
    mutate(label = factor(label))

head(chr_training)
head(chr_testing)

[1mRows: [22m[34m1950[39m [1mColumns: [22m[34m785[39m
[36m──[39m [1mColumn specification[22m [36m────────────────────────────────────────────────────────[39m
[1mDelimiter:[22m ","
[31mchr[39m   (1): label
[32mdbl[39m (784): 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19...

[36mℹ[39m Use `spec()` to retrieve the full column specification for this data.
[36mℹ[39m Specify the column types or set `show_col_types = FALSE` to quiet this message.
[1mRows: [22m[34m650[39m [1mColumns: [22m[34m785[39m
[36m──[39m [1mColumn specification[22m [36m────────────────────────────────────────────────────────[39m
[1mDelimiter:[22m ","
[31mchr[39m   (1): label
[32mdbl[39m (784): 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19...

[36mℹ[39m Use `spec()` to retrieve the full column specification for this data.
[36mℹ[39m Specify the column types or set `show_col_types = FALSE` to quiet this message.


label,1,2,3,4,5,6,7,8,9,⋯,775,776,777,778,779,780,781,782,783,784
<fct>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,⋯,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0


label,1,2,3,4,5,6,7,8,9,⋯,775,776,777,778,779,780,781,782,783,784
<fct>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,⋯,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0
a,0,0,0,0,0,0,0,0,0,⋯,0,0,0,0,0,0,0,0,0,0


In [5]:
chr_knn_recipe <- recipe(label ~ ., data = chr_training)


chr_knn_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = tune()) |>
    set_engine("kknn") |>
    set_mode("classification")

gridvals <- tibble(neighbors = seq(from = 1, to = 50, by = 1))

chr_vfold <- vfold_cv(chr_training, v = 5, strata = label)

chr_wf <- workflow() |>
    add_recipe(chr_knn_recipe) |>
    add_model(chr_knn_spec) |>
    tune_grid(resamples = chr_vfold, grid = gridvals)

chr_result <- chr_wf |>
    collect_metrics() |>
    filter(.metric == "accuracy")

“Too little data to stratify.
[36m•[39m Resampling will be unstratified.”


In [10]:
best_k <- chr_result |>
    arrange(desc(mean)) |>
    head(1) |>
    pull(neighbors)

best_k

In [14]:
best_k_knn_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = best_k) |>
    set_engine("kknn") |>
    set_mode("classification")

chr_best_k_wf <- workflow() |>
    add_recipe(chr_knn_recipe) |>
    add_model(best_k_knn_spec) |>
    fit(data = chr_training)

chr_best_k_result <- predict(chr_best_k_wf, chr_testing) |>
    bind_cols(chr_testing) |>
    metrics(truth = label, estimate = .pred_class)

chr_best_k_result

.metric,.estimator,.estimate
<chr>,<chr>,<dbl>
accuracy,multiclass,0.8138462
kap,multiclass,0.8063132


In [17]:
chr_best_k_result2 <- predict(chr_best_k_wf, chr_testing) |>
    bind_cols(chr_testing) |>
    conf_mat(truth = label, estimate = .pred_class)

chr_best_k_result2

          Truth
Prediction  a  b  c  d  e  f  g  h  i  j  k  l  m  n  o  p  q  r  s  t  u  v  w
         a 16  0  0  0  0  0  0  1  0  0  0  0  2  0  0  0  0  1  1  0  0  0  0
         b  0 14  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
         c  0  0 25  0  2  0  1  0  0  0  0  2  0  0  0  0  0  2  2  0  0  0  0
         d  0  1  0 19  0  0  0  1  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0
         e  0  1  1  0 21  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
         f  0  0  0  0  0 21  0  0  0  0  0  0  0  0  0  2  0  0  0  0  0  0  0
         g  0  2  1  0  0  0 21  0  0  0  0  0  0  0  0  0  1  1  0  0  0  0  0
         h  0  0  0  0  0  0  1 18  0  0  1  0  1  3  0  0  0  0  0  0  1  0  1
         i  0  1  0  0  0  0  0  0 18  2  0  0  0  0  0  0  0  0  0  0  0  0  0
         j  0  0  0  0  0  0  1  0  0 24  0  0  0  0  0  0  1  0  1  1  1  0  0
         k  1  1  0  0  0  0  0  0  0  0 17  0  0  0  0  0  0  1  0  0  0  0  0
         l  0  0  0  0  