In [117]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from typing import Any, Callable
import helpers

labeled_df = pd.read_csv('data/labeled_penguins.csv')

labeled_df.head(5)

X_train, y_train, X_val, y_val, X_test, y_test, feature_names, label_map = helpers.preprocess_data(df=labeled_df, label="species", train_size=0.6, val_size=0.2, seed=42)
unlabeled_df = pd.read_csv('data/unlabeled_penguins.csv')

unlabeled_df.head(5)

X_unlabeled = unlabeled_df.to_numpy()

mean=(sum(X_train[:,])/len(X_train))
std=(((sum(X_train[:,]**2)-len(X_train)*(mean**2)))/(len(X_train)))**0.5
print(mean,std)

def normalize(X: np.ndarray, mean: np.ndarray, std: np.ndarray):
    X_normalized = (X-mean)/(std)
    return X_normalized

feature_names = ['Bill length (normalized)', 'Body mass (normalized)']

X_train = normalize(X_train, mean, std)
X_val = normalize(X_val, mean, std)
X_test = normalize(X_test, mean, std)
X_unlabeled = normalize(X_unlabeled, mean, std)

def manhattan_dist(sample: np.ndarray, X: np.ndarray):
    distances = sum((abs(sample-X)).T)
    return distances
   
def euclidean_dist(sample: np.ndarray, X: np.ndarray):
    distances = (sum(((sample-X)**2).T))**0.5
    return distances

print(f'Manhattan: {np.round(manhattan_dist(X_val[0], X_train[:3]), decimals=1)}')
print(f'Euclidean: {np.round(euclidean_dist(X_val[0], X_train[:3]), decimals=1)}')

euclidean_dist(X_val[0],X_train)



[  44.73442623 4184.22131148] [  5.30537572 725.93354833]
Manhattan: [4.2 2.7 2.2]
Euclidean: [3.1 2.  2. ]


TypeError: Field elements must be 2- or 3-tuples, got 'array([3.08355209, 2.00636864, 2.0046187 , 1.45006931, 2.15077378,
       2.05984101, 1.51634492, 3.25494153, 2.43813941, 4.55995559,
       2.21670824, 4.28236828, 4.1737135 , 4.06727207, 1.80673597,
       2.89583413, 2.73880517, 1.78572866, 2.48593271, 4.06192708,
       2.97701851, 0.89017565, 1.97960308, 2.19969413, 1.26024664,
       1.99072484, 4.2964432 , 1.91949155, 1.60792544, 1.30879546,
       2.2001969 , 1.3776655 , 1.75148111, 2.58175155, 1.13318262,
       1.32806259, 3.25680593, 2.54955766, 2.82862664, 3.04836407,
       2.77413687, 2.86944081, 1.17909586, 2.54964252, 2.46029009,
       1.75327571, 1.54932486, 2.15828339, 2.44794027, 2.46951182,
       2.22459544, 2.66678781, 1.31529422, 2.37783648, 3.40034412,
       2.50499123, 2.99614201, 2.50954104, 3.52142309, 3.23918577,
       1.56586058, 2.69918063, 1.96295813, 1.83285571, 2.21560132,
       1.59682391, 2.06390085, 1.72404217, 1.1625836 , 4.36275657,
       2.84900057, 2.94223705, 0.82261857, 2.90661042, 1.96770943,
       1.38075664, 1.66731189, 2.1470577 , 4.1737135 , 1.13737926,
       1.99283953, 1.76287897, 1.34744173, 2.32785854, 1.50216463,
       3.32751065, 1.22047376, 2.35270994, 2.33819578, 1.913904  ,
       2.12712697, 2.32844357, 2.07668106, 3.66620441, 2.01890476,
       2.66021024, 3.87932892, 3.1793042 , 2.46482455, 2.56870286,
       1.89518347, 2.52997219, 1.78309185, 1.66807654, 2.4841466 ,
       3.98102091, 3.44404771, 1.97113613, 2.10045286, 1.04586773,
       2.92422609, 3.54418365, 2.56806585, 3.95335374, 2.82797295,
       2.78051808, 1.03744198, 1.38910705, 2.13063803, 2.34302551,
       1.98184448, 2.55240933])'