In [37]:
import numpy as np
import re
import csv
import random

random.seed(42)

def tokenize(text):
    text = text.lower()
    words = re.findall(r"\b[a-zA-Z]+\b", text)
    return words
def load_data(file_name, max_labels=100):
    data = []
    labels_counter = {}
    with open(file_name, 'r', encoding="utf-8") as file:
        first = True
        for row in csv.reader(file):
            if not first:
                label = int(row[0])-1
                if label in [0,1,2]:
                    data.append((set(tokenize(row[1]+' '+row[2])), label))
            else:
                first = not first
    random.shuffle(data)
    short_data = []
    for row in data:
        label = row[1]
        if label not in labels_counter.keys():
            labels_counter[label] = 0
        if labels_counter[label] < max_labels:
            labels_counter[label]+=1
            short_data.append(row)
    return short_data

row_data_train = load_data(r".\train.csv")
row_data_test = load_data(r".\test.csv", 80)

print(row_data_train)
print(row_data_test)

[({'scene', 'his', 'dominated', 'dari', 'duo', 'and', 'persian', 'is', 'a', 'music', 'ex', 'wife', 'started', 'afghanistan', 'pakistan', 'he', 'musical', 'popular', 'singer', 'in', 'during', 'were', 'born', 'pashto', 'prominent', 'mangal', 'who', 'early', 'afghan', 'naghma', 'the', 'laghman', 'sings'}, 1), ({'an', 'ermocrate', 'flowers', 'italian', 'of', 'still', 'painter', 'with', 'was', 'lifes', 'bucchi', 'mainly'}, 1), ({'records', 'here', 'later', 'addition', 'columbia', 'co', 'anywhere', 'debut', 'chris', 'and', 'year', 'by', 'american', 'an', 'nothing', 'is', 'a', 'one', 'but', 'wrote', 'long', 'signed', 'music', 'hot', 'cagle', 'allan', 'gary', 's', 'number', 'randy', 'was', 'songs', 'released', 'anything', 'singer', 'which', 'in', 'charted', 'radio', 'born', 'single', 'charts', 'william', 'billboard', 'brice', 'goes', 'that', 'august', 'houser', 'to', 'also', 'on', 'the', 'country'}, 1), ({'things', 'privately', 'english', 'his', 'merchant', 'of', 'ladslove', 'navy', 'and', 'ac

In [38]:

all_words = []
for tokens, _ in row_data_train:
    all_words += tokens
all_words = sorted(set(all_words))
len(all_words)

3576

In [39]:
def one_hot(data, all_words):
    n = len(data)
    m = len(all_words)
    encoded = np.zeros((n, m), dtype=int)
    labels = np.zeros(n, dtype=int)
    for i in range(n):
        for j in range(m):
            if all_words[j] in data[i][0]:
                encoded[i][j] = 1
        labels[i] = data[i][1]
    return encoded, labels

X_train, y_train = one_hot(row_data_train, all_words)
X_test, y_test = one_hot(row_data_test, all_words)

In [46]:
class WeightedKNN:
    def __init__(self,  n_neighbors=5, metric='l2', n_classes=None, weights='distance'):
        self.n_neighbors = n_neighbors
        self.metric = (lambda a: np.linalg.norm(a, axis=1, ord=1)) if metric == 'l1' else (lambda a: np.linalg.norm(a,axis=1))
        self.n_classes = n_classes or len(np.unique(y_train))
    def fit(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
    def _softmin(self, dists):
        logits = - dists
        exps = np.exp(logits - np.max(logits))
        return exps / np.sum(exps)

    def predict_proba(self, x):
        dists = self.metric(self.X_train-x)
        idx = np.argsort(dists)[:self.n_neighbors]
        neigh_y = self.y_train[idx]
        neigh_d = dists[idx]
        weights = self._softmin(neigh_d)
        print(neigh_y)
        print(neigh_d)
        print(weights)
        scores = np.zeros(self.n_classes)
        for w, y in zip(weights, neigh_y):
            scores[y] += w
        probs = np.exp(scores - np.max(scores))
        probs /= np.sum(probs)
        print(probs)
        return probs

    def predict(self, X):
        return np.array([np.argmax(self.predict_proba(x)) for x in X])


In [58]:
model = WeightedKNN(n_neighbors=25, metric='l1')
model.fit(X_train, y_train)
predicted = model.predict(X_test[0:1])
print(predicted, y_test[0:1])

[1 2 2 2 1 0 2 2 2 2 1 2 1 2 0 0 1 2 2 0 2 2 1 0 0]
[12. 13. 13. 14. 15. 15. 15. 15. 16. 16. 16. 16. 16. 17. 17. 17. 17. 18.
 18. 18. 18. 18. 18. 18. 18.]
[0.45277501 0.16656662 0.16656662 0.06127643 0.02254234 0.02254234
 0.02254234 0.02254234 0.00829286 0.00829286 0.00829286 0.00829286
 0.00829286 0.00305077 0.00305077 0.00305077 0.00305077 0.00112232
 0.00112232 0.00112232 0.00112232 0.00112232 0.00112232 0.00112232
 0.00112232]
[0.2413662  0.38389943 0.37473437]
[1] [2]
