In [1]:
from typing import List
from collections import Counter

def raw_majority_vote(labels: List[str]) -> str:
    votes = Counter(labels)
    winner, _ = votes.most_common(1)[0]
    return winner

assert raw_majority_vote(['a', 'b', 'c', 'b']) == 'b'

In [2]:
def majority_vote(labels: List[str]) -> str:
    """ Assumes that labels are ordered from nearest to farthest """
    vote_counts = Counter(labels)
    winner, winner_count = vote_counts.most_common(1)[0]
    num_winners = len([count for count in vote_counts.values() if count == winner_count])
    if num_winners == 1:
        return winner
    else:
        return majority_vote(labels[:-1])

assert majority_vote(['a', 'b', 'c', 'b', 'a']) == 'b'

In [3]:
from typing import NamedTuple
from scratch.linear_algebra import Vector, distance

class LabeledPoint(NamedTuple):
    point: Vector
    label: str

def knn_classify(k: int, labeled_points: List[LabeledPoint], new_point: Vector) -> str:
    # Order the labeled points from nearest to farthest.
    by_distance = sorted(labeled_points, key=lambda lp: distance(lp.point, new_point))
    # Find the labels for the k closest
    k_nearest_labels = [lp.label for lp in by_distance[:k]]
    # And let them vote
    return majority_vote(k_nearest_labels)

In [4]:
from pathlib import Path

iris_dataset = Path() / 'datasets' / 'iris.dat'
assert iris_dataset.is_file(), f'Cannot find {iris_dataset}'

In [8]:
import csv
from typing import Dict
from collections import defaultdict

def parse_iris_row(row: List[str]) -> LabeledPoint:
    """ sepal_length, sepal_width, petal_length, petal_width, class """
    measurements = [float(value) for value in row[:-1]]
    # class is e.g. 'Iris-virginica'; we just want 'virginica'
    label = row[-1].split('-')[-1]
    return LabeledPoint(measurements, label)

with iris_dataset.open('r') as f:
    reader = csv.reader(f)
    iris_data = [parse_iris_row(row) for row in reader if row]
    
# We'll also group just the points by species/label so we can plot them
points_by_species: Dict[str, List[Vector]] = defaultdict(list)
for iris in iris_data:
    points_by_species[iris.label].append(iris.point)