In [26]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import model_selection

def main():
    X = np.load("../util/X_mfccs/X_mfccs_4000.npy")
    X = np.reshape(X, (X.shape[0], -1))
    y = np.load("../util/labels/labels_4000.npy")
    x_train, x_valid, y_train, y_valid = model_selection.train_test_split(X, y, test_size=0.10, random_state=42)
    avg_0, avg_1 = find_averages(x_train, y_train)
    
    predictions = predict(avg_0, avg_1, x_valid)
    
    TP, TN, FP, FN = find_counts(predictions, y_valid)
    print(find_accuracies(TP, TN, FP, FN))
    
    
def find_averages(x_train, y_train):
    avg_0 = 0
    num_0 = 0
    avg_1 = 0
    num_1 = 0
    for i in range(y_train.shape[0]):
        if y_train[i] == 0:
            avg_0 += np.linalg.norm(x_train[i])
            num_0 += 1
        elif y_train[i] == 1:
            avg_1 += np.linalg.norm(x_train[i])
            num_1 += 1
    avg_0 /= num_0
    avg_1 /= num_1
    return (avg_0, avg_1)
    
def predict(avg_0, avg_1, x_valid):
    predictions = []
    for x in x_valid:
        if np.abs(np.linalg.norm(x) - avg_0) < np.abs(np.linalg.norm(x) - avg_1):
            predictions.append(0)
        else:
            predictions.append(1)
    return np.array(predictions)

def find_counts(predictions, truth):
    TP = 0; TN = 0; FP = 0; FN = 0
    for i in range(predictions.shape[0]):
        if predictions[i] < 0.5 and truth[i] == 0:
            TN += 1
        elif predictions[i] >= 0.5 and truth[i] == 1:
            TP += 1
        elif predictions[i] < 0.5 and truth[i] == 1:
            FN += 1
        elif predictions[i] >= 0.5 and truth[i] == 0:
            FP += 1
    return TP, TN, FP, FN

def find_accuracies(TP, TN, FP, FN):
    accuracy = (TP + TN) / (TP + TN + FP + FN)
    a_0 = TP / (TP + FN)
    a_1 = TN / (TN + FP)
    balanced_accuracy = 1/2 * (a_0 + a_1)
    return accuracy, balanced_accuracy, a_0, a_1
    
    
if __name__ == '__main__':
    main()

(0.7611874169251218, 0.7452242399151775, 0.6981891348088531, 0.7922593450215019)
