# Tutorial: Classification I: training & predicting

This worksheet covers the [Classification I: training & predicting](https://datasciencebook.ca/classification1.html) chapter of the online textbook, which also lists the learning objectives for this worksheet. You should read the textbook chapter before attempting this worksheet. 

In [None]:
### Run this cell before continuing. 
library(tidyverse)
library(repr)
library(tidymodels)
options(repr.matrix.max.rows = 6)
source("cleanup.R") 

**Question 0.1** Multiple Choice: 
<br> {points: 1}

Before applying k-nearest neighbour to a classification task, we need to scale the data. What is the purpose of this step?

A. To help speed up the knn algorithm. 

B. To convert all data observations to numeric values. 

C. To ensure all data observations will be on a comparable scale and contribute equal shares to the calculation of the distance between points.

D. None of the above. 

*Assign your answer to an object called `answer0.1`. Make sure the correct answer is an uppercase letter. Surround your answer with quotation marks (e.g. `"F"`).*

*Note: we typically **standardize** (i.e., scale **and** center) the data before doing classification. For the K-nearest neighbour algorithm specifically, centering has no effect. But it doesn't hurt, and can help with other predictive data analyses, so we will do it below.*

In [None]:
# Replace the fail() with your answer. 

# your code here
fail() # No Answer - remove if you provide an answer

In [None]:
library(digest)
stopifnot("type of answer0.1 is not character"= setequal(digest(paste(toString(class(answer0.1)), "d61f1")), "f94d540efe5be545d1314ed7fc4cef50"))
stopifnot("length of answer0.1 is not correct"= setequal(digest(paste(toString(length(answer0.1)), "d61f1")), "ae230260b0d96065a2fa62385253b54f"))
stopifnot("value of answer0.1 is not correct"= setequal(digest(paste(toString(tolower(answer0.1)), "d61f1")), "6c38fdd008504ba826aad69718a6df26"))
stopifnot("letters in string value of answer0.1 are correct but case is not correct"= setequal(digest(paste(toString(answer0.1), "d61f1")), "b7d893c6d594e358e4eb8b1d6dcb0a91"))

print('Success!')

## 1. Fruit Data Example 

In the agricultural industry, cleaning, sorting, grading, and packaging food products are all necessary tasks in the post-harvest process. Products are classified based on appearance, size and shape, attributes which helps determine the quality of the food. Sorting can be done by humans, but it is tedious and time consuming. Automatic sorting could help save time and money. Images of the food products are captured and analysed to determine visual characteristics. 

The [dataset](https://www.kaggle.com/mjamilmoughal/k-nearest-neighbor-classifier-to-predict-fruits/notebook) contains observations of fruit described with four features 1) mass (in g) 2) width (in cm) 3) height (in cm) and 4) color score (on a scale from 0 - 1).

**Question 1.0** 
<br> {points: 1}

Load the file, `fruit_data.csv`, into your notebook. 

`mutate()` the `fruit_name` column such that it is a *factor* using the `as_factor()` function.

*Assign your data to an object called `fruit_data`.*

In [None]:
# your code here
fail() # No Answer - remove if you provide an answer

In [None]:
library(digest)
stopifnot("fruit_data should be a data frame"= setequal(digest(paste(toString('data.frame' %in% class(fruit_data)), "6764d")), "f175532894bb9421209030becaeadcb2"))
stopifnot("dimensions of fruit_data are not correct"= setequal(digest(paste(toString(dim(fruit_data)), "6764d")), "fe1754c7fb3f3b81dc2837de63ebc404"))
stopifnot("column names of fruit_data are not correct"= setequal(digest(paste(toString(sort(colnames(fruit_data))), "6764d")), "83a8ab9d1385dbd239f8a661ab3cd20c"))
stopifnot("types of columns in fruit_data are not correct"= setequal(digest(paste(toString(sort(unlist(sapply(fruit_data, class)))), "6764d")), "72e73781ee2f973c18613d10eb11a15b"))
stopifnot("values in one or more numerical columns in fruit_data are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_data, is.numeric))) sort(round(sapply(fruit_data[, sapply(fruit_data, is.numeric)], sum, na.rm = TRUE), 2)) else 0), "6764d")), "ee4c9449a0b7b1f9268393ee12a27311"))
stopifnot("values in one or more character columns in fruit_data are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_data, is.character))) sum(sapply(fruit_data[sapply(fruit_data, is.character)], function(x) length(unique(x)))) else 0), "6764d")), "21be8009c291e4cb872b233a9f73bac4"))
stopifnot("values in one or more factor columns in fruit_data are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_data, is.factor))) sum(sapply(fruit_data[, sapply(fruit_data, is.factor)], function(col) length(unique(col)))) else 0), "6764d")), "8d8a9f99753534ac09b71968bff2e312"))

print('Success!')

Let's take a look at the first few observations in the fruit dataset. Run the cell below.

In [None]:
# Run this cell. 
fruit_data

**Question 1.0.1** Multiple Choice:
<br> {points: 1}

**Which of the columns should we treat as categorical variables?**

A. Fruit label, width, fruit subtype

B. Fruit name, color score, height

C. Fruit label, fruit subtype, fruit name

D. Color score, mass, width 

*Assign your answer to an object called `answer1.0.1`. Make sure the correct answer is an uppercase letter. Remember to surround your answer with quotation marks (e.g. `"E"`).*

In [None]:
# Replace the fail() with your answer. 

# your code here
fail() # No Answer - remove if you provide an answer

In [None]:
library(digest)
stopifnot("type of answer1.0.1 is not character"= setequal(digest(paste(toString(class(answer1.0.1)), "3c2da")), "b47c76f4f1cb4e5fbff4a23185a46ccc"))
stopifnot("length of answer1.0.1 is not correct"= setequal(digest(paste(toString(length(answer1.0.1)), "3c2da")), "82fe67f7a4e5ce7e7925b0c0e7b46740"))
stopifnot("value of answer1.0.1 is not correct"= setequal(digest(paste(toString(tolower(answer1.0.1)), "3c2da")), "0b7234c113d259ec6ac8fc77a142c4d2"))
stopifnot("letters in string value of answer1.0.1 are correct but case is not correct"= setequal(digest(paste(toString(answer1.0.1), "3c2da")), "3fe5105ef7f249d26137ee36a4a7d8d1"))

print('Success!')

Run the cell below, and find the nearest neighbour based on mass and width to the first observation just by looking at the scatterplot (the first observation has been circled for you).

In [None]:
# Run this cell. 
options(repr.plot.width=10, repr.plot.height=7)
point1 <- c(192, 8.4)
point2 <- c(180, 8)
point44 <- c(194, 7.2)

fruit_data |>  
    ggplot(aes(x=mass, 
               y= width, 
               colour = fruit_name)) +
        labs(x = "Mass (grams)",
             y = "Width (cm)",
            colour = 'Name of the Fruit') +
        geom_point(size = 2.5) +
        annotate("path", 
                 x=point1[1] + 5*cos(seq(0,2*pi,length.out=100)),
                 y=point1[2] + 0.1*sin(seq(0,2*pi,length.out=100))) +
        annotate("text", x = 183, y =  8.5, label = "1", size = 8) + 
        theme(text = element_text(size = 20))

**Question 1.1** Multiple Choice: 
<br> {points: 1}

Based on the graph generated, what is the `fruit_name` of the closest data point to the one circled?

A. apple

B. lemon

C. mandarin 

D. orange

*Assign your answer to an object called `answer1.1`. Make sure the correct answer is an uppercase letter. Surround your answer with quotation marks (e.g. `"F"`).*

In [None]:
# Replace the fail() with your answer. 

# your code here
fail() # No Answer - remove if you provide an answer

In [None]:
library(digest)
stopifnot("type of answer1.1 is not character"= setequal(digest(paste(toString(class(answer1.1)), "ec118")), "e018a26824269560efa6a7d3a78d3b40"))
stopifnot("length of answer1.1 is not correct"= setequal(digest(paste(toString(length(answer1.1)), "ec118")), "c80a712371da11395e42dde4098a6031"))
stopifnot("value of answer1.1 is not correct"= setequal(digest(paste(toString(tolower(answer1.1)), "ec118")), "8bf8f319c1928cf610a2af9eb67b2d89"))
stopifnot("letters in string value of answer1.1 are correct but case is not correct"= setequal(digest(paste(toString(answer1.1), "ec118")), "b30b411f054d6f0fca2466b0bb5bd880"))

print('Success!')

**Question 1.2**
<br> {points: 1}

Using mass and width, calculate the distance between the first observation and the second observation. 

We provide a scaffolding to get you started. 

*Assign your answer to an object called `fruit_dist_2`.*

In [None]:
# ... <- fruit_data |>
#    slice(1, 2) |> # We use slice to get the first two rows of the fruit dataset
#    select(mass, ...) |>
#    dist()

# your code here
fail() # No Answer - remove if you provide an answer
fruit_dist_2

In [None]:
library(digest)
stopifnot("type of fruit_dist_2[1] is not numeric"= setequal(digest(paste(toString(class(fruit_dist_2[1])), "deafd")), "9ee66dd971961cb9e1eae2aed76b5c2f"))
stopifnot("value of fruit_dist_2[1] is not correct (rounded to 2 decimal places)"= setequal(digest(paste(toString(round(fruit_dist_2[1], 2)), "deafd")), "b8dea29e2c5fa0e1c210605c127b5348"))
stopifnot("length of fruit_dist_2[1] is not correct"= setequal(digest(paste(toString(length(fruit_dist_2[1])), "deafd")), "6f8b9d71d4e86390ee5f0c93753bc718"))
stopifnot("values of fruit_dist_2[1] are not correct"= setequal(digest(paste(toString(sort(round(fruit_dist_2[1], 2))), "deafd")), "b8dea29e2c5fa0e1c210605c127b5348"))

print('Success!')

**Question 1.3**
<br> {points: 1}

Calculate the distance between the first and the the 44th observation in the fruit dataset using the mass and width variables. 

*Assign your answer to an object called `fruit_dist_44`.*

In [None]:
# your code here
fail() # No Answer - remove if you provide an answer
fruit_dist_44

In [None]:
library(digest)
stopifnot("type of fruit_dist_44[1] is not numeric"= setequal(digest(paste(toString(class(fruit_dist_44[1])), "35807")), "28c2a99e7a16848406b583e7bac863ac"))
stopifnot("value of fruit_dist_44[1] is not correct (rounded to 2 decimal places)"= setequal(digest(paste(toString(round(fruit_dist_44[1], 2)), "35807")), "f61a6e4addbfb132f867fb134bfa390b"))
stopifnot("length of fruit_dist_44[1] is not correct"= setequal(digest(paste(toString(length(fruit_dist_44[1])), "35807")), "04f6e44080aac266a698a56ea41744fc"))
stopifnot("values of fruit_dist_44[1] are not correct"= setequal(digest(paste(toString(sort(round(fruit_dist_44[1], 2))), "35807")), "f61a6e4addbfb132f867fb134bfa390b"))

print('Success!')

Let's circle these three observations on the plot from earlier.


In [None]:
options(repr.plot.width = 10, repr.plot.height = 7)

# Run this cell. 
point1 <- c(192, 8.4)
point2 <- c(180, 8)
point44 <- c(194, 7.2)

fruit_data |>
    ggplot(aes(x = mass, 
               y = width, 
               colour = fruit_name)) +
        labs(x = "Mass (grams)",
             y = "Width (cm)",
            colour = 'Name of the Fruit') +
        geom_point(size = 2.5) +
        theme(text = element_text(size = 20)) +
        annotate("path", 
                 x=point1[1] + 5*cos(seq(0,2*pi,length.out=100)),
                 y=point1[2] + 0.1*sin(seq(0,2*pi,length.out=100))) +
        annotate("text", x = 183, y =  8.5, label = "1", size = 8) +
        annotate("path",
                 x=point2[1] + 5*cos(seq(0,2*pi,length.out=100)),
                 y=point2[2] + 0.1*sin(seq(0,2*pi,length.out=100))) +
        annotate("text", x = 169, y =  8.1, label = "2", size = 8) +
        annotate("path",
                 x=point44[1] + 5*cos(seq(0,2*pi,length.out=100)),
                 y=point44[2]+0.1*sin(seq(0,2*pi,length.out=100))) +
        annotate("text", x = 204, y =  7.1, label = "44", size = 8) 

What do you notice about your answers from **Question 1.2 & 1.3** that you just calculated? Is it what you would expect given the scatter plot above? Why or why not? Discuss with your neighbour. 

*Hint: Look at where the observations are on the scatterplot in the cell above this question, and what might happen if we changed grams into kilograms to measure the mass?*


**Question 1.4** Multiple Choice:
<br> {points: 1}

The distance between the first and second observation is 12.01 and the distance between the first and 44th observation is 2.33. By the formula, observation 1 and 44 are closer, however, if we look at the scatterplot the distance of the first observation to the second observation appears closer than to the 44th observation. 

Which of the following statements is correct?

A. A difference of 12 g in mass between observation 1 and 2 is large compared to a difference of 1.2 cm in width between observation 1 and 44. Consequently, mass will drive the classification results, and width will have less of an effect. 

B. If we measured mass in kilograms, then we’d get different nearest neighbours.

C. We should standardize the data so that all variables will be on a comparable scale. 

D. All of the above. 

*Assign your answer to an object called `answer1.4`. Make sure the correct answer is an uppercase letter. Surround your answer with quotation marks (e.g. `"F"`).*

In [None]:
# Replace the fail() with your answer. 

# your code here
fail() # No Answer - remove if you provide an answer

In [None]:
library(digest)
stopifnot("type of answer1.4 is not character"= setequal(digest(paste(toString(class(answer1.4)), "505dc")), "a7b2e369a8e216819f6293c05cf59bfc"))
stopifnot("length of answer1.4 is not correct"= setequal(digest(paste(toString(length(answer1.4)), "505dc")), "b2603ce5a8bb0ffe75b1722120ee0ab0"))
stopifnot("value of answer1.4 is not correct"= setequal(digest(paste(toString(tolower(answer1.4)), "505dc")), "0f575034c94b70bc98758ad6484a0c1e"))
stopifnot("letters in string value of answer1.4 are correct but case is not correct"= setequal(digest(paste(toString(answer1.4), "505dc")), "ba394fcaecde028cf11d8bf3d73e61e0"))

print('Success!')

**Question 1.5**
<br> {points: 1}

Let's create a `tidymodels` recipe to *standardize* (i.e., center and scale) all of the variables in the fruit dataset. Centering will make sure that every variable has an average of 0, and scaling will make sure that every variable has standard deviation of 1. We will use the `step_scale` and `step_center` preprocessing steps in the recipe. Then `bake` the recipe so that we can examine the output.

Specify your recipe with class variable `fruit_name` and predictors `mass`, `width`, `height`, and `color_score`. 

Name the recipe `fruit_data_recipe`, and name the preprocessed data `fruit_data_scaled`.

In [None]:
# Set the seed. Don't remove this!
set.seed(9999) 

#... <- ...(fruit_name ~ .... + .... + .... + ...., data = ....) |>
#                        ....(all_predictors()) |>
#                        ....(all_predictors())

#... <- fruit_data_recipe |>  
#                            ....() |> 
#                            ....(fruit_data)


# your code here
fail() # No Answer - remove if you provide an answer
fruit_data_scaled

In [None]:
library(digest)
stopifnot("fruit_data_recipe should be a recipe"= setequal(digest(paste(toString('recipe' %in% class(fruit_data_recipe)), "9bce6")), "5d82a247ddbab576493febde7018be2a"))
stopifnot("response variable of fruit_data_recipe is not correct"= setequal(digest(paste(toString(sort(filter(fruit_data_recipe$var_info, role == 'outcome')$variable)), "9bce6")), "c219dbca8bbe803676b318735df97471"))
stopifnot("predictor variable(s) of fruit_data_recipe are not correct"= setequal(digest(paste(toString(sort(filter(fruit_data_recipe$var_info, role == 'predictor')$variable)), "9bce6")), "ab9921e4e51f0a0392a6df8e7e9fe1b7"))
stopifnot("fruit_data_recipe does not contain the correct data, might need to be standardized"= setequal(digest(paste(toString(round(sum(bake(prep(fruit_data_recipe), fruit_data_recipe$template) %>% select_if(is.numeric), na.rm = TRUE), 2)), "9bce6")), "3fcab62bbf9763a1aa7a982ead684d6e"))

stopifnot("fruit_data_scaled should be a data frame"= setequal(digest(paste(toString('data.frame' %in% class(fruit_data_scaled)), "9bce7")), "e5a4930da601f01adcf1967974813330"))
stopifnot("dimensions of fruit_data_scaled are not correct"= setequal(digest(paste(toString(dim(fruit_data_scaled)), "9bce7")), "62491514c2edcc12b111560592364040"))
stopifnot("column names of fruit_data_scaled are not correct"= setequal(digest(paste(toString(sort(colnames(fruit_data_scaled))), "9bce7")), "40c1c6dff4ab2ae4644907d8695a578b"))
stopifnot("types of columns in fruit_data_scaled are not correct"= setequal(digest(paste(toString(sort(unlist(sapply(fruit_data_scaled, class)))), "9bce7")), "654a694251224e1d66db241a2fef982d"))
stopifnot("values in one or more numerical columns in fruit_data_scaled are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_data_scaled, is.numeric))) sort(round(sapply(fruit_data_scaled[, sapply(fruit_data_scaled, is.numeric)], sum, na.rm = TRUE), 2)) else 0), "9bce7")), "8e5c82489a1464131549da8c4fc26647"))
stopifnot("values in one or more character columns in fruit_data_scaled are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_data_scaled, is.character))) sum(sapply(fruit_data_scaled[sapply(fruit_data_scaled, is.character)], function(x) length(unique(x)))) else 0), "9bce7")), "e14335c0822091f08f541c2380dc2beb"))
stopifnot("values in one or more factor columns in fruit_data_scaled are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_data_scaled, is.factor))) sum(sapply(fruit_data_scaled[, sapply(fruit_data_scaled, is.factor)], function(col) length(unique(col)))) else 0), "9bce7")), "b39c769e7f2c1b1eaba8ac5a89f31568"))

print('Success!')

**Question 1.6**
<br> {points: 1}

Let's repeat **Question 1.2 and 1.3** with the scaled variables:

- calculate the distance with the scaled mass and width variables between observations 1 and 2
- calculate the distances with the scaled mass and width variables between observations 1 and 44 

After you do this, think about how these distances compared to the distances you computed in **Question 1.2 and 1.3** for the same points.

*Assign your answers to objects called `distance_2` and `distance_44` respectively.*

In [None]:
# your code here
fail() # No Answer - remove if you provide an answer
distance_2
distance_44

In [None]:
library(digest)
stopifnot("type of distance_2[1] is not numeric"= setequal(digest(paste(toString(class(distance_2[1])), "43ed9")), "34dd1c1a80b045d782ed62baadd6eec6"))
stopifnot("value of distance_2[1] is not correct (rounded to 2 decimal places)"= setequal(digest(paste(toString(round(distance_2[1], 2)), "43ed9")), "6710073eab5e5f72ba5aa99c7c283579"))
stopifnot("length of distance_2[1] is not correct"= setequal(digest(paste(toString(length(distance_2[1])), "43ed9")), "8755a22873f91e07a3fbee190fa3a06a"))
stopifnot("values of distance_2[1] are not correct"= setequal(digest(paste(toString(sort(round(distance_2[1], 2))), "43ed9")), "6710073eab5e5f72ba5aa99c7c283579"))

stopifnot("type of distance_44[1] is not numeric"= setequal(digest(paste(toString(class(distance_44[1])), "43eda")), "10688372c13c6c7154848a996387b01f"))
stopifnot("value of distance_44[1] is not correct (rounded to 2 decimal places)"= setequal(digest(paste(toString(round(distance_44[1], 2)), "43eda")), "debd034b3644cf8b7b069285c1231577"))
stopifnot("length of distance_44[1] is not correct"= setequal(digest(paste(toString(length(distance_44[1])), "43eda")), "9c1b1cbc2c825cc0a5f6425708757471"))
stopifnot("values of distance_44[1] are not correct"= setequal(digest(paste(toString(sort(round(distance_44[1], 2))), "43eda")), "debd034b3644cf8b7b069285c1231577"))

print('Success!')

**Question 1.7**
<br> {points: 1}

Make a scatterplot of scaled mass on the horizontal axis and scaled color score on the vertical axis. Color the points by fruit name. 

*Assign your plot to an object called `fruit_plot`. Make sure to do all the things to make an effective visualization.*

In [None]:
# your code here
fail() # No Answer - remove if you provide an answer
fruit_plot

In [None]:
library(digest)
stopifnot("type of plot is not correct (if you are using two types of geoms, try flipping the order of the geom objects!)"= setequal(digest(paste(toString(sapply(seq_len(length(fruit_plot$layers)), function(i) {c(class(fruit_plot$layers[[i]]$geom))[1]})), "60569")), "94ec2a4ac8e46afef12c7401f1e48079"))
stopifnot("variable x is not correct"= setequal(digest(paste(toString(unlist(lapply(sapply(seq_len(length(fruit_plot$layers)), function(i) {rlang::get_expr(c(fruit_plot$layers[[i]]$mapping, fruit_plot$mapping)$x)}), as.character))), "60569")), "04a9bb19408098165ae47c038da6df7e"))
stopifnot("variable y is not correct"= setequal(digest(paste(toString(unlist(lapply(sapply(seq_len(length(fruit_plot$layers)), function(i) {rlang::get_expr(c(fruit_plot$layers[[i]]$mapping, fruit_plot$mapping)$y)}), as.character))), "60569")), "30021d95e370a243ba03f04f532db79c"))
stopifnot("x-axis label is not descriptive, nicely formatted, or human readable"= setequal(digest(paste(toString(rlang::get_expr(c(fruit_plot$layers[[1]]$mapping, fruit_plot$mapping)$x)!= fruit_plot$labels$x), "60569")), "b56993311da6bef2f2d43a9d8b607927"))
stopifnot("y-axis label is not descriptive, nicely formatted, or human readable"= setequal(digest(paste(toString(rlang::get_expr(c(fruit_plot$layers[[1]]$mapping, fruit_plot$mapping)$y)!= fruit_plot$labels$y), "60569")), "b56993311da6bef2f2d43a9d8b607927"))
stopifnot("incorrect colour variable in fruit_plot, specify a correct one if required"= setequal(digest(paste(toString(rlang::get_expr(c(fruit_plot$layers[[1]]$mapping, fruit_plot$mapping)$colour)), "60569")), "58b688015cc8776f13c91ca153be74cf"))
stopifnot("incorrect shape variable in fruit_plot, specify a correct one if required"= setequal(digest(paste(toString(rlang::get_expr(c(fruit_plot$layers[[1]]$mapping, fruit_plot$mapping)$shape)), "60569")), "23865003f0a2648779d5ab229c052aca"))
stopifnot("the colour label in fruit_plot is not descriptive, nicely formatted, or human readable"= setequal(digest(paste(toString(rlang::get_expr(c(fruit_plot$layers[[1]]$mapping, fruit_plot$mapping)$colour) != fruit_plot$labels$colour), "60569")), "b56993311da6bef2f2d43a9d8b607927"))
stopifnot("the shape label in fruit_plot is not descriptive, nicely formatted, or human readable"= setequal(digest(paste(toString(rlang::get_expr(c(fruit_plot$layers[[1]]$mapping, fruit_plot$mapping)$colour) != fruit_plot$labels$shape), "60569")), "23865003f0a2648779d5ab229c052aca"))
stopifnot("fill variable in fruit_plot is not correct"= setequal(digest(paste(toString(quo_name(fruit_plot$mapping$fill)), "60569")), "351a6933e9d53892c5f62983d0a3d3ba"))
stopifnot("fill label in fruit_plot is not informative"= setequal(digest(paste(toString((quo_name(fruit_plot$mapping$fill) != fruit_plot$labels$fill)), "60569")), "23865003f0a2648779d5ab229c052aca"))
stopifnot("position argument in fruit_plot is not correct"= setequal(digest(paste(toString(class(fruit_plot$layers[[1]]$position)[1]), "60569")), "416bdb7085b3d869602fbf2d3a582658"))

stopifnot("fruit_plot$data should be a data frame"= setequal(digest(paste(toString('data.frame' %in% class(fruit_plot$data)), "6056a")), "fb6f6e6f202cf1324f609dd5709119db"))
stopifnot("dimensions of fruit_plot$data are not correct"= setequal(digest(paste(toString(dim(fruit_plot$data)), "6056a")), "31954ec21b8e10831f8cb4aeea0d5f96"))
stopifnot("column names of fruit_plot$data are not correct"= setequal(digest(paste(toString(sort(colnames(fruit_plot$data))), "6056a")), "f8807721570bed893c429a6f18841bb2"))
stopifnot("types of columns in fruit_plot$data are not correct"= setequal(digest(paste(toString(sort(unlist(sapply(fruit_plot$data, class)))), "6056a")), "9a0b02612ebfccceb01dd9b482a1671c"))
stopifnot("values in one or more numerical columns in fruit_plot$data are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_plot$data, is.numeric))) sort(round(sapply(fruit_plot$data[, sapply(fruit_plot$data, is.numeric)], sum, na.rm = TRUE), 2)) else 0), "6056a")), "3161af34c584bcc0d5064cc52a36db3a"))
stopifnot("values in one or more character columns in fruit_plot$data are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_plot$data, is.character))) sum(sapply(fruit_plot$data[sapply(fruit_plot$data, is.character)], function(x) length(unique(x)))) else 0), "6056a")), "2e93ff75226ae8875f0c34fb133d8efb"))
stopifnot("values in one or more factor columns in fruit_plot$data are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_plot$data, is.factor))) sum(sapply(fruit_plot$data[, sapply(fruit_plot$data, is.factor)], function(col) length(unique(col)))) else 0), "6056a")), "1e5ac3f515965e5b6001da33c2b88965"))

print('Success!')

**Question 1.8** 
<br> {points: 3}

Suppose we have a new observation in the fruit dataset with scaled mass 0.5 and scaled color score 0.5.

Just by looking at the scatterplot, how would you classify this observation using K-nearest neighbours if you use K = 3? Explain how you arrived at your answer.

DOUBLE CLICK TO EDIT **THIS CELL** AND REPLACE THIS TEXT WITH YOUR ANSWER.

**Question 1.9**
<br> {points: 1}

Now, let's use the `tidymodels` package to predict `fruit_name` for another new observation. The new observation we are interested in has mass 150g and color score 0.73.

First, create the K-nearest neighbour model specification. Specify that we want $K=5$ neighbors, `set_engine` to be `"kknn"`, and that each neighboring point should have the same weight when voting. Name this model specification as `knn_spec`.

Then create a new recipe named `fruit_data_recipe_2` that centers and scales the predictors, but only uses `mass` and `color_score` as predictors.

Combine this with your recipe from before in a `workflow`, and fit to the `fruit_data` dataset. 

Name the fitted model `fruit_fit`.

In [None]:
# Set the seed. Don't remove this!
set.seed(9999) 

#... <- nearest_neighbor(weight_func = ..., neighbors = ...) |>
#       ...(...) |>
#       ...(...)

#... <- recipe(... ~ ... + ..., data = fruit_data) |>
#                        ...(...) |>
#                        ...(...)

#... <- ...() |>
#          ...(...) |>
#          ...(...) |>
#          fit(data = ...)


# your code here
fail() # No Answer - remove if you provide an answer

In [None]:
library(digest)
stopifnot("knn_spec should be a model specification"= setequal(digest(paste(toString('model_spec' %in% class(knn_spec)), "9295a")), "d325e8a1486075a21464fe828a1d61d2"))
stopifnot("model specification in knn_spec is not correct"= setequal(digest(paste(toString(knn_spec$mode), "9295a")), "8d485021b126d9b2c98439cf4d171bb8"))
stopifnot("computational engine in knn_spec is not correct"= setequal(digest(paste(toString(knn_spec$engine), "9295a")), "f02416665573367711ad15a99432b855"))
stopifnot("weight function in knn_spec is not correct"= setequal(digest(paste(toString(quo_name(knn_spec$args$weight_func)), "9295a")), "dae33928330e022839c148ad6fdb1b58"))
stopifnot("number of neighbours in knn_spec is not correct"= setequal(digest(paste(toString(quo_name(knn_spec$args$neighbors)), "9295a")), "c882304a47c418ab56b5cc35b418598d"))

stopifnot("fruit_data_recipe_2 should be a recipe"= setequal(digest(paste(toString('recipe' %in% class(fruit_data_recipe_2)), "9295b")), "73f0ab1d2654a05e298adf7f42dedaa6"))
stopifnot("response variable of fruit_data_recipe_2 is not correct"= setequal(digest(paste(toString(sort(filter(fruit_data_recipe_2$var_info, role == 'outcome')$variable)), "9295b")), "84be31a307cbdc3f11ffda2e59026964"))
stopifnot("predictor variable(s) of fruit_data_recipe_2 are not correct"= setequal(digest(paste(toString(sort(filter(fruit_data_recipe_2$var_info, role == 'predictor')$variable)), "9295b")), "efce3e81237f1be10c0c514499a0fac1"))
stopifnot("fruit_data_recipe_2 does not contain the correct data, might need to be standardized"= setequal(digest(paste(toString(round(sum(bake(prep(fruit_data_recipe_2), fruit_data_recipe_2$template) %>% select_if(is.numeric), na.rm = TRUE), 2)), "9295b")), "b53d21f65bc5e3eae9d7be4896c341a7"))

stopifnot("fruit_fit should be a workflow"= setequal(digest(paste(toString('workflow' %in% class(fruit_fit)), "9295c")), "402b0dc729a54a147213ec3a4e1079f5"))
stopifnot("computational engine used in fruit_fit is not correct"= setequal(digest(paste(toString(fruit_fit$fit$actions$model$spec$engine), "9295c")), "9ca79d24339a522ae3044ebe1faa279a"))
stopifnot("model specification used in fruit_fit is not correct"= setequal(digest(paste(toString(fruit_fit$fit$actions$model$spec$mode), "9295c")), "2738b9a015dd79cb44965e6874453a7c"))
stopifnot("fruit_fit must be a trained workflow, make sure to call the fit() function"= setequal(digest(paste(toString(fruit_fit$trained), "9295c")), "402b0dc729a54a147213ec3a4e1079f5"))
stopifnot("predictor variable(s) of fruit_fit are not correct"= setequal(digest(paste(toString(sort(filter(fruit_fit$pre$actions$recipe$recipe$var_info, role == 'predictor')$variable)), "9295c")), "5552f70f91293ee363a33f6bf3336362"))
stopifnot("fruit_fit does not contain the correct data"= setequal(digest(paste(toString(sort(vapply(fruit_fit$pre$mold$predictors[, sapply(fruit_fit$pre$mold$predictors, is.numeric)], function(col) if(!is.null(col)) round(sum(col), 2) else NA_real_, numeric(1)), na.last = NA)), "9295c")), "8e2ddae1a2a099dd4a28a2fcb2a929a8"))
stopifnot("did not fit fruit_fit on the training dataset"= setequal(digest(paste(toString(nrow(fruit_fit$pre$mold$outcomes)), "9295c")), "4dcef0e662bc8d71676ab6a6bb713caf"))
stopifnot("for classification/regression models, weight function is not correct"= setequal(digest(paste(toString(quo_name(fruit_fit$fit$actions$model$spec$args$weight_func)), "9295c")), "2b89fe55646f1dbbf4f5c036405b2429"))
stopifnot("for classification/regression models, response variable of fruit_fit is not correct"= setequal(digest(paste(toString(sort(filter(fruit_fit$pre$actions$recipe$recipe$var_info, role == 'outcome')$variable)), "9295c")), "c883a18529fcd1e841993f91b02165ed"))
stopifnot("for KNN models, number of neighbours is not correct"= setequal(digest(paste(toString(quo_name(fruit_fit$fit$actions$model$spec$args$neighbors)), "9295c")), "5891bc82cddff45d93a7ecbbc0174bf8"))
stopifnot("for clustering models, the clustering is not correct"= setequal(digest(paste(toString(fruit_fit$fit$fit$fit$cluster), "9295c")), "35f06f364bc813d04ea85300403e701b"))
stopifnot("for clustering models, the total within-cluster sum-of-squared distances is not correct"= setequal(digest(paste(toString(if (!is.null(fruit_fit$fit$fit$fit$tot.withinss)) round(fruit_fit$fit$fit$fit$tot.withinss, 2) else NULL), "9295c")), "35f06f364bc813d04ea85300403e701b"))

print('Success!')

**Question 1.10**
<br> {points: 1}

Create a new tibble where `mass = 150` and `color_score = 0.73` and call it `new_fruit`. Then, pass `fruit_fit` and `new_fruit` to the `predict` function to predict the class for the new fruit observation. Save your prediction to an object named `fruit_predicted`.

In [None]:
# your code here
fail() # No Answer - remove if you provide an answer

In [None]:
library(digest)
stopifnot("new_fruit should be a data frame"= setequal(digest(paste(toString('data.frame' %in% class(new_fruit)), "aab5")), "cf144a819f45480177dd0b4b8500441f"))
stopifnot("dimensions of new_fruit are not correct"= setequal(digest(paste(toString(dim(new_fruit)), "aab5")), "91f415d9fb5387620589456c6e607f50"))
stopifnot("column names of new_fruit are not correct"= setequal(digest(paste(toString(sort(colnames(new_fruit))), "aab5")), "0d826d006e69eaa8d883fc696c08230c"))
stopifnot("types of columns in new_fruit are not correct"= setequal(digest(paste(toString(sort(unlist(sapply(new_fruit, class)))), "aab5")), "35477eb38d90db941ef1ac720168c45a"))
stopifnot("values in one or more numerical columns in new_fruit are not correct"= setequal(digest(paste(toString(if (any(sapply(new_fruit, is.numeric))) sort(round(sapply(new_fruit[, sapply(new_fruit, is.numeric)], sum, na.rm = TRUE), 2)) else 0), "aab5")), "e32c5ba41b7a27f770e97b2728994791"))
stopifnot("values in one or more character columns in new_fruit are not correct"= setequal(digest(paste(toString(if (any(sapply(new_fruit, is.character))) sum(sapply(new_fruit[sapply(new_fruit, is.character)], function(x) length(unique(x)))) else 0), "aab5")), "819c947572afc2fd3966e5f7ed77b233"))
stopifnot("values in one or more factor columns in new_fruit are not correct"= setequal(digest(paste(toString(if (any(sapply(new_fruit, is.factor))) sum(sapply(new_fruit[, sapply(new_fruit, is.factor)], function(col) length(unique(col)))) else 0), "aab5")), "819c947572afc2fd3966e5f7ed77b233"))

stopifnot("fruit_predicted should be a data frame"= setequal(digest(paste(toString('data.frame' %in% class(fruit_predicted)), "aab6")), "35526846fb5221c24808d98024bc9c47"))
stopifnot("dimensions of fruit_predicted are not correct"= setequal(digest(paste(toString(dim(fruit_predicted)), "aab6")), "3d489b777252ff787fc01898cae2a8fe"))
stopifnot("column names of fruit_predicted are not correct"= setequal(digest(paste(toString(sort(colnames(fruit_predicted))), "aab6")), "2909814820c0b2bdec363da09a3a52ea"))
stopifnot("types of columns in fruit_predicted are not correct"= setequal(digest(paste(toString(sort(unlist(sapply(fruit_predicted, class)))), "aab6")), "6646df19a76d931c97f7cfffd4a5274d"))
stopifnot("values in one or more numerical columns in fruit_predicted are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_predicted, is.numeric))) sort(round(sapply(fruit_predicted[, sapply(fruit_predicted, is.numeric)], sum, na.rm = TRUE), 2)) else 0), "aab6")), "cbac36c78d10ef788d2d1802564e382b"))
stopifnot("values in one or more character columns in fruit_predicted are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_predicted, is.character))) sum(sapply(fruit_predicted[sapply(fruit_predicted, is.character)], function(x) length(unique(x)))) else 0), "aab6")), "cbac36c78d10ef788d2d1802564e382b"))
stopifnot("values in one or more factor columns in fruit_predicted are not correct"= setequal(digest(paste(toString(if (any(sapply(fruit_predicted, is.factor))) sum(sapply(fruit_predicted[, sapply(fruit_predicted, is.factor)], function(col) length(unique(col)))) else 0), "aab6")), "cdf33cdb8985e5cebd00d83feee1cc2f"))

print('Success!')

**Question 1.11** 
<br> {points: 3}

Revisiting `fruit_plot` and considering the prediction given by K-nearest neighbours above, do you think the classification model did a "good" job predicting? Could you have done/do better? Given what we know this far in the course, what might we want to do to help with tricky prediction cases such as this?

*You can use the code below to visualize the observation whose label we just tried to predict.*

In [None]:
options(repr.plot.width = 10, repr.plot.height = 7) # you can change the plot size 

fruit_plot + 
    geom_point(aes(x = -0.3, y = -0.4), color = "black", size = 4)

DOUBLE CLICK TO EDIT **THIS CELL** AND REPLACE THIS TEXT WITH YOUR ANSWER.

**Question 1.12**
<br> {points: 1}

Now do K-nearest neighbours classification again with the same data set, same K, and same new observation. However, this time, let's use **all the columns in the dataset as predictors (except for the categorical `fruit_label` and `fruit_subtype` variables).**

We have provided the `new_fruit_all` dataframe below, which encodes the predictors for our new observation. Your job is to use K-nearest neighbours to predict the class of this point. You can reuse the model specification you created earlier. 

*Assign your answer (the output of `predict`) to an object called `fruit_all_predicted`.*

In [None]:
#This is the new observation to predict class label for
new_fruit_all <- tibble(mass = 150, 
                            color_score = 0.73, 
                            height = 10,
                            width = 6)


# no hints this time!

# your code here
fail() # No Answer - remove if you provide an answer
fruit_all_predicted

In [None]:
library(digest)
stopifnot("type of as.character(fruit_all_predicted$.pred_class) is not character"= setequal(digest(paste(toString(class(as.character(fruit_all_predicted$.pred_class))), "bd9c1")), "ca3f1dff25e5c4b9780551556458f2c7"))
stopifnot("length of as.character(fruit_all_predicted$.pred_class) is not correct"= setequal(digest(paste(toString(length(as.character(fruit_all_predicted$.pred_class))), "bd9c1")), "90e58a86992e919eab0edeaa4c612f7a"))
stopifnot("value of as.character(fruit_all_predicted$.pred_class) is not correct"= setequal(digest(paste(toString(tolower(as.character(fruit_all_predicted$.pred_class))), "bd9c1")), "98ba8259d3f692c01db77d3b7062eb46"))
stopifnot("letters in string value of as.character(fruit_all_predicted$.pred_class) are correct but case is not correct"= setequal(digest(paste(toString(as.character(fruit_all_predicted$.pred_class)), "bd9c1")), "98ba8259d3f692c01db77d3b7062eb46"))

print('Success!')

**Question 1.13** 
<br> {points: 3}

Did your second classification on the same data set with the same K change the prediction? If so, why do you think this happened?

DOUBLE CLICK TO EDIT **THIS CELL** AND REPLACE THIS TEXT WITH YOUR ANSWER.

## 2. Wheat Seed Dataset

X-ray images can be used to analyze and sort seeds. In [this data set](https://archive.ics.uci.edu/ml/datasets/seeds), we have 7 measurements from x-ray images from 3 varieties of wheat seeds (Kama, Rosa and Canadian). 

**Question 2.0**
<br> {points: 3}

Let's use `tidymodels` to perform K-nearest neighbours to classify the wheat variety of seeds. The data set is available here: https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt. **Download the data set directly from this URL using the `read_table()` function**, which is helpful when the columns are separated by one or more white spaces.

The seven measurements were taken below for each wheat kernel:
1. area A, 
2. perimeter P, 
3. compactness C = 4*pi*A/P^2, 
4. length of kernel, 
5. width of kernel, 
6. asymmetry coefficient 
7. length of kernel groove. 

The last column in the data set is the variety label. The mapping for the numbers to varieties is listed below:

- 1 == Kama
- 2 == Rosa
- 3 == Canadian

Use `tidymodels` with this data to perform K-nearest neighbours to classify the wheat variety of a new seed we measure with the given observed measurements (from an x-ray image) listed above. Specify that we want $K = 5$ neighbors to perform the classification. Don't forget to perform any necessary preprocessing!

*Assign your answer to an object called `seed_predict`.*

Hints: 
- `colnames()` can be used to specify the column names of a data frame.
- the wheat variety column appears numerical, but you want it to be treated as categorical for this analysis, thus `as_factor()` might be helpful.

In [None]:
# Set the seed. Don't remove this!
set.seed(9999) 

#This is the new observation to predict
new_seed <- tibble(area = 12.1,
                        perimeter = 14.2,
                        compactness = 0.9,
                        length = 4.9,
                        width = 2.8,
                        asymmetry_coefficient = 3.0, 
                        groove_length = 5.1)

# your code here
fail() # No Answer - remove if you provide an answer

**Question 2.1** Multiple Choice:
<br> {points: 1}

What is classification of the `new_seed` observation?

A. Kama

B. Rosa

C. Canadian

*Assign your answer to an object called `answer2.1`. Make sure your answer is in uppercase and is surrounded by quotation marks (e.g. `"F"`).*


In [None]:
# your code here
fail() # No Answer - remove if you provide an answer

In [None]:
library(digest)
stopifnot("type of answer2.1 is not character"= setequal(digest(paste(toString(class(answer2.1)), "aef4d")), "1354deb3a629eb26b07cb0d05cc95549"))
stopifnot("length of answer2.1 is not correct"= setequal(digest(paste(toString(length(answer2.1)), "aef4d")), "338b9169d38a74e9deb1549b98ed1611"))
stopifnot("value of answer2.1 is not correct"= setequal(digest(paste(toString(tolower(answer2.1)), "aef4d")), "a97bdd5997909488c1ff7f3510d320ab"))
stopifnot("letters in string value of answer2.1 are correct but case is not correct"= setequal(digest(paste(toString(answer2.1), "aef4d")), "d01e051b1c6abaa2e2d3c6afa267a22b"))

print('Success!')

In [None]:
source("cleanup.R")