# Tutorial 6: Classification

### Lecture and Tutorial Learning Goals:

After completing this week's lecture and tutorial work, you will be able to:

* Recognize situations where a simple classifier would be appropriate for making predictions.
* Explain the k-nearest neighbour classification algorithm.
* Interpret the output of a classifier.
* Compute, by hand, the distance between points when there are two explanatory variables/predictors.
* Describe what a training data set is and how it is used in classification.
* In a dataset with two explanatory variables/predictors, perform k-nearest neighbour classification in Python using `scikit-learn` to predict the class of a single new observation.

In [None]:
### Run this cell before continuing.
import random

import altair as alt
import pandas as pd
import numpy as np
import sklearn
from sklearn.compose import make_column_transformer
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler

alt.data_transformers.disable_max_rows()

**Question 0.1** Multiple Choice: 
<br> {points: 1}

Before applying k-nearest neighbour to a classification task, we need to scale the data. What is the purpose of this step?

A. To help speed up the knn algorithm. 

B. To convert all data observations to numeric values. 

C. To ensure all data observations will be on a comparable scale and contribute equal shares to the calculation of the distance between points.

D. None of the above. 

*Assign your answer to an object called `answer0_1`. Make sure the correct answer is an uppercase letter. Surround your answer with quotation marks (e.g. `"F"`).*

*Note: we typically **standardize** (i.e., scale **and** center) the data before doing classification. For the K-nearest neighbour algorithm specifically, centering has no effect. But it doesn't hurt, and can help with other predictive data analyses, so we will do it below.*

In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(answer0_1)).encode("utf-8")+b"446dfe89269717c0").hexdigest() == "150d32139df34a5ea7b21b18e4f49ac80e2a0c21", "type of answer0_1 is not str. answer0_1 should be an str"
assert sha1(str(len(answer0_1)).encode("utf-8")+b"446dfe89269717c0").hexdigest() == "6ea6ccdb0dec838e73504a53fb7e9d22be4c9d01", "length of answer0_1 is not correct"
assert sha1(str(answer0_1.lower()).encode("utf-8")+b"446dfe89269717c0").hexdigest() == "dbd2d53ea4dd66a6107f10cd20d93fee85352571", "value of answer0_1 is not correct"
assert sha1(str(answer0_1).encode("utf-8")+b"446dfe89269717c0").hexdigest() == "625a3b5dc049596396c42d114a1a57d6b532e701", "correct string value of answer0_1 but incorrect case of letters"

print('Success!')

## 1. Fruit Data Example 

In the agricultural industry, cleaning, sorting, grading, and packaging food products are all necessary tasks in the post-harvest process. Products are classified based on appearance, size and shape, attributes which helps determine the quality of the food. Sorting can be done by humans, but it is tedious and time consuming. Automatic sorting could help save time and money. Images of the food products are captured and analysed to determine visual characteristics. 

The [dataset](https://www.kaggle.com/mjamilmoughal/k-nearest-neighbor-classifier-to-predict-fruits/notebook) contains observations of fruit described with four features: (1) mass (in g), (2) width (in cm), (3) height (in cm), and (4) color score (on a scale from 0 - 1).

**Question 1.0** 
<br> {points: 1}

Load the file, `fruit_data.csv`, into your notebook. 

*Assign your data to an object called `fruit_data`.*

In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_data is None)).encode("utf-8")+b"f9e47e031b6e996e").hexdigest() == "8a76cfb19a4aefc72985b960ffabedb26f39d35d", "type of fruit_data is None is not bool. fruit_data is None should be a bool"
assert sha1(str(fruit_data is None).encode("utf-8")+b"f9e47e031b6e996e").hexdigest() == "2ffa809a3b7dd033ce1836d16916176daaf94cbb", "boolean value of fruit_data is None is not correct"

assert sha1(str(type(fruit_data.shape)).encode("utf-8")+b"f5e435f983057466").hexdigest() == "c63e2a141f6a8c861227e7dced76347252667e75", "type of fruit_data.shape is not tuple. fruit_data.shape should be a tuple"
assert sha1(str(len(fruit_data.shape)).encode("utf-8")+b"f5e435f983057466").hexdigest() == "ff85e2cb6e0a09eb0a90b797074b8c5b74ff88f5", "length of fruit_data.shape is not correct"
assert sha1(str(sorted(map(str, fruit_data.shape))).encode("utf-8")+b"f5e435f983057466").hexdigest() == "515b62ce68f845edaead7200219f91253b24fa22", "values of fruit_data.shape are not correct"
assert sha1(str(fruit_data.shape).encode("utf-8")+b"f5e435f983057466").hexdigest() == "910e8570605ce3ffb236104c46ec701eb25e1f3b", "order of elements of fruit_data.shape is not correct"

assert sha1(str(type(fruit_data.fruit_name.dtype)).encode("utf-8")+b"0419694a9f7e91b5").hexdigest() == "e701cd35cc504e09e0b38ade1e32edaa3c049ced", "type of fruit_data.fruit_name.dtype is not correct"
assert sha1(str(fruit_data.fruit_name.dtype).encode("utf-8")+b"0419694a9f7e91b5").hexdigest() == "1fa6b04330c2782bb9362fcbaba556b989a8c7be", "value of fruit_data.fruit_name.dtype is not correct"

assert sha1(str(type(fruit_data.fruit_name.unique())).encode("utf-8")+b"9e33f63bdc8e9854").hexdigest() == "4a8d4c28c7932113685a1c5515f11b9fedd501ba", "type of fruit_data.fruit_name.unique() is not correct"
assert sha1(str(fruit_data.fruit_name.unique()).encode("utf-8")+b"9e33f63bdc8e9854").hexdigest() == "5ba8b9e612d08b5dd8fdc0e0af53ab688443bdd6", "value of fruit_data.fruit_name.unique() is not correct"

assert sha1(str(type(fruit_data.mass.values)).encode("utf-8")+b"0570f4697871177a").hexdigest() == "4ecd5ba72c127179ea967c6574857a9341839bb0", "type of fruit_data.mass.values is not correct"
assert sha1(str(fruit_data.mass.values).encode("utf-8")+b"0570f4697871177a").hexdigest() == "ce8e21db1a6c4eac0519ce13c313f6a6254de6ad", "value of fruit_data.mass.values is not correct"

print('Success!')

Let's take a look at the first few observations in the fruit dataset. Run the cell below.

In [None]:
# Run this cell.
fruit_data.head()

**Question 1.0.1** Multiple Choice:
<br> {points: 1}

**Which of the columns should we treat as categorical variables?**

A. Fruit label, width, fruit subtype

B. Fruit name, color score, height

C. Fruit label, fruit subtype, fruit name

D. Color score, mass, width 

*Assign your answer to an object called `answer1_0_1`. Make sure the correct answer is an uppercase letter. Remember to surround your answer with quotation marks (e.g. `"E"`).*

In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(answer1_0_1)).encode("utf-8")+b"6d732f793df60f86").hexdigest() == "45a78cb7c5ff5fcad6569ec3fc7163bcff1167f0", "type of answer1_0_1 is not str. answer1_0_1 should be an str"
assert sha1(str(len(answer1_0_1)).encode("utf-8")+b"6d732f793df60f86").hexdigest() == "6d9abebce94f2e135d540f419831444b1c5be76b", "length of answer1_0_1 is not correct"
assert sha1(str(answer1_0_1.lower()).encode("utf-8")+b"6d732f793df60f86").hexdigest() == "a9e55b2dd111b7aa0b5b35a48aa2ae4ccbacf1d9", "value of answer1_0_1 is not correct"
assert sha1(str(answer1_0_1).encode("utf-8")+b"6d732f793df60f86").hexdigest() == "bad2f5612d3cf707d510cc8f1946af4bf6512e29", "correct string value of answer1_0_1 but incorrect case of letters"

print('Success!')

Run the cell below, and find the nearest neighbour based on mass and width to the first observation just by looking at the scatterplot (the first observation has been circled for you).

In [None]:
# Run this cell.
point1 = [192, 8.4]
point2 = [180, 8]
point44 = [194, 7.2]

fruit_chart = (
    alt.Chart(fruit_data)
    .mark_point(size=15)
    .encode(
        x=alt.X("mass", title="Mass (grams)"),
        y=alt.Y("width", title="Width (cm)", scale=alt.Scale(zero=False)),
        color=alt.Color("fruit_name", title="Name of the Fruit"),
    )
)

(
    fruit_chart
    + alt.Chart(pd.DataFrame([[192, 8.4]], columns=["x", "y"]))
    .mark_point(size=150)
    .encode(x="x", y="y", color=alt.value("black"))
    + alt.Chart(pd.DataFrame([[1, 183, 8.5]], columns=["text", "x", "y"]))
    .mark_text(size=15)
    .encode(x="x", y="y", text="text", color=alt.value("black"))
).configure_axis(labelFontSize=20, titleFontSize=20).configure_legend(
    titleFontSize=15, labelFontSize=15
).properties(
    width=400, height=300
)

**Question 1.1** Multiple Choice: 
<br> {points: 1}

Based on the graph generated, what is the `fruit_name` of the closest data point to the one circled?

A. apple

B. lemon

C. mandarin 

D. orange

*Assign your answer to an object called `answer1_1`. Make sure the correct answer is an uppercase letter. Surround your answer with quotation marks (e.g. `"F"`).*

In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(answer1_1)).encode("utf-8")+b"0781edf709d6550b").hexdigest() == "b1bf47fe9abb56d6f8dfc5d62939c46dd3fae00b", "type of answer1_1 is not str. answer1_1 should be an str"
assert sha1(str(len(answer1_1)).encode("utf-8")+b"0781edf709d6550b").hexdigest() == "fa88fb5c25a8028ee793c3a15f0a0a7621658ba5", "length of answer1_1 is not correct"
assert sha1(str(answer1_1.lower()).encode("utf-8")+b"0781edf709d6550b").hexdigest() == "88cf147f1c2ed843536e047af7a53a44c28a6a9a", "value of answer1_1 is not correct"
assert sha1(str(answer1_1).encode("utf-8")+b"0781edf709d6550b").hexdigest() == "a69743f0650f345edcf6a62d3051af3009c32fc2", "correct string value of answer1_1 but incorrect case of letters"

print('Success!')

**Question 1.2**
<br> {points: 1}

Using mass and width, calculate the distance between the first observation and the second observation with the `euclidean_distances` function. 

We provide a scaffolding to get you started. 

*Assign your answer to an object called `fruit_dist_2`.*

In [None]:
# ___ = euclidean_distances(
#     fruit_data.loc[0:1, ["mass", ___]]
# )

# your code here
raise NotImplementedError
fruit_dist_2

In [None]:
from hashlib import sha1
assert str(type(fruit_dist_2)) == "<class 'numpy.ndarray'>", "type of fruit_dist_2 is not correct"
assert str(fruit_dist_2) == "[[ 0.         12.00666482]\n [12.00666482  0.        ]]", "value of fruit_dist_2 is not correct"

print('Success!')

**Question 1.3**
<br> {points: 1}

Calculate the distance between the first and the the 44th observation in the fruit dataset using the mass and width variables. 

*Hint: remember that in Python, index starts from 0, so the 44th observation in Pandas Dataframe corresponds to index 43*

*Assign your answer to an object called `fruit_dist_44`.*

In [None]:
# your code here
raise NotImplementedError
fruit_dist_44

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_dist_44)).encode("utf-8")+b"5a253c87c5a9d447").hexdigest() == "0868878f9f32e6a450229903e9d416c10cf66699", "type of fruit_dist_44 is not correct"
assert sha1(str(fruit_dist_44).encode("utf-8")+b"5a253c87c5a9d447").hexdigest() == "f45cfaf9e84b423072d68cd71089d0c0d3c96b1f", "value of fruit_dist_44 is not correct"

print('Success!')

Let's circle these three observations on the plot from earlier.


In [None]:
# Run this cell.
point1 = [192, 8.4]
point2 = [180, 8]
point44 = [194, 7.2]

(
    fruit_chart
    + alt.Chart(
        pd.DataFrame([[192, 8.4], [180, 8.0], [193.5, 7.2]], columns=["x", "y"])
    )
    .mark_point(size=150)
    .encode(x="x", y="y", color=alt.value("black"))
    + alt.Chart(
        pd.DataFrame(
            [[1, 183, 8.5], [2, 169, 8.1], [44, 204, 7.1]], columns=["text", "x", "y"]
        )
    )
    .mark_text(size=15)
    .encode(x="x", y="y", text="text", color=alt.value("black"))
).configure_axis(labelFontSize=20, titleFontSize=20).configure_legend(
    titleFontSize=15, labelFontSize=15
).properties(width=400, height=300)

What do you notice about your answers from **Question 1.2 & 1.3** that you just calculated? Is it what you would expect given the scatter plot above? Why or why not? Discuss with your neighbour. 

*Hint: Look at where the observations are on the scatterplot in the cell above this question, and what might happen if we changed grams into kilograms to measure the mass?*


**Question 1.4** Multiple Choice:
<br> {points: 1}

The distance between the first and second observation is 12.01 and the distance between the first and 44th observation is 2.33. By the formula, observation 1 and 44 are closer, however, if we look at the scatterplot the distance of the first observation to the second observation appears closer than to the 44th observation. 

Which of the following statements is correct?

A. A difference of 12 g in mass between observation 1 and 2 is large compared to a difference of 1.2 cm in width between observation 1 and 44. Consequently, mass will drive the classification results, and width will have less of an effect. 

B. If we measured mass in kilograms, then we’d get different nearest neighbours.

C. We should standardize the data so that all variables will be on a comparable scale. 

D. All of the above. 

*Assign your answer to an object called `answer1_4`. Make sure the correct answer is an uppercase letter. Surround your answer with quotation marks (e.g. `"F"`).*

In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(answer1_4)).encode("utf-8")+b"3592d870fb3df059").hexdigest() == "be2bfab502f9eb6570070b5ba47505e1b8fdbfd9", "type of answer1_4 is not str. answer1_4 should be an str"
assert sha1(str(len(answer1_4)).encode("utf-8")+b"3592d870fb3df059").hexdigest() == "8eada8ebfeed63cd0dee1d9bfb77308ab821c206", "length of answer1_4 is not correct"
assert sha1(str(answer1_4.lower()).encode("utf-8")+b"3592d870fb3df059").hexdigest() == "61776c5aa7adca3656ea04aaee659914e55ea515", "value of answer1_4 is not correct"
assert sha1(str(answer1_4).encode("utf-8")+b"3592d870fb3df059").hexdigest() == "65a6e1e05b568caa5f02c37b01bc2eb7b60037ef", "correct string value of answer1_4 but incorrect case of letters"

print('Success!')

**Question 1.5**
<br> {points: 1}

Let's create a `preprocessor` to *standardize* (i.e., center and scale) all of the variables in the fruit dataset. Centering will make sure that every variable has an average of 0, and scaling will make sure that every variable has standard deviation of 1. We will use the `StandardScaler` in the `preprocessor`. Then `fit_transform` the preprocessor so that we can examine the output.

Fit and transform your preprocessor with predictors `mass`, `width`, `height`, and `color_score`. For other columns, we use `passthrough` in the preprocessor.

Name the preprocessor `fruit_data_preprocessor`, and name the preprocessed data `fruit_data_scaled`.

*Note that we would save the preprocessed data into a dataframe for upcoming exercises.*

In [None]:
# ___ = ___(
#     (
#         "passthrough",
#         [
#             ___,
#             ___,
#             ___,
#         ],
#     ),
#     (StandardScaler(), [___, ___, ___, ___]),
# )
# ___ = pd.DataFrame(
#     fruit_data_preprocessor.___(___),
#     columns=[
#         "fruit_label",
#         "fruit_name",
#         "fruit_subtype",
#         "mass",
#         "width",
#         "height",
#         "color_score",
#     ],
# )

# your code here
raise NotImplementedError
fruit_data_scaled.head()

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_data_scaled is None)).encode("utf-8")+b"8bae2b657ab2b05b").hexdigest() == "63e10969cd0f405c4aee5ffc9a5ca4c44c8c9ba3", "type of fruit_data_scaled is None is not bool. fruit_data_scaled is None should be a bool"
assert sha1(str(fruit_data_scaled is None).encode("utf-8")+b"8bae2b657ab2b05b").hexdigest() == "cd37ee211a08ca65eaf42e7e44ff994db3d2f004", "boolean value of fruit_data_scaled is None is not correct"

assert sha1(str(type(fruit_data_scaled.shape)).encode("utf-8")+b"ac72f3be9f0d64fa").hexdigest() == "f08fe0899e90b9c30260ed585eb886394b88e10c", "type of fruit_data_scaled.shape is not tuple. fruit_data_scaled.shape should be a tuple"
assert sha1(str(len(fruit_data_scaled.shape)).encode("utf-8")+b"ac72f3be9f0d64fa").hexdigest() == "416a2b6b161619b45aa5c5fb7de4b05f65a86436", "length of fruit_data_scaled.shape is not correct"
assert sha1(str(sorted(map(str, fruit_data_scaled.shape))).encode("utf-8")+b"ac72f3be9f0d64fa").hexdigest() == "c1a6e6d951f0feaaaebfef75258ce5b5a23e49eb", "values of fruit_data_scaled.shape are not correct"
assert sha1(str(fruit_data_scaled.shape).encode("utf-8")+b"ac72f3be9f0d64fa").hexdigest() == "8471f6a8a6ae496d2660548a9eef1c1152b1dec3", "order of elements of fruit_data_scaled.shape is not correct"

assert sha1(str(type(fruit_data_scaled.fruit_name.dtype)).encode("utf-8")+b"c78bf5438950f73e").hexdigest() == "9da722098f61b619d1cbe7a778029f1292ad2799", "type of fruit_data_scaled.fruit_name.dtype is not correct"
assert sha1(str(fruit_data_scaled.fruit_name.dtype).encode("utf-8")+b"c78bf5438950f73e").hexdigest() == "9a9993631efdb37bb1eda6e65eacca8833306a61", "value of fruit_data_scaled.fruit_name.dtype is not correct"

assert sha1(str(type(np.mean(fruit_data_scaled.mass.dropna()))).encode("utf-8")+b"0f9fb26701d359d8").hexdigest() == "fd1cd55056166a8ea98aae9113c7593158579e7c", "type of np.mean(fruit_data_scaled.mass.dropna()) is not correct"
assert sha1(str(np.mean(fruit_data_scaled.mass.dropna())).encode("utf-8")+b"0f9fb26701d359d8").hexdigest() == "9e73c9307979b00671c50b4b482b06fd3e51cbd5", "value of np.mean(fruit_data_scaled.mass.dropna()) is not correct"

assert sha1(str(type(np.mean(fruit_data_scaled.height.dropna()))).encode("utf-8")+b"8d0e6a59500e58eb").hexdigest() == "994dbba05de88d2dd73f47c09e890423ccc80eee", "type of np.mean(fruit_data_scaled.height.dropna()) is not correct"
assert sha1(str(np.mean(fruit_data_scaled.height.dropna())).encode("utf-8")+b"8d0e6a59500e58eb").hexdigest() == "8833f854a3e292e371fb30d70aeccf75181c751e", "value of np.mean(fruit_data_scaled.height.dropna()) is not correct"

assert sha1(str(type(np.mean(fruit_data_scaled.width.dropna()))).encode("utf-8")+b"b002a24b12f12fe0").hexdigest() == "45f8a0043c1c210433a11cfaa87d06813b53f8b4", "type of np.mean(fruit_data_scaled.width.dropna()) is not correct"
assert sha1(str(np.mean(fruit_data_scaled.width.dropna())).encode("utf-8")+b"b002a24b12f12fe0").hexdigest() == "03c3028f99b430fbea07705e5dfc0cca991ca6a5", "value of np.mean(fruit_data_scaled.width.dropna()) is not correct"

assert sha1(str(type(np.mean(fruit_data_scaled.color_score.dropna()))).encode("utf-8")+b"e83fb670fbb4026e").hexdigest() == "fe19a4f7a8632b7b13c81fb8e3e175d633dbe09e", "type of np.mean(fruit_data_scaled.color_score.dropna()) is not correct"
assert sha1(str(np.mean(fruit_data_scaled.color_score.dropna())).encode("utf-8")+b"e83fb670fbb4026e").hexdigest() == "1229ea910b41a2b2303138bf23ea8cef24c404ce", "value of np.mean(fruit_data_scaled.color_score.dropna()) is not correct"

assert sha1(str(type(np.std(fruit_data_scaled.mass.dropna()))).encode("utf-8")+b"b5d9318608624271").hexdigest() == "0b0836338f74739a67a0b6d011bc80a739c2e8e1", "type of np.std(fruit_data_scaled.mass.dropna()) is not correct"
assert sha1(str(np.std(fruit_data_scaled.mass.dropna())).encode("utf-8")+b"b5d9318608624271").hexdigest() == "df180d6abd7c462a1daa8ff05a4e1fae8eba951c", "value of np.std(fruit_data_scaled.mass.dropna()) is not correct"

assert sha1(str(type(np.std(fruit_data_scaled.height.dropna()))).encode("utf-8")+b"104af07c7bab237b").hexdigest() == "0ad652ec51010523d5c7069b99f4b022a4d1f1be", "type of np.std(fruit_data_scaled.height.dropna()) is not correct"
assert sha1(str(np.std(fruit_data_scaled.height.dropna())).encode("utf-8")+b"104af07c7bab237b").hexdigest() == "084b8411572633531d2f3a5e8c3e0350421cc02f", "value of np.std(fruit_data_scaled.height.dropna()) is not correct"

assert sha1(str(type(np.std(fruit_data_scaled.width.dropna()))).encode("utf-8")+b"0b960402c1535b69").hexdigest() == "ed7427d81c9dbdb2160ec16c81e38838c2f663c5", "type of np.std(fruit_data_scaled.width.dropna()) is not correct"
assert sha1(str(np.std(fruit_data_scaled.width.dropna())).encode("utf-8")+b"0b960402c1535b69").hexdigest() == "d2a6a724b8939ae23005d6b46670ce4af78afa40", "value of np.std(fruit_data_scaled.width.dropna()) is not correct"

assert sha1(str(type(np.std(fruit_data_scaled.color_score.dropna()))).encode("utf-8")+b"6bfc459dbf3d0ef8").hexdigest() == "bf9b50c934a85544c2fba49c2cfbeb106493336d", "type of np.std(fruit_data_scaled.color_score.dropna()) is not correct"
assert sha1(str(np.std(fruit_data_scaled.color_score.dropna())).encode("utf-8")+b"6bfc459dbf3d0ef8").hexdigest() == "13512d7e5632a431a3aed025783af1d1ff25ab2b", "value of np.std(fruit_data_scaled.color_score.dropna()) is not correct"

assert sha1(str(type(fruit_data_preprocessor is None)).encode("utf-8")+b"209a3bad4c49b981").hexdigest() == "48c7bc71dafbde453388fc9cfa8995247c08698f", "type of fruit_data_preprocessor is None is not bool. fruit_data_preprocessor is None should be a bool"
assert sha1(str(fruit_data_preprocessor is None).encode("utf-8")+b"209a3bad4c49b981").hexdigest() == "622a7e785e42d5864666c9378ed4b5cb8d454695", "boolean value of fruit_data_preprocessor is None is not correct"

assert sha1(str(type(fruit_data_preprocessor.transformers_[1][2])).encode("utf-8")+b"8d747fc63ab19d79").hexdigest() == "f75d4d502b96d1a5fd988ed2b6fd691f67c2e7d8", "type of fruit_data_preprocessor.transformers_[1][2] is not list. fruit_data_preprocessor.transformers_[1][2] should be a list"
assert sha1(str(len(fruit_data_preprocessor.transformers_[1][2])).encode("utf-8")+b"8d747fc63ab19d79").hexdigest() == "154c3660ca141c06939b948549679750bc3d303c", "length of fruit_data_preprocessor.transformers_[1][2] is not correct"
assert sha1(str(sorted(map(str, fruit_data_preprocessor.transformers_[1][2]))).encode("utf-8")+b"8d747fc63ab19d79").hexdigest() == "756553ee643d0766f612ef66a4be1433bdc5e9db", "values of fruit_data_preprocessor.transformers_[1][2] are not correct"
assert sha1(str(fruit_data_preprocessor.transformers_[1][2]).encode("utf-8")+b"8d747fc63ab19d79").hexdigest() == "fee36da06ced1231c40039a8f0fa2c146ab00cb0", "order of elements of fruit_data_preprocessor.transformers_[1][2] is not correct"

print('Success!')

**Question 1.6**
<br> {points: 1}

Let's repeat **Question 1.2 and 1.3** with the scaled variables:

- calculate the distance with the scaled mass and width variables between observations 1 and 2
- calculate the distances with the scaled mass and width variables between observations 1 and 44 

After you do this, think about how these distances compared to the distances you computed in **Question 1.2 and 1.3** for the same points.

*Assign your answers to objects called `distance_2` and `distance_44` respectively.*

In [None]:
# your code here
raise NotImplementedError
print(distance_2)
print(distance_44)

In [None]:
from hashlib import sha1
assert sha1(str(type(distance_2 is None)).encode("utf-8")+b"af48b7c63917e5a9").hexdigest() == "15c4ac7742a7cd66de6ebd32d82fcf3181cf7de3", "type of distance_2 is None is not bool. distance_2 is None should be a bool"
assert sha1(str(distance_2 is None).encode("utf-8")+b"af48b7c63917e5a9").hexdigest() == "a062be0e9450be8edb38d536ea5fd8208b250e23", "boolean value of distance_2 is None is not correct"

assert sha1(str(type(distance_44 is None)).encode("utf-8")+b"c1e6f3e1e40fd77d").hexdigest() == "3e9aacb71f53ad29afc0225e9b43fa8df1341706", "type of distance_44 is None is not bool. distance_44 is None should be a bool"
assert sha1(str(distance_44 is None).encode("utf-8")+b"c1e6f3e1e40fd77d").hexdigest() == "6cc94ce85800bf955ee3045914683736abec22c8", "boolean value of distance_44 is None is not correct"

assert sha1(str(type(distance_2)).encode("utf-8")+b"5906162c68d69397").hexdigest() == "50bdab37a3ef2285459e645c2a303ee0612ad3ac", "type of type(distance_2) is not correct"

assert sha1(str(type(distance_44)).encode("utf-8")+b"97abe283234cd926").hexdigest() == "4cd47e5ce28d679f809645db119d7ef458cc4122", "type of type(distance_44) is not correct"

assert sha1(str(type(distance_2)).encode("utf-8")+b"6f4d3bf4200b8aa3").hexdigest() == "b7865d0668fc39234c10b4db1ddb43bdd6a4bab4", "type of distance_2 is not correct"
assert sha1(str(distance_2).encode("utf-8")+b"6f4d3bf4200b8aa3").hexdigest() == "f9d209ad60af53ebfb7404139fa0c3e7a590c7ad", "value of distance_2 is not correct"

assert sha1(str(type(distance_44)).encode("utf-8")+b"df0fef3b4d8156d9").hexdigest() == "01d025676cee99bd2ff8d5aed8de378abbcd8ba4", "type of distance_44 is not correct"
assert sha1(str(distance_44).encode("utf-8")+b"df0fef3b4d8156d9").hexdigest() == "59551a0666b1bcd8708c4310832b70341ed16415", "value of distance_44 is not correct"

print('Success!')

**Question 1.7**
<br> {points: 1}

Make a scatterplot of scaled mass on the horizontal axis and scaled color score on the vertical axis. Color the points by fruit name. 

*Assign your plot to an object called `fruit_plot`. Make sure to do all the things to make an effective visualization.*

In [None]:
# your code here
raise NotImplementedError
fruit_plot

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_plot is None)).encode("utf-8")+b"75047ef440b37ad1").hexdigest() == "f441dcffdc3b23a3e581f6845f8ae70ac8b336cb", "type of fruit_plot is None is not bool. fruit_plot is None should be a bool"
assert sha1(str(fruit_plot is None).encode("utf-8")+b"75047ef440b37ad1").hexdigest() == "e4339c89b9a49473bc7869b58bca99131a01bad9", "boolean value of fruit_plot is None is not correct"

assert sha1(str(type(fruit_plot.encoding.x.field)).encode("utf-8")+b"ee223293dca6ddf9").hexdigest() == "2f3df132cd2a32938207678014296f6cab6f6ae0", "type of fruit_plot.encoding.x.field is not str. fruit_plot.encoding.x.field should be an str"
assert sha1(str(len(fruit_plot.encoding.x.field)).encode("utf-8")+b"ee223293dca6ddf9").hexdigest() == "65edff7c2d5238263eabe2d21fd416202cbdd5b6", "length of fruit_plot.encoding.x.field is not correct"
assert sha1(str(fruit_plot.encoding.x.field.lower()).encode("utf-8")+b"ee223293dca6ddf9").hexdigest() == "708da8f4f43d70ed0edfa8c2224a4518fbe999d6", "value of fruit_plot.encoding.x.field is not correct"
assert sha1(str(fruit_plot.encoding.x.field).encode("utf-8")+b"ee223293dca6ddf9").hexdigest() == "708da8f4f43d70ed0edfa8c2224a4518fbe999d6", "correct string value of fruit_plot.encoding.x.field but incorrect case of letters"

assert sha1(str(type(fruit_plot.encoding.y.field)).encode("utf-8")+b"a648111d62e037a0").hexdigest() == "4f46007db9238dcd8571f2957e405143b884c166", "type of fruit_plot.encoding.y.field is not str. fruit_plot.encoding.y.field should be an str"
assert sha1(str(len(fruit_plot.encoding.y.field)).encode("utf-8")+b"a648111d62e037a0").hexdigest() == "9013fb9bfe2a5ca359edba04a21c7bcfb432de5d", "length of fruit_plot.encoding.y.field is not correct"
assert sha1(str(fruit_plot.encoding.y.field.lower()).encode("utf-8")+b"a648111d62e037a0").hexdigest() == "f8ec4a6b9d93fc4fef3a217ca1713d9a6cd12286", "value of fruit_plot.encoding.y.field is not correct"
assert sha1(str(fruit_plot.encoding.y.field).encode("utf-8")+b"a648111d62e037a0").hexdigest() == "f8ec4a6b9d93fc4fef3a217ca1713d9a6cd12286", "correct string value of fruit_plot.encoding.y.field but incorrect case of letters"

assert sha1(str(type(fruit_plot.encoding.color.field)).encode("utf-8")+b"16a494d0558e6061").hexdigest() == "fd4df5f99c4b27a4ba7bc9187aeaad43be90fb69", "type of fruit_plot.encoding.color.field is not str. fruit_plot.encoding.color.field should be an str"
assert sha1(str(len(fruit_plot.encoding.color.field)).encode("utf-8")+b"16a494d0558e6061").hexdigest() == "b7793015f8a77483e35e697c229538a49bb50f44", "length of fruit_plot.encoding.color.field is not correct"
assert sha1(str(fruit_plot.encoding.color.field.lower()).encode("utf-8")+b"16a494d0558e6061").hexdigest() == "f6f24873d3843051bf724c1e2ae958ca839b6fe6", "value of fruit_plot.encoding.color.field is not correct"
assert sha1(str(fruit_plot.encoding.color.field).encode("utf-8")+b"16a494d0558e6061").hexdigest() == "f6f24873d3843051bf724c1e2ae958ca839b6fe6", "correct string value of fruit_plot.encoding.color.field but incorrect case of letters"

assert sha1(str(type(fruit_plot.mark)).encode("utf-8")+b"d4780c6be34e207d").hexdigest() == "e137e650878d3e6d782921590493244209a8c9d3", "type of fruit_plot.mark is not str. fruit_plot.mark should be an str"
assert sha1(str(len(fruit_plot.mark)).encode("utf-8")+b"d4780c6be34e207d").hexdigest() == "82fb8e0cbce9be28b41ef5487ae27021cd1388f0", "length of fruit_plot.mark is not correct"
assert sha1(str(fruit_plot.mark.lower()).encode("utf-8")+b"d4780c6be34e207d").hexdigest() == "1ff4458cc6294806d6cd3f46ab7c82cc1c2c4492", "value of fruit_plot.mark is not correct"
assert sha1(str(fruit_plot.mark).encode("utf-8")+b"d4780c6be34e207d").hexdigest() == "1ff4458cc6294806d6cd3f46ab7c82cc1c2c4492", "correct string value of fruit_plot.mark but incorrect case of letters"

assert sha1(str(type(fruit_plot.encoding.x.title != fruit_plot.encoding.x.field)).encode("utf-8")+b"a121b621eb95244e").hexdigest() == "63ac1108325fddf2dd09e500c532f4a530269ae6", "type of fruit_plot.encoding.x.title != fruit_plot.encoding.x.field is not bool. fruit_plot.encoding.x.title != fruit_plot.encoding.x.field should be a bool"
assert sha1(str(fruit_plot.encoding.x.title != fruit_plot.encoding.x.field).encode("utf-8")+b"a121b621eb95244e").hexdigest() == "7adee825267515d27421132882971cfd7228085c", "boolean value of fruit_plot.encoding.x.title != fruit_plot.encoding.x.field is not correct"

assert sha1(str(type(fruit_plot.encoding.y.title != fruit_plot.encoding.y.field)).encode("utf-8")+b"957ae86b1df4f876").hexdigest() == "bf309aba949f8bb39792cf607d5018811d6c6ed1", "type of fruit_plot.encoding.y.title != fruit_plot.encoding.y.field is not bool. fruit_plot.encoding.y.title != fruit_plot.encoding.y.field should be a bool"
assert sha1(str(fruit_plot.encoding.y.title != fruit_plot.encoding.y.field).encode("utf-8")+b"957ae86b1df4f876").hexdigest() == "8cff8c7fd968c478f1a54c8c31a4f17967ffcf01", "boolean value of fruit_plot.encoding.y.title != fruit_plot.encoding.y.field is not correct"

assert sha1(str(type(fruit_plot.encoding.color.title != fruit_plot.encoding.color.field)).encode("utf-8")+b"7040079b43c16a6f").hexdigest() == "86e38eb1e34cc8b367b814c38c22a3e7a711830e", "type of fruit_plot.encoding.color.title != fruit_plot.encoding.color.field is not bool. fruit_plot.encoding.color.title != fruit_plot.encoding.color.field should be a bool"
assert sha1(str(fruit_plot.encoding.color.title != fruit_plot.encoding.color.field).encode("utf-8")+b"7040079b43c16a6f").hexdigest() == "1c5726c3dbc3097a110a3288afa8231c0df42f51", "boolean value of fruit_plot.encoding.color.title != fruit_plot.encoding.color.field is not correct"

print('Success!')

**Question 1.8** 
<br> {points: 3}

Suppose we have a new observation in the fruit dataset with scaled mass 0.5 and scaled color score 0.5.

Just by looking at the scatterplot, how would you classify this observation using K-nearest neighbours if you use K = 3? Explain how you arrived at your answer.

DOUBLE CLICK TO EDIT **THIS CELL** AND REPLACE THIS TEXT WITH YOUR ANSWER.

**Question 1.9**
<br> {points: 1}

Now, let's use the `scikit-learn` package to predict `fruit_name` for another new observation. The new observation we are interested in has mass 150g and color score 0.73.

First, create the K-nearest neighbour model specification. Specify we want $K=5$ neighbors and `weights = "distance"`. Name this model specification as `knn_spec`.

Then create a new preprocessor named `fruit_data_preprocessor_2` that centers and scales the predictors, but only uses `mass` and `color_score` as predictors. We can drop all other unused columns. Name the predictor as `X` and the target `y`.

Combine this with your neighbour model from before in a `pipeline`, and fit to the `fruit_data` dataset. 

*Name the fitted model `fruit_fit`.*

In [None]:
# ___ = KNeighborsClassifier(n_neighbors=___, weights="distance")

# ____ = make_column_transformer(
#     (___, [___, ___]),
# )

# X = ____.drop(
#         columns=[___, ___, ___, ___, ___]
#     )
# y = ___[___]

# ___ = ___(___, ___).fit(___, ___)

# your code here
raise NotImplementedError
fruit_fit

In [None]:
from hashlib import sha1
assert sha1(str(type(knn_spec is None)).encode("utf-8")+b"fcc36f081841dd6b").hexdigest() == "a2e19a3d3fcdf82156fb6760300e8f7aa39226bf", "type of knn_spec is None is not bool. knn_spec is None should be a bool"
assert sha1(str(knn_spec is None).encode("utf-8")+b"fcc36f081841dd6b").hexdigest() == "73f4d88d7b62818f8b569b1dc0d857816193128a", "boolean value of knn_spec is None is not correct"

assert sha1(str(type(knn_spec.n_neighbors)).encode("utf-8")+b"8579d2104597037a").hexdigest() == "1a5697071a6225947a408058124d17aee1b1e976", "type of knn_spec.n_neighbors is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(knn_spec.n_neighbors).encode("utf-8")+b"8579d2104597037a").hexdigest() == "9933fe23b6979ebb42e57983e6ec77510dec3303", "value of knn_spec.n_neighbors is not correct"

assert sha1(str(type(knn_spec.effective_metric_)).encode("utf-8")+b"dd5e7346384c6662").hexdigest() == "a81f74d08d5ca35535c25bec200a3dd4425b47e1", "type of knn_spec.effective_metric_ is not str. knn_spec.effective_metric_ should be an str"
assert sha1(str(len(knn_spec.effective_metric_)).encode("utf-8")+b"dd5e7346384c6662").hexdigest() == "7e7ca9a479db2a4aaf225212cf57c8d0cfba707f", "length of knn_spec.effective_metric_ is not correct"
assert sha1(str(knn_spec.effective_metric_.lower()).encode("utf-8")+b"dd5e7346384c6662").hexdigest() == "fcfcd2e1ee93768f7cf08176b2270cd7f605ff9c", "value of knn_spec.effective_metric_ is not correct"
assert sha1(str(knn_spec.effective_metric_).encode("utf-8")+b"dd5e7346384c6662").hexdigest() == "fcfcd2e1ee93768f7cf08176b2270cd7f605ff9c", "correct string value of knn_spec.effective_metric_ but incorrect case of letters"

assert sha1(str(type(fruit_data_preprocessor_2 is None)).encode("utf-8")+b"a4e1a6627d3b7582").hexdigest() == "7a8f43afc5639f4c8318ee7637f4af00f92b68b3", "type of fruit_data_preprocessor_2 is None is not bool. fruit_data_preprocessor_2 is None should be a bool"
assert sha1(str(fruit_data_preprocessor_2 is None).encode("utf-8")+b"a4e1a6627d3b7582").hexdigest() == "98c44333a06f440481a244cb296b9019d34119a9", "boolean value of fruit_data_preprocessor_2 is None is not correct"

assert sha1(str(type(fruit_data_preprocessor_2.transformers_[0][2])).encode("utf-8")+b"6af8eb3c4b97d90d").hexdigest() == "42e2398f9aef2c240254e1d30e4103db0adbabaf", "type of fruit_data_preprocessor_2.transformers_[0][2] is not list. fruit_data_preprocessor_2.transformers_[0][2] should be a list"
assert sha1(str(len(fruit_data_preprocessor_2.transformers_[0][2])).encode("utf-8")+b"6af8eb3c4b97d90d").hexdigest() == "28dfab95ca57a31cf8b62c2c96a191ae360dfb00", "length of fruit_data_preprocessor_2.transformers_[0][2] is not correct"
assert sha1(str(sorted(map(str, fruit_data_preprocessor_2.transformers_[0][2]))).encode("utf-8")+b"6af8eb3c4b97d90d").hexdigest() == "3d58831c0c197177aed1cb08ed50f164d5b4cf98", "values of fruit_data_preprocessor_2.transformers_[0][2] are not correct"
assert sha1(str(fruit_data_preprocessor_2.transformers_[0][2]).encode("utf-8")+b"6af8eb3c4b97d90d").hexdigest() == "5e2ce654a14ca9f037b51ea1d2f86ab1f7b3dec8", "order of elements of fruit_data_preprocessor_2.transformers_[0][2] is not correct"

assert sha1(str(type(fruit_fit is None)).encode("utf-8")+b"600a9939c4f53015").hexdigest() == "bd41bb844166e96752a2c115d576779ce31d4cd3", "type of fruit_fit is None is not bool. fruit_fit is None should be a bool"
assert sha1(str(fruit_fit is None).encode("utf-8")+b"600a9939c4f53015").hexdigest() == "15c6f2a1bef18c0ebeb012a868f487e7cca6a944", "boolean value of fruit_fit is None is not correct"

assert sha1(str(type(type(fruit_fit))).encode("utf-8")+b"6a3ea094cc03bb31").hexdigest() == "11a32c7ea6098dcede9e9e27aca6021ebc2376a5", "type of type(fruit_fit) is not correct"
assert sha1(str(type(fruit_fit)).encode("utf-8")+b"6a3ea094cc03bb31").hexdigest() == "95147d15da8850f5133362de3d07a96cc89bb637", "value of type(fruit_fit) is not correct"

assert sha1(str(type(fruit_fit.named_steps.kneighborsclassifier.n_neighbors)).encode("utf-8")+b"42bcc6001f49f1b9").hexdigest() == "6341f918c0698d5f9d9deb4283ae4f1a6c39bd82", "type of fruit_fit.named_steps.kneighborsclassifier.n_neighbors is not int. Please make sure it is int and not np.int64, etc. You can cast your value into an int using int()"
assert sha1(str(fruit_fit.named_steps.kneighborsclassifier.n_neighbors).encode("utf-8")+b"42bcc6001f49f1b9").hexdigest() == "d2ecd0d79066ce7554d6f7c64a7ff25d5b7db173", "value of fruit_fit.named_steps.kneighborsclassifier.n_neighbors is not correct"

assert sha1(str(type(fruit_fit.named_steps.kneighborsclassifier.effective_metric_)).encode("utf-8")+b"0176f1f191204557").hexdigest() == "799774a15eb1d4a9e0189e613d767c498c0deb41", "type of fruit_fit.named_steps.kneighborsclassifier.effective_metric_ is not str. fruit_fit.named_steps.kneighborsclassifier.effective_metric_ should be an str"
assert sha1(str(len(fruit_fit.named_steps.kneighborsclassifier.effective_metric_)).encode("utf-8")+b"0176f1f191204557").hexdigest() == "f8f6e83faf38c612ffdab75917437480888ed647", "length of fruit_fit.named_steps.kneighborsclassifier.effective_metric_ is not correct"
assert sha1(str(fruit_fit.named_steps.kneighborsclassifier.effective_metric_.lower()).encode("utf-8")+b"0176f1f191204557").hexdigest() == "46cb3b170a0bd9f4598694cc69729f1291840aef", "value of fruit_fit.named_steps.kneighborsclassifier.effective_metric_ is not correct"
assert sha1(str(fruit_fit.named_steps.kneighborsclassifier.effective_metric_).encode("utf-8")+b"0176f1f191204557").hexdigest() == "46cb3b170a0bd9f4598694cc69729f1291840aef", "correct string value of fruit_fit.named_steps.kneighborsclassifier.effective_metric_ but incorrect case of letters"

print('Success!')

**Question 1.10**
<br> {points: 1}

Create a new dataframe `mass = 150` and `color_score = 0.73` and call it `new_fruit`. Then, pass `fruit_fit` and `new_fruit` to the `predict` function to predict the class for the new fruit observation. Save your prediction to an object named `fruit_predicted`.

In [None]:
# your code here
raise NotImplementedError
fruit_predicted

In [None]:
from hashlib import sha1
assert sha1(str(type(new_fruit is None)).encode("utf-8")+b"97b269ad098320c0").hexdigest() == "b9f5d21c5410c97790a73b68bd99b9661a73c8ec", "type of new_fruit is None is not bool. new_fruit is None should be a bool"
assert sha1(str(new_fruit is None).encode("utf-8")+b"97b269ad098320c0").hexdigest() == "b15381acd945bd543df0888e8458d3ba0ba6d4cf", "boolean value of new_fruit is None is not correct"

assert sha1(str(type(new_fruit.shape)).encode("utf-8")+b"e3aefa0f36ef0d2f").hexdigest() == "22ab9a201edc0368833d18965ca4c8933937fbb0", "type of new_fruit.shape is not tuple. new_fruit.shape should be a tuple"
assert sha1(str(len(new_fruit.shape)).encode("utf-8")+b"e3aefa0f36ef0d2f").hexdigest() == "b37650460cea14c7c4d2bb8d99a3a4180f3d7704", "length of new_fruit.shape is not correct"
assert sha1(str(sorted(map(str, new_fruit.shape))).encode("utf-8")+b"e3aefa0f36ef0d2f").hexdigest() == "eee72d32d3b2aed5103b1dba4747bdbccced6c4d", "values of new_fruit.shape are not correct"
assert sha1(str(new_fruit.shape).encode("utf-8")+b"e3aefa0f36ef0d2f").hexdigest() == "01055a1d66a2322c9336f877a652da17a77ef025", "order of elements of new_fruit.shape is not correct"

assert sha1(str(type(new_fruit.mass.values)).encode("utf-8")+b"dd7ce6e267f3c34c").hexdigest() == "34989edebf2578a4910b1e6d9d704e153f6dd316", "type of new_fruit.mass.values is not correct"
assert sha1(str(new_fruit.mass.values).encode("utf-8")+b"dd7ce6e267f3c34c").hexdigest() == "6a54979c4686072de4c98bf222a991e035531cb1", "value of new_fruit.mass.values is not correct"

assert sha1(str(type(new_fruit.color_score.values)).encode("utf-8")+b"4242be496fd6567a").hexdigest() == "1fb6ccc45c22cd0db67bbc6c8c7ffcab23990060", "type of new_fruit.color_score.values is not correct"
assert sha1(str(new_fruit.color_score.values).encode("utf-8")+b"4242be496fd6567a").hexdigest() == "c8d2c9927919ac8da323399f11f816c70a76f860", "value of new_fruit.color_score.values is not correct"

assert sha1(str(type(fruit_predicted)).encode("utf-8")+b"1a6cdceb98df7859").hexdigest() == "0375cc262d0e65aeecd321e124d420716dc69bdb", "type of fruit_predicted is not correct"
assert sha1(str(fruit_predicted).encode("utf-8")+b"1a6cdceb98df7859").hexdigest() == "bfe087b00b30fda3c333dc0a2cf12af0c1b14984", "value of fruit_predicted is not correct"

print('Success!')

**Question 1.11** 
<br> {points: 3}

Revisiting `fruit_plot` and considering the prediction given by K-nearest neighbours above, do you think the classification model did a "good" job predicting? Could you have done/do better? Given what we know this far in the course, what might we want to do to help with tricky prediction cases such as this?

*You can use the code below to visualize the observation whose label we just tried to predict.*

In [None]:
fruit_plot + (
    alt.Chart(pd.DataFrame([[-0.3, -0.4]], columns=["x", "y"]))
    .mark_circle(size=50)
    .encode(x="x", y="y", color=alt.value("black"))
)

DOUBLE CLICK TO EDIT **THIS CELL** AND REPLACE THIS TEXT WITH YOUR ANSWER.

**Question 1.12**
<br> {points: 1}

Now do K-nearest neighbours classification again with the same data set, same K, and same new observation. However, this time, let's use **all the columns in the dataset as predictors (except for the categorical `fruit_label` and `fruit_subtype` variables).** Therefore, you would need to make a new preprocessor.

We have provided the `new_fruit_all` dataframe below, which encodes the predictors for our new observation. Your job is to use K-nearest neighbours to predict the class of this point. You can reuse the model specification you created earlier. 

Name the new predictor as `X_2` and new target `y_2`.

*Assign your answer (the output of `predict`) to an object called `fruit_all_predicted`.*

In [None]:
# This is the new observation to predict class label for
new_fruit_all = pd.DataFrame(
    [[150, 6, 10, 0.73]],
    columns=[
        "mass",
        "width",
        "height",
        "color_score",
    ],
)

# no hints this time!

# your code here
raise NotImplementedError
fruit_all_predicted

In [None]:
from hashlib import sha1
assert sha1(str(type(fruit_all_predicted)).encode("utf-8")+b"f3fc7f4debd00044").hexdigest() == "3439e2e5ad5308bc2f02ddca8a7da70ae326ce20", "type of fruit_all_predicted is not correct"
assert sha1(str(fruit_all_predicted).encode("utf-8")+b"f3fc7f4debd00044").hexdigest() == "8418a29b45d0cf4aa31ef49c2126def519ade9a3", "value of fruit_all_predicted is not correct"

print('Success!')

**Question 1.13** 
<br> {points: 3}

Did your second classification on the same data set with the same K change the prediction? If so, why do you think this happened?

DOUBLE CLICK TO EDIT **THIS CELL** AND REPLACE THIS TEXT WITH YOUR ANSWER.

## 2. Wheat Seed Dataset

X-ray images can be used to analyze and sort seeds. In [this data set](https://archive.ics.uci.edu/ml/datasets/seeds), we have 7 measurements from x-ray images from 3 varieties of wheat seeds (Kama, Rosa and Canadian). 

**Question 2.0**
<br> {points: 3}

Let's use `scikit-learn` to perform K-nearest neighbours to classify the wheat variety of seeds. The data set is available here: https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt. **Download the data set directly from this URL using the `pd.read_csv` function with `delimiter='\t'`**, which is helpful when the columns are separated by one or more white spaces.

The seven measurements were taken below for each wheat kernel:
1. area A, 
2. perimeter P, 
3. compactness C = 4*pi*A/P^2, 
4. length of kernel, 
5. width of kernel, 
6. asymmetry coefficient 
7. length of kernel groove. 

The last column in the data set is the variety label. The mapping for the numbers to varieties is listed below:

- 1 == Kama
- 2 == Rosa
- 3 == Canadian

Use `scikit-learn` with this data to perform K-nearest neighbours to classify the wheat variety of a new seed we measure with the given observed measurements (from an x-ray image) listed above. Specify that we want $K = 5$ neighbors to perform the classification. 

*Assign your answer to an object called `seed_predict`.*

Hints: 
- `names` can be used to specify the column names of a data frame.
- There are some nan values in the dataset, please use `dropna` to drop the nan values in the dataset before passing it into the K-nearest neighbours model.

In [None]:
# This is the new observation to predict
new_seed = pd.DataFrame(
    [[12.1, 14.2, 0.9, 4.9, 2.8, 3.0, 5.1]],
    columns=[
        "area",
        "perimeter",
        "compactness",
        "length",
        "width",
        "asymmetry_coefficient",
        "groove_length",
    ],
)

# your code here
raise NotImplementedError
seed_predict

**Question 2.1** Multiple Choice:
<br> {points: 1}

What is classification of the `new_seed` observation?

A. Kama

B. Rosa

C. Canadian

*Assign your answer to an object called `answer2_1`. Make sure your answer is in uppercase and is surrounded by quotation marks (e.g. `"F"`).*


In [None]:
# your code here
raise NotImplementedError

In [None]:
from hashlib import sha1
assert sha1(str(type(answer2_1)).encode("utf-8")+b"5184f8031f2bcfba").hexdigest() == "e2ab6eb25886ff45887345b0a732b93e18b47717", "type of answer2_1 is not str. answer2_1 should be an str"
assert sha1(str(len(answer2_1)).encode("utf-8")+b"5184f8031f2bcfba").hexdigest() == "3dd14138a255b00eeaee7f06e21d6ee7ed51a68c", "length of answer2_1 is not correct"
assert sha1(str(answer2_1.lower()).encode("utf-8")+b"5184f8031f2bcfba").hexdigest() == "213ee448c316f4f4edb30878dc15f182dd06c65d", "value of answer2_1 is not correct"
assert sha1(str(answer2_1).encode("utf-8")+b"5184f8031f2bcfba").hexdigest() == "9f874e70f664c55492674835ac7b7bf632f01427", "correct string value of answer2_1 but incorrect case of letters"

print('Success!')