In [1]:
import os
import sys
import time
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import classification_report, confusion_matrix
from sklearn.ensemble import RandomForestClassifier

from src.data import load_annotation, load_data
from src.utils.train import train_test_split

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

In [2]:
# config
DATAFRAME_PATH = "../data/raw/data_frames"
ANNOTATION_PATH = "../data/processed/Annotation.csv"
CATEGORY = "Urination"
THRESHOLD = 0.3

FEATURE_NAMES = ['Min', 'Max', 'Median', 'Mean', 'LogVariance', 'LinearTrend']
SOURCES = ['TotalWeight', 'WaterDistance', 'AudioDelay', 'RadarSum', 'AudioDelay2', 'AudioDelay4']
CATEGORY = 'Urination'

ANNOTATIONS = load_annotation.get_annotation(ANNOTATION_PATH)
USER_IDS = load_annotation.get_complete_ids(ANNOTATION_PATH, CATEGORY)
TRAIN_IDS, TEST_IDS = train_test_split(USER_IDS)

print ("Training on {} cases, {} ...".format(len(TRAIN_IDS), TRAIN_IDS[:5]))
print ("Testing  on {} cases, {} ...".format(len(TEST_IDS), TEST_IDS[:5]))

Training on 117 cases, [2021, 2005, 1918, 2027, 2015] ...
Testing  on 30 cases, [1802, 1806, 1830, 1831, 1836] ...


In [3]:
rf_train_config = {
    'USE_IDS': TRAIN_IDS,
    'ANNOTATION_PATH': ANNOTATION_PATH,
    'FEATURE_NAMES': FEATURE_NAMES,
    'SOURCES': SOURCES,
    'CATEGORY': CATEGORY
}

rf_test_config = {
    'USE_IDS': TEST_IDS,
    'ANNOTATION_PATH': ANNOTATION_PATH,
    'FEATURE_NAMES': FEATURE_NAMES,
    'SOURCES': SOURCES,
    'CATEGORY': CATEGORY
}

dataset = {}
dataset['train'] = load_data.RandomForestDataset(rf_train_config)
dataset['test'] = load_data.RandomForestDataset(rf_test_config)

In [4]:
test_x, test_y = dataset['test'].get_all_features_and_labels()

Updating user : 1802
Updating user : 1806
Updating user : 1830
Updating user : 1831
Updating user : 1836
Updating user : 1864
Updating user : 1871
Updating user : 1876
Updating user : 1889
Updating user : 1904
Updating user : 1914
Updating user : 1921
Updating user : 1924
Updating user : 1925
Updating user : 1927
Updating user : 1932
Updating user : 1937
Updating user : 1939
Updating user : 1949
Updating user : 1951
Updating user : 1996
Updating user : 1999
Updating user : 2013
Updating user : 2024
Updating user : 2042
Updating user : 2062
updating user 2062 failed
Updating user : 2066
updating user 2066 failed
Updating user : 2067
updating user 2067 failed
Updating user : 2070
updating user 2070 failed
Updating user : 2073
updating user 2073 failed


In [28]:
def classification_result(model, testX, testY, threshold = 0.5):
    testYPredProb = model.predict_proba(testX)
    testYPred = (testYPredProb[:, 1] > threshold).astype(int)
    print (f"threshold = {threshold}", "\n")
    print (classification_report(testY, testYPred))

In [29]:
import pickle

with open("../randomforest-20210108-032342.pkl", "rb") as f:
    rf = pickle.load(f)

In [33]:
classification_result(
    model = rf,
    testX = test_x,
    testY = test_y,
    threshold = 0.2
)

threshold = 0.2 

              precision    recall  f1-score   support

         0.0       0.98      0.96      0.97      2008
         1.0       0.76      0.87      0.81       303

    accuracy                           0.95      2311
   macro avg       0.87      0.91      0.89      2311
weighted avg       0.95      0.95      0.95      2311



In [31]:
classification_result(
    model = rf,
    testX = test_x,
    testY = test_y,
    threshold = 0.5
)

threshold = 0.5 

              precision    recall  f1-score   support

         0.0       0.97      0.99      0.98      2008
         1.0       0.91      0.81      0.85       303

    accuracy                           0.96      2311
   macro avg       0.94      0.90      0.92      2311
weighted avg       0.96      0.96      0.96      2311



In [32]:
classification_result(
    model = rf,
    testX = test_x,
    testY = test_y,
    threshold = 0.3
)

threshold = 0.3 

              precision    recall  f1-score   support

         0.0       0.98      0.97      0.98      2008
         1.0       0.83      0.85      0.84       303

    accuracy                           0.96      2311
   macro avg       0.90      0.91      0.91      2311
weighted avg       0.96      0.96      0.96      2311

