In [87]:
import torch
from collections import Counter

def loadData(fileName):
    data = []
    with open(fileName, 'r') as f:
        for line in f:
            row = list(map(float, line.strip().split(',')))
            data.append(row)
    return torch.tensor(data, dtype=torch.float32)

trainingSet = loadData("yeast_train.txt")
testingSet = loadData("yeast_test.txt")

In [88]:
print(trainingSet.shape)
print(testingSet.shape)
print(trainingSet[:5])
testInstance = testingSet[1, :-1]
print(testInstance)
print(testingSet.shape[0])

torch.Size([1039, 9])
torch.Size([445, 9])
tensor([[0.5800, 0.6100, 0.4700, 0.1300, 0.5000, 0.0000, 0.4800, 0.2200, 1.0000],
        [0.4300, 0.6700, 0.4800, 0.2700, 0.5000, 0.0000, 0.5300, 0.2200, 1.0000],
        [0.6400, 0.6200, 0.4900, 0.1500, 0.5000, 0.0000, 0.5300, 0.2200, 1.0000],
        [0.5800, 0.4400, 0.5700, 0.1300, 0.5000, 0.0000, 0.5400, 0.2200, 2.0000],
        [0.4200, 0.4400, 0.4800, 0.5400, 0.5000, 0.0000, 0.4800, 0.2200, 1.0000]])
tensor([0.4000, 0.4200, 0.5700, 0.3500, 0.5000, 0.0000, 0.5300, 0.2500])
445


In [89]:
def getDistance(x, y):
    return torch.sqrt(torch.sum((x - y) ** 2))

In [None]:
def getNeighbors(trainingSet, tester, k):
    distances = []
    for i in range(trainingSet.shape[0]):
        dist = getDistance(tester, trainingSet[i][:-1])
        distances.append((dist, trainingSet[i][-1], i))
    distances.sort(key=lambda x: (x[0], x[2]))

    neighbors = [distances[i][1].item() for i in range(k)]
    return neighbors

In [91]:
def guessClass(neighbors, classOrder):
    
    if(len(neighbors) > 1):
        
        labels = [label for label in neighbors]

        mostVotes = max(Counter(labels).values())
        options = [label for label, total in Counter(labels).items() if total == mostVotes]

        for cls in classOrder:
            if cls in options:
                return cls
    else:
        return neighbors[0]

In [92]:
def mykNN(trainingSet, testingSet, k):
    classOrder = []
    for label in trainingSet[:, -1]:
        if label not in classOrder:
            classOrder.append(label)

    predictions = []

    for i in range(testingSet.shape[0]):
        tester = testingSet[i, :-1]
        actual = testingSet[i, -1].item()
        neighbors = getNeighbors(trainingSet, tester[0], k)
        guess = guessClass(neighbors, classOrder)
        predictions.append((guess, actual))

    return predictions

In [93]:
def loocv(trainingSet, k):
    errors = []

    for i in range(trainingSet.shape[0]):
        trainer = torch.cat((trainingSet[:i], trainingSet[i+1:]))
        tester = trainingSet[i, :-1]

        predicted = mykNN(trainer, tester.unsqueeze(0), k)[0][0]
        actual = trainingSet[i, -1].item()

        errors.append(abs(predicted - actual))
    return sum(errors) / len(errors)

In [94]:
def runkNN(trainingSet, testingSet):
    errors = []
    for k in range(3):
        mae = loocv(trainingSet, k+1)
        errors.append((mae, k+1))
        
    errors.sort()
    bestK = errors[0][1]

    print(f"K chosen to be: {bestK}")

    predictions = mykNN(trainingSet, testingSet, bestK)

    correct = 0
    totalError = 0
    for guess, actual in predictions:
        if guess == actual:
            correct += 1
        totalError += abs(guess - actual)
    
    print(correct, len(predictions))
    print(totalError / len(predictions), len(predictions))

In [95]:
runkNN(trainingSet, testingSet)

[1.0]
[9.0]
[1.0]
[1.0]
[9.0]
[9.0]
[9.0]
[1.0]
[9.0]
[9.0]
[9.0]
[1.0]
[1.0]
[9.0]
[9.0]
[9.0]
[9.0]
[9.0]
[1.0]
[1.0]
[1.0]
[1.0]
[9.0]
[5.0]
[1.0]
[9.0]
[1.0]
[1.0]
[1.0]
[9.0]
[9.0]
[1.0]
[1.0]
[9.0]
[1.0]
[9.0]
[9.0]
[9.0]
[1.0]
[1.0]
[1.0]
[9.0]
[9.0]
[9.0]
[1.0]
[9.0]
[9.0]
[1.0]
[9.0]
[9.0]
[9.0]
[1.0]
[9.0]
[9.0]
[9.0]
[9.0]
[1.0]
[5.0]
[1.0]
[9.0]
[9.0]
[9.0]
[9.0]
[5.0]
[9.0]
[9.0]
[1.0]
[1.0]
[1.0]
[2.0]
[9.0]
[9.0]
[9.0]
[9.0]
[5.0]
[9.0]
[5.0]
[9.0]
[9.0]
[1.0]
[2.0]
[1.0]
[1.0]
[1.0]
[1.0]
[9.0]
[9.0]
[1.0]
[1.0]
[2.0]
[5.0]
[9.0]
[9.0]
[5.0]
[9.0]
[1.0]
[9.0]
[9.0]
[9.0]
[1.0]
[9.0]
[9.0]
[2.0]
[9.0]
[9.0]
[1.0]
[9.0]
[1.0]
[9.0]
[1.0]
[2.0]
[9.0]
[5.0]
[9.0]
[5.0]
[5.0]
[9.0]
[1.0]
[5.0]
[9.0]
[9.0]
[9.0]
[5.0]
[9.0]
[1.0]
[9.0]
[9.0]
[5.0]
[5.0]
[5.0]
[1.0]
[9.0]
[1.0]
[1.0]
[1.0]
[1.0]
[1.0]
[9.0]
[9.0]
[1.0]
[1.0]
[9.0]
[9.0]
[9.0]
[1.0]
[1.0]
[9.0]
[1.0]
[1.0]
[9.0]
[1.0]
[1.0]
[1.0]
[1.0]
[5.0]
[2.0]
[2.0]
[9.0]
[9.0]
[9.0]
[9.0]
[9.0]
[9.0]
[5.0]
[1.0]
[9.0]
[5.0