In [43]:
from utils import download_dataset_if_needed

from models.knn_model import KnnModel

import pandas as pd
import numpy as np

import os

from datasets.dataset import Dataset
from datasets import livingroom
from datasets.roomsetup import RoomSetup
import matplotlib.pyplot as plt

import librosa

from sklearn.preprocessing import minmax_scale
from sklearn.model_selection import train_test_split

from utils import download_dataset_if_needed

In [44]:
def min_max_scale(x, min, max):
    return (x - min) / (max - min)


def min_max_unscale(x, min, max):
    return x * (max - min) + min

x_min = -4000
x_max = 500
y_min = -4000
y_max = 2000

In [45]:
download_dataset_if_needed()

path: ./LivingRoom_preprocessed_hack already exist, ignorign dataset downloading


In [46]:
DATASET_PATH = "LivingRoom_preprocessed_hack/Human1"

dr = Dataset(RoomSetup(livingroom.speaker_xyz,
                livingroom.mic_xyzs,
                livingroom.x_min,
                livingroom.x_max,
                livingroom.y_min,
                livingroom.y_max,
                livingroom.walls), DATASET_PATH)
livingroom.mic_xyzs

array([[-3343.275 , -3146.425 ,  1277.9375],
       [-3284.5375,  1587.5   ,  1277.9375],
       [  136.525 ,    53.975 ,  1277.9375],
       [  -76.2   , -3594.1   ,  1277.9375]])

In [47]:
centroid = np.load(os.path.join(DATASET_PATH, "centroid.npy"))
print("Shape of Centroid:")
print(centroid.shape)

#Loading Room Impulse Response (1000 human locations x 10 microphones x M time samples)
RIRs = np.load(os.path.join(DATASET_PATH, "deconvoled_trim.npy"), mmap_mode='r')
print("Shape of RIRs:")
print(RIRs.shape)

Shape of Centroid:
(1000, 2)
Shape of RIRs:
(1000, 4, 667200)


In [48]:
rms_values = []

n_datapoint = RIRs.shape[0]
n_mics = RIRs.shape[1]

for i in range(n_datapoint):
    rms_values.append([])
    for j in range(n_mics):
        rms_values[i].append(np.sqrt(np.mean(RIRs[i, j]**2, axis=-1)))

rms_values

[[0.0005250508, 0.00050231937, 0.0007083557, 0.00042991326],
 [0.00052083656, 0.0004996485, 0.0007070924, 0.00042909803],
 [0.0005219782, 0.00049891206, 0.0007053384, 0.00042922105],
 [0.00052071904, 0.0004915442, 0.000706142, 0.00042609416],
 [0.0005203253, 0.00049116445, 0.0006989711, 0.00042926765],
 [0.00052292855, 0.00047523956, 0.00070268573, 0.00043043995],
 [0.00052255375, 0.0004936408, 0.00070211326, 0.0004275497],
 [0.0005208316, 0.0004898015, 0.0007014926, 0.0004314625],
 [0.00052354555, 0.00050294417, 0.0007030659, 0.0004289229],
 [0.0005212008, 0.00050025404, 0.00070209184, 0.00042767203],
 [0.00051721925, 0.0005035848, 0.00070275366, 0.00042827232],
 [0.0005141903, 0.00048575317, 0.00044236533, 0.00036199432],
 [0.00051343185, 0.000490165, 0.0006880406, 0.0004175915],
 [0.00051372516, 0.00049302826, 0.00068872544, 0.00041907313],
 [0.000513887, 0.0004933613, 0.0006862418, 0.00041899964],
 [0.0005148951, 0.0004929176, 0.00068902154, 0.0004221684],
 [0.0005162434, 0.0004944

In [49]:
rms = np.asarray(rms_values)

rms_normed = np.clip(minmax_scale(rms, axis=1), 0,1)

rms_normed

array([[0.34167767, 0.2600398 , 1.        , 0.        ],
       [0.33000135, 0.25378382, 1.        , 0.        ],
       [0.3359338 , 0.25239635, 1.        , 0.        ],
       ...,
       [0.3619293 , 0.27672982, 0.9999999 , 0.        ],
       [0.35767245, 0.26432538, 0.9999999 , 0.        ],
       [1.        , 0.8847115 , 0.33306885, 0.        ]], dtype=float32)

In [50]:
centroid_normed[:,0] = min_max_scale(centroid[:,0], x_min, x_max)
centroid_normed[:,1] = min_max_scale(centroid[:,1], y_min, y_max)

centroid_normed

array([[0.1708157 , 0.47868705],
       [0.17810198, 0.54258166],
       [0.17933831, 0.62528887],
       ...,
       [0.5071335 , 0.13890059],
       [0.48697703, 0.20347827],
       [0.89876387, 0.68441392]])

In [51]:
knn = KnnModel()

In [52]:
X_train, X_test, y_train, y_test = train_test_split(rms_normed, centroid_normed, test_size=0.2, random_state=42)

In [53]:
y_train

array([[0.36299312, 0.318763  ],
       [0.49827923, 0.85234374],
       [0.42181296, 0.6049897 ],
       ...,
       [0.75794365, 0.10442357],
       [0.37504589, 0.4867089 ],
       [0.35263566, 0.78652167]])

In [54]:
# Train the KNN model
knn.fit(X_train, y_train)

# Make predictions on the testing set
y_pred = knn.predict(X_test)

In [55]:
knn.model.score(X_test, y_test)

0.1983293641982814

In [56]:
y_pred

array([[0.39708833, 0.42692227],
       [0.36415273, 0.37783766],
       [0.56954194, 0.1732475 ],
       [0.50864479, 0.53395064],
       [0.51572787, 0.22084567],
       [0.28866514, 0.86724506],
       [0.56877704, 0.18413417],
       [0.52803218, 0.54580018],
       [0.84610878, 0.14991699],
       [0.33880873, 0.38767211],
       [0.87932219, 0.74125839],
       [0.21767845, 0.36549421],
       [0.58865874, 0.19040582],
       [0.61363612, 0.12821866],
       [0.34944866, 0.33337147],
       [0.27961194, 0.37039559],
       [0.46138713, 0.72567641],
       [0.51273276, 0.54455335],
       [0.52417518, 0.54087366],
       [0.14837726, 0.60592629],
       [0.73152771, 0.64792949],
       [0.63714612, 0.60656979],
       [0.4519479 , 0.56484745],
       [0.48311128, 0.6781836 ],
       [0.34721393, 0.69432189],
       [0.3450538 , 0.38586198],
       [0.22804898, 0.70353647],
       [0.64022202, 0.16704128],
       [0.60652843, 0.19420062],
       [0.44520023, 0.17442868],
       [0.

In [57]:
centroid_pred = y_pred.copy()
centroid_pred[:,0] = min_max_unscale(centroid_pred[:,0], x_min, x_max)
centroid_pred[:,1] = min_max_unscale(centroid_pred[:,1], y_min, y_max)

In [58]:
centroid_pred

array([[-2.21310251e+03, -1.43846641e+03],
       [-2.36131271e+03, -1.73297403e+03],
       [-1.43706127e+03, -2.96051500e+03],
       [-1.71109845e+03, -7.96296157e+02],
       [-1.67922459e+03, -2.67492595e+03],
       [-2.70100686e+03,  1.20347038e+03],
       [-1.44050331e+03, -2.89519497e+03],
       [-1.62385519e+03, -7.25198923e+02],
       [-1.92510493e+02, -3.10049806e+03],
       [-2.47536072e+03, -1.67396732e+03],
       [-4.30501660e+01,  4.47550317e+02],
       [-3.02044698e+03, -1.80703475e+03],
       [-1.35103566e+03, -2.85756507e+03],
       [-1.23863745e+03, -3.23068806e+03],
       [-2.42748102e+03, -1.99977118e+03],
       [-2.74174629e+03, -1.77762647e+03],
       [-1.92375790e+03,  3.54058474e+02],
       [-1.69270259e+03, -7.32679881e+02],
       [-1.64121168e+03, -7.54758030e+02],
       [-3.33230233e+03, -3.64442251e+02],
       [-7.08125323e+02, -1.12423061e+02],
       [-1.13284246e+03, -3.60581236e+02],
       [-1.96623446e+03, -6.10915278e+02],
       [-1.

In [59]:
centroid_test = y_test.copy()
centroid_test[:,0] = min_max_unscale(centroid_test[:,0], x_min, x_max)
centroid_test[:,1] = min_max_unscale(centroid_test[:,1], y_min, y_max)

In [60]:
y_test

array([[ 6.55343082e-01,  6.54923439e-01],
       [ 4.30738642e-01,  4.92277476e-01],
       [ 4.38926865e-01,  2.02849813e-01],
       [ 6.61047928e-01,  7.02093603e-01],
       [ 4.63199162e-01,  2.60286469e-01],
       [ 3.01830386e-01,  8.90863572e-01],
       [ 5.10053797e-01,  1.75624542e-01],
       [ 7.88187927e-01,  7.90745725e-01],
       [ 8.41200973e-01,  9.95572876e-02],
       [ 8.00253807e-01,  7.86135136e-01],
       [ 8.78078969e-01,  7.43771306e-01],
       [ 2.61483107e-01,  4.37457541e-01],
       [ 7.81711744e-01,  1.47653559e-01],
       [ 5.85230780e-01,  1.07037697e-01],
       [ 7.98891235e-01,  2.14209218e-01],
       [ 1.25499919e-01,  5.98457265e-01],
       [ 5.21938498e-01,  6.87053744e-01],
       [ 6.52735954e-01,  7.32695130e-01],
       [ 3.53474027e-01,  3.24860897e-01],
       [ 8.80607685e-02,  3.00104158e-01],
       [ 7.36861266e-02,  2.19159839e-01],
       [ 8.46019912e-01,  6.42252689e-01],
       [ 2.54527119e-01,  6.14690338e-01],
       [ 6.

In [61]:
centroid_test

array([[-1050.95613037,   -70.45936782],
       [-2061.67611058, -1046.33514568],
       [-2024.82910906, -2782.90112155],
       [-1025.28432354,   212.56161767],
       [-1915.60377208, -2438.28118592],
       [-2641.76326138,  1345.18143055],
       [-1704.75791182, -2946.25274819],
       [ -453.15433072,   744.47435037],
       [ -214.59562343, -3402.65627454],
       [ -398.85786846,   716.81081649],
       [  -48.64463981,   462.62783888],
       [-2823.32601876, -1375.25475582],
       [ -482.29715011, -3114.0786461 ],
       [-1366.46149041, -3357.773817  ],
       [ -404.98944029, -2714.74469138],
       [-3435.25036395,  -409.25640754],
       [-1651.27675848,   122.32246552],
       [-1062.68820725,   396.17077806],
       [-2409.36688042, -2050.83462068],
       [-3603.72654182, -2199.37505471],
       [-3668.41243035, -2685.04096338],
       [ -192.91039557,  -146.48386805],
       [-2854.62796477,  -311.85797361],
       [-1166.12880182,   908.91483174],
       [ -876.78