In [1]:
import pandas as pd
import numpy as np
from scipy.signal import find_peaks
from scipy import stats
import math
from sklearn import linear_model
from sklearn import model_selection
from numpy import mean
from numpy import std
import numpy as np
from sklearn.model_selection import cross_val_score

In [2]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import RepeatedStratifiedKFold, RepeatedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
import pickle
from sklearn.metrics import accuracy_score

In [3]:
PLAYERS_NUMBER = 6
ages = [26, 15, 14, 13, 19, 17]
levels = [1, 2, 1, 2, 1, 2]
backhands = [2, 1, 2, 2, 2, 1]
smash_types = ['Dretes', 'Reves', 'Serve', 'Smash', 'VD', 'VR']
sensor_types = ['ACC', 'GYR']
stroke_numbers = {'Dretes': 30, 'Reves': 30, 'Serve': 15, 'Smash': 15, 'VD': 20, 'VR': 20}

In [24]:
"""FUNCTION TO LOAD ALL SENSOR DATA IN A COMBINED DATASET"""

def load_data (player_number, ages, levels, backhands, smash_types, sensor_types):
    result = []
    interval = player_number * len(smash_types)
    for i in range(interval * len(sensor_types)):
        player = int(i /(len(smash_types) * len(sensor_types)))
        smash = i - player_number * int(i / player_number)
        sensor =  int(i / player_number) % 2
        df = pd.read_csv(f'./Jugador {player+1}/Enregistraments J{player+1}/J{player+1} {smash_types[smash]} {sensor_types[sensor]}.csv').sort_values('UNIX_TIMESTAMP')
        df.drop(columns=['DATE'], inplace=True)
        df.rename(columns={ df.columns[0]: 'X', df.columns[1]: 'Y', df.columns[2]: 'Z', df.columns[3]: 'TIMESTAMP'}, inplace = True)
        df['AGE'] = ages[player]
        df['LEVEL'] = levels[player]
        df['BACKHANDS'] = backhands[player]
        df['SMASH'] = smash
        df['SENSOR'] = sensor
        temp = np.array_split(df, stroke_numbers[smash_types[smash]])
        for j in range(stroke_numbers[smash_types[smash]]):
            dfTemp = temp[j] 
            dfTemp['GROUP'] = str(i+1) + str(j+1)
        result += temp
    return result

data = load_data(PLAYERS_NUMBER, ages, levels, backhands, smash_types, sensor_types)
pd.concat(data)

Unnamed: 0,X,Y,Z,TIMESTAMP,AGE,LEVEL,BACKHANDS,SMASH,SENSOR,GROUP
0,0.013325,0.008126,0.000599,1697802318975,26,1,2,0,0,11
1,-0.110110,-0.069654,-0.005859,1697802318996,26,1,2,0,0,11
2,-0.068091,-0.044559,-0.003320,1697802319016,26,1,2,0,0,11
3,0.019304,0.012501,0.000792,1697802319038,26,1,2,0,0,11
4,0.186550,0.116849,0.005761,1697802319056,26,1,2,0,0,11
...,...,...,...,...,...,...,...,...,...,...
7358,0.040840,0.128904,-0.097510,1697903854058,17,2,1,5,1,7220
7359,0.034853,0.110147,-0.083276,1697903854068,17,2,1,5,1,7220
7360,0.034720,0.108551,-0.085271,1697903854079,17,2,1,5,1,7220
7361,0.034587,0.107221,-0.086867,1697903854090,17,2,1,5,1,7220


In [38]:
def calculate_statistics(df, suffix):
    row = {}
    row['x_mean_' + suffix] = [df['X'].mean()]
    row['y_mean_' + suffix] = [df['Y'].mean()]
    row['z_mean_' + suffix] = [df['Z'].mean()]
    row['x_std_' + suffix] = [df['X'].std()]
    row['y_std_' + suffix] = [df['Y'].std()]
    row['z_std_' + suffix] = [df['Z'].std()]
    row['x_min_' + suffix] = [df['X'].min()]
    row['y_min_' + suffix] = [df['Y'].min()]
    row['z_min_' + suffix] = [df['Z'].min()]
    row['x_max_' + suffix] = [df['X'].max()]
    row['y_max_' + suffix] = [df['Y'].max()]
    row['z_max_' + suffix] = [df['Z'].max()]
    row['x_max_' + suffix] = [df['X'].max()]
    row['y_max_' + suffix] = [df['Y'].max()]
    row['z_max_' + suffix] = [df['Z'].max()]
    row['x_median_' + suffix] = [df['X'].median()]
    row['y_median_' + suffix] = [df['Y'].median()]
    row['z_median_' + suffix] = [df['Z'].median()]
    row['x_mad_' + suffix] = [np.median(np.absolute(df['X'] - row['x_median_' + suffix]))] # median absolute difference
    row['y_mad_' + suffix] = [np.median(np.absolute(df['Y'] - row['y_median_' + suffix]))]
    row['z_mad_' + suffix] = [np.median(np.absolute(df['Z'] - row['z_median_' + suffix]))]
    row['x_iqr_' + suffix] = [np.percentile(df['X'], 75) - np.percentile(df['X'], 25)]
    row['y_iqr_' + suffix] = [np.percentile(df['Y'], 75) - np.percentile(df['Y'], 25)]
    row['z_iqr_' + suffix] = [np.percentile(df['Z'], 75) - np.percentile(df['Z'], 25)]
    row['x_pcount_' + suffix] = [np.sum(df['X'] > 0)]
    row['y_pcount_' + suffix] = [np.sum(df['Y'] > 0)]
    row['z_pcount_' + suffix] = [np.sum(df['Z'] > 0)]
    row['x_ncount_' + suffix] = [np.sum(df['X'] < 0)]
    row['y_ncount_' + suffix] = [np.sum(df['Y'] < 0)]
    row['z_ncount_' + suffix] = [np.sum(df['Z'] < 0)]
    row['x_abvmean_' + suffix] = [np.sum(df['X'] > df['X'].mean())]
    row['y_abvmean_' + suffix] = [np.sum(df['Y'] > df['Y'].mean())]
    row['z_abvmean_' + suffix] = [np.sum(df['Z'] > df['Z'].mean())]
    row['x_cntpeaks_' + suffix] = [len(find_peaks(df['X'])[0])]
    row['y_cntpeaks_' + suffix] = [len(find_peaks(df['Y'])[0])]
    row['z_cntpeaks_' + suffix] = [len(find_peaks(df['Z'])[0])]
    row['x_skew_' + suffix] = [stats.skew(df['X'])]
    row['y_skew_' + suffix] = [stats.skew(df['Y'])]
    row['z_skew_' + suffix] = [stats.skew(df['Z'])]
    row['x_kurt_' + suffix] = [stats.kurtosis(df['X'])]
    row['y_kurt_' + suffix] = [stats.kurtosis(df['Y'])]
    row['z_kurt_' + suffix] = [stats.kurtosis(df['Z'])]
    row['time'] = df.iloc[-1, df.columns.get_loc('UNIX_TIMESTAMP')] - df.iloc[0, df.columns.get_loc('UNIX_TIMESTAMP')]
    if (suffix == 'acc' or suffix == 'ftt_acc'):
        row['x_energy_' + suffix] = [np.sum(df['X']**2)/100]
        row['y_energy_' + suffix] = [np.sum(df['Y']**2)/100]
        row['z_energy_' + suffix] = [np.sum(df['Z']**2)/100]
        row['acceleration'] = [np.mean((df['X']**2 + df['Y']**2 + df['Z']**2)**0.5)]
    return pd.DataFrame.from_dict(row)

def load_csv(smash_types, stroke_numbers, player, smash, sensor):
    df_acc = pd.read_csv(f'./Jugador {player+1}/Enregistraments J{player+1}/J{player+1} {smash_types[smash]} {sensor}.csv').sort_values('UNIX_TIMESTAMP')
    df_acc.rename(columns={ df_acc.columns[0]: 'X', df_acc.columns[1]: 'Y', df_acc.columns[2]: 'Z'}, inplace = True)
    temp = np.array_split(df_acc, stroke_numbers[smash_types[smash]])
    return temp

In [39]:
"""FUNCTION TO LOAD STATISTICAL DATA FOR EACH STROKE"""

def load_data (player_number, ages, levels, backhands, smash_types, stroke_numbers, sensor_types):
    interval = player_number * len(smash_types)
    result = []
    for i in range(interval):
        player = int(i /len(smash_types))
        smash = i - player_number * int(i / player_number)
        temp = load_csv(smash_types, stroke_numbers, player, smash, sensor_types[0])
        temp2 = load_csv(smash_types, stroke_numbers, player, smash, sensor_types[1])
        for j in range(len(temp)):
            df = temp[j]
            row1 = calculate_statistics(df, 'acc')
            df2 = temp2[j]
            row3 = calculate_statistics(df2, 'gyr')

            row_result = pd.concat([row1, row3], axis=1)
            row_result['player'] = float(player)
            row_result['age'] = float(ages[player])
            row_result['level'] = float(levels[player])
            row_result['backhands'] = float(backhands[player])
            row_result['smash'] = float(smash)

            df['distance_xy'] = 0
            df['distance_xz'] = 0
            df['distance_yz'] = 0
            df['angle_z'] = 0
            df['angle_y'] = 0
            df['angle_x'] = 0
            df['xy_x'] = 0
            df['xy_y'] = 0
            df['xz_x'] = 0
            df['xz_y'] = 0
            df['yz_x'] = 0
            df['yz_y'] = 0

            for k in range(len(df)): 
                if k != 0:
                    time = df.iloc[k, df.columns.get_loc('UNIX_TIMESTAMP')] - df.iloc[k-1, df.columns.get_loc('UNIX_TIMESTAMP')]
                else: 
                    time = 0
                df.iloc[k, df.columns.get_loc('distance_xy')] = math.sqrt(df.iloc[k, df.columns.get_loc('X')]**2*(time)**4 + df.iloc[k, df.columns.get_loc('Y')]**2*time**4)
                df.iloc[k, df.columns.get_loc('distance_xz')] = math.sqrt(df.iloc[k, df.columns.get_loc('X')]**2*(time)**4 + df.iloc[k, df.columns.get_loc('Z')]**2*time**4)
                df.iloc[k, df.columns.get_loc('distance_yz')] = math.sqrt(df.iloc[k, df.columns.get_loc('Y')]**2*(time)**4 + df.iloc[k, df.columns.get_loc('Z')]**2*time**4)
                if k != 0:
                    z = k
                else:
                    z = -1
                df.iloc[k, df.columns.get_loc('angle_z')] = (df2.iloc[z+1, df.columns.get_loc('Z')] + df2.iloc[z+2, df.columns.get_loc('Z')])*time
                df.iloc[k, df.columns.get_loc('angle_y')] = (df2.iloc[z+1, df.columns.get_loc('Y')] + df2.iloc[z+2, df.columns.get_loc('Y')])*time
                df.iloc[k, df.columns.get_loc('angle_x')] = (df2.iloc[z+1, df.columns.get_loc('X')] + df2.iloc[z+2, df.columns.get_loc('X')])*time

                df.iloc[k, df.columns.get_loc('xy_x')] = df.iloc[k, df.columns.get_loc('distance_xy')] * math.cos(df.iloc[k, df.columns.get_loc('angle_z')])
                df.iloc[k, df.columns.get_loc('xy_y')] = df.iloc[k, df.columns.get_loc('distance_xy')] * math.sin(df.iloc[k, df.columns.get_loc('angle_z')])
                df.iloc[k, df.columns.get_loc('xz_x')] = df.iloc[k, df.columns.get_loc('distance_xz')] * math.cos(df.iloc[k, df.columns.get_loc('angle_y')])
                df.iloc[k, df.columns.get_loc('xz_y')] = df.iloc[k, df.columns.get_loc('distance_xz')] * math.sin(df.iloc[k, df.columns.get_loc('angle_y')])
                df.iloc[k, df.columns.get_loc('yz_x')] = df.iloc[i, df.columns.get_loc('distance_yz')] * math.cos(df.iloc[k, df.columns.get_loc('angle_x')])
                df.iloc[k, df.columns.get_loc('yz_y')] = df.iloc[i, df.columns.get_loc('distance_yz')] * math.sin(df.iloc[k, df.columns.get_loc('angle_x')])

            # AÑADO LA DISTANCIA TOTAL, EL ÁNGULO TOTAL, EL PUNTO FINAL Y EL PUNTO INTERMEDIO CALCULADOS (SE PUEDEN AÑADIR MÁS PUNTOS PARA TENER MÁS DATOS)
            row_result['distance_xy'] = df['distance_xy'].sum()
            row_result['distance_xz'] = df['distance_xz'].sum()
            row_result['distance_yz'] = df['distance_yz'].sum()
            row_result['angle_z'] = df['angle_x'].sum()
            row_result['angle_y'] = df['angle_y'].sum()
            row_result['angle_x'] = df['angle_z'].sum()
            row_result['xy_y_last'] = df.iloc[stroke_numbers[smash_types[smash]] - 1, df.columns.get_loc('xy_y')]
            row_result['xy_y_middle']= df.iloc[stroke_numbers[smash_types[smash]]//2, df.columns.get_loc('xy_y')]
            row_result['xy_x_last'] = df.iloc[stroke_numbers[smash_types[smash]] - 1, df.columns.get_loc('xy_x')]
            row_result['xy_x_middle']= df.iloc[stroke_numbers[smash_types[smash]]//2, df.columns.get_loc('xy_x')]
            row_result['xz_y_last'] = df.iloc[stroke_numbers[smash_types[smash]] - 1, df.columns.get_loc('xz_y')]
            row_result['xz_y_middle']= df.iloc[stroke_numbers[smash_types[smash]]//2, df.columns.get_loc('xz_y')]
            row_result['xz_x_last'] = df.iloc[stroke_numbers[smash_types[smash]] - 1, df.columns.get_loc('xz_x')]
            row_result['xz_x_middle']= df.iloc[stroke_numbers[smash_types[smash]]//2, df.columns.get_loc('xz_x')]
            row_result['yz_y_last'] = df.iloc[stroke_numbers[smash_types[smash]] - 1, df.columns.get_loc('xz_y')]
            row_result['yz_y_middle']= df.iloc[stroke_numbers[smash_types[smash]]//2, df.columns.get_loc('yz_y')]
            row_result['yz_x_last'] = df.iloc[stroke_numbers[smash_types[smash]] - 1, df.columns.get_loc('yz_x')]
            row_result['yz_x_middle']= df.iloc[stroke_numbers[smash_types[smash]]//2, df.columns.get_loc('yz_x')]
            
            result.append(row_result)
    return pd.concat(result)

data = load_data(PLAYERS_NUMBER, ages, levels, backhands, smash_types, stroke_numbers, sensor_types)
data

Unnamed: 0,x_mean_acc,y_mean_acc,z_mean_acc,x_std_acc,y_std_acc,z_std_acc,x_min_acc,y_min_acc,z_min_acc,x_max_acc,y_max_acc,z_max_acc,x_median_acc,y_median_acc,z_median_acc,x_mad_acc,y_mad_acc,z_mad_acc,x_iqr_acc,y_iqr_acc,z_iqr_acc,x_pcount_acc,y_pcount_acc,z_pcount_acc,x_ncount_acc,y_ncount_acc,z_ncount_acc,x_abvmean_acc,y_abvmean_acc,z_abvmean_acc,x_cntpeaks_acc,y_cntpeaks_acc,z_cntpeaks_acc,x_skew_acc,y_skew_acc,z_skew_acc,x_kurt_acc,y_kurt_acc,z_kurt_acc,time,x_energy_acc,y_energy_acc,z_energy_acc,acceleration,x_mean_gyr,y_mean_gyr,z_mean_gyr,x_std_gyr,y_std_gyr,z_std_gyr,x_min_gyr,y_min_gyr,z_min_gyr,x_max_gyr,y_max_gyr,z_max_gyr,x_median_gyr,y_median_gyr,z_median_gyr,x_mad_gyr,y_mad_gyr,z_mad_gyr,x_iqr_gyr,y_iqr_gyr,z_iqr_gyr,x_pcount_gyr,y_pcount_gyr,z_pcount_gyr,x_ncount_gyr,y_ncount_gyr,z_ncount_gyr,x_abvmean_gyr,y_abvmean_gyr,z_abvmean_gyr,x_cntpeaks_gyr,y_cntpeaks_gyr,z_cntpeaks_gyr,x_skew_gyr,y_skew_gyr,z_skew_gyr,x_kurt_gyr,y_kurt_gyr,z_kurt_gyr,time.1,player,age,level,backhands,smash,distance_xy,distance_xz,distance_yz,angle_z,angle_y,angle_x,xy_y_last,xy_y_middle,xy_x_last,xy_x_middle,xz_y_last,xz_y_middle,xz_x_last,xz_x_middle,yz_y_last,yz_y_middle,yz_x_last,yz_x_middle
0,0.119649,0.075253,-0.011181,0.056442,0.036162,0.011034,-0.110110,-0.069654,-0.029241,0.333318,0.198022,0.007792,0.129523,0.079631,-0.011266,0.037786,0.025356,0.010673,0.072021,0.050243,0.021003,127,127,27,3,3,103,75,74,65,9,9,5,-0.550520,-0.607466,-0.099980,2.905261,2.179081,-1.537183,2614,0.022720,0.009049,0.000320,0.145690,-0.346136,-0.289389,0.029494,0.287170,0.198342,0.087777,-0.980284,-0.889825,-0.254882,0.004922,0.004922,0.386447,-0.316807,-0.314412,0.013236,0.235393,0.145666,0.019821,0.482427,0.259737,0.052546,25,28,186,235,230,73,144,120,104,30,25,29,-0.539721,-0.105965,0.206929,-0.768824,-0.565476,3.489603,2634,0.0,26.0,1.0,2.0,0.0,7751.362958,6600.962656,4191.654719,-730.794677,-908.621519,88.569213,-4.485667,2.516997,35.684831,60.030828,-0.817015,9.925753,30.701116,50.109120,-0.817015,0.0,0.0,0.0
0,0.158366,0.084453,-0.044530,0.064365,0.034816,0.025279,0.000542,0.000191,-0.101574,0.280684,0.144491,-0.000102,0.152324,0.087056,-0.043000,0.050265,0.026830,0.020106,0.106840,0.052667,0.036034,130,130,0,0,0,130,63,66,68,4,3,3,0.031116,-0.248425,-0.484940,-0.808359,-0.773329,-0.605872,2614,0.037948,0.010836,0.003402,0.185495,-0.393011,-0.227487,0.215061,0.426748,0.547005,0.196675,-1.263235,-1.111982,-0.076757,0.256611,0.654765,0.640265,-0.307095,-0.065250,0.179255,0.254416,0.442917,0.146131,0.580103,1.029372,0.322227,44,121,226,216,139,34,169,145,104,6,6,8,-0.610069,-0.188370,0.512483,-0.651549,-1.404568,-0.709934,2625,0.0,26.0,1.0,2.0,0.0,9583.617828,8808.333980,5134.556346,-3574.799940,-1986.363467,1537.315632,39.244966,-77.443169,96.189802,-22.048062,-85.554220,-8.288122,45.291439,70.313240,-85.554220,0.0,0.0,-0.0
0,0.705115,0.575456,-0.977364,0.800909,1.265109,1.529748,-0.274389,-0.661758,-5.194325,2.415067,3.896989,0.082945,0.494091,0.015037,-0.249011,0.646356,0.117098,0.301923,1.460489,0.668150,1.095494,94,68,36,36,62,94,57,34,94,2,3,2,0.479716,1.587956,-1.692850,-1.046606,1.130063,1.539312,2614,1.473821,2.495141,4.260577,1.598933,-1.329833,-0.629818,1.348140,1.904659,3.033405,3.271360,-7.805427,-19.245981,-0.814797,1.354094,4.870294,21.817148,-0.870669,-0.303703,0.231868,1.006691,0.795442,0.751809,1.936058,1.487820,2.358656,68,104,151,192,156,109,160,159,81,6,10,9,-1.550616,-3.350901,4.115060,2.395200,16.671190,19.570853,2623,0.0,26.0,1.0,2.0,0.0,64039.823768,73550.805654,67811.514720,-3408.123687,-8271.828911,6724.809839,-38.913111,-149.488290,33.426953,23.089789,19.520663,-81.484065,-48.340884,121.528266,19.520663,-0.0,-0.0,0.0
0,0.750097,1.026938,-1.270542,0.998745,1.910227,1.902650,-0.722997,-0.250944,-6.398224,3.125900,5.935603,0.295709,0.727732,0.067038,-0.447860,0.778379,0.174045,0.583183,1.519565,1.099968,1.671153,92,76,38,38,54,92,65,33,89,3,3,2,0.594898,1.602820,-1.562669,-0.296258,1.004371,1.222706,2614,2.018204,6.078147,6.768461,2.253215,-1.893532,-0.807340,1.568602,1.692539,3.308180,3.627724,-6.636508,-24.220037,-2.443328,0.894481,6.954446,28.441685,-1.651345,-0.605877,0.797371,1.131005,0.658888,1.319573,2.258718,1.275407,2.566612,42,87,193,217,173,67,145,144,92,7,8,6,-0.407977,-3.332839,4.222910,-0.587256,19.475408,22.712071,2625,0.0,26.0,1.0,2.0,0.0,94162.016143,97406.917130,97646.040084,-6885.480609,-6971.656240,6889.834763,-131.047220,206.989865,-120.339399,161.558075,-89.643040,-273.639762,-175.381026,-21.668635,-89.643040,-0.0,-0.0,0.0
0,0.848838,0.973131,-1.458308,1.081737,2.012801,2.271412,-0.769161,-0.798550,-7.525879,2.437628,6.057998,0.426785,0.928033,-0.015841,-0.533851,1.108709,0.196532,0.728995,2.143037,1.317004,1.874395,88,59,42,42,71,88,67,35,96,2,3,3,-0.024427,1.535184,-1.493724,-1.460534,0.875619,0.993267,2615,2.446183,6.457343,9.420174,2.461836,-1.252286,0.777510,-0.256891,2.391435,4.140240,4.124073,-6.461044,-4.869362,-30.057446,7.765651,29.208992,3.876174,-0.943103,-0.278694,0.084140,1.256584,1.209825,1.563147,2.749226,2.405082,3.349084,81,108,133,179,152,127,140,85,145,9,9,7,0.101236,3.507514,-3.907596,0.891800,16.356610,20.350650,2629,0.0,26.0,1.0,2.0,0.0,97079.686757,111619.083348,110140.471945,799.956021,11140.067038,-6939.599536,-409.987607,127.110852,-2105.047947,-161.786748,893.371813,13.391245,-2529.594695,-263.272013,893.371813,0.0,0.0,-0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,1.124059,0.868385,-0.557649,1.769305,1.793278,0.640079,-2.124098,-1.618071,-2.107992,5.115294,6.160065,0.137966,0.745668,0.317028,-0.226515,0.819809,0.428861,0.299584,1.895584,0.833830,0.956650,146,134,28,38,50,156,75,40,107,5,4,6,0.595994,1.579119,-0.937681,0.043571,1.770909,-0.220320,3708,8.053562,7.272532,1.321944,2.114526,0.177239,-0.005505,0.000754,2.094841,0.360890,0.924471,-7.820592,-1.390942,-4.466952,4.243997,0.905656,1.955913,0.446576,0.054009,0.075893,1.309064,0.212779,0.417309,2.450877,0.449668,0.834652,216,218,204,152,150,164,201,220,203,8,14,12,-0.709408,-0.839562,-1.723582,1.739160,1.608635,6.632414,3719,5.0,17.0,2.0,1.0,5.0,150476.603535,123709.410214,105031.769317,2243.132459,675.118539,-79.980126,607.160878,-95.941516,764.806596,286.767877,19.728675,70.428455,950.510208,-168.317429,19.728675,-0.0,0.0,0.0
0,0.175868,0.126534,-0.090149,0.063361,0.044371,0.048972,0.034111,0.026536,-0.259139,0.366540,0.257108,-0.010025,0.167381,0.122452,-0.078358,0.031544,0.028989,0.022199,0.061695,0.059532,0.047954,184,184,0,0,0,184,85,87,111,9,9,8,0.636391,0.479981,-1.412384,1.131113,0.751953,2.593557,3711,0.064257,0.033063,0.019342,0.235701,0.016071,0.007215,0.022708,0.441973,0.172015,0.203104,-1.055179,-0.413451,-0.384983,1.443622,0.338025,0.474910,-0.018890,0.024078,0.032592,0.240714,0.091058,0.142473,0.499720,0.204398,0.292795,173,202,210,195,166,157,163,195,191,16,15,15,0.363590,-0.156499,-0.061046,1.719475,-0.398642,-0.672954,3722,5.0,17.0,2.0,1.0,5.0,16328.455012,14926.845369,11764.884720,-1202.284235,-198.746054,124.284340,125.089248,99.077492,36.171542,80.914541,-96.802980,-36.527058,-84.167728,-121.459865,-96.802980,-0.0,-0.0,-0.0
0,1.158998,0.109567,-0.447495,1.400385,1.075495,0.455428,-1.683910,-2.239668,-1.848851,4.628438,4.601656,0.093953,1.040378,0.076231,-0.242911,0.786498,0.497102,0.165139,1.652192,0.967944,0.474715,161,109,9,23,75,175,83,82,121,7,4,7,0.319763,0.749889,-1.385160,0.071514,2.112440,1.169408,3708,6.060401,2.138831,0.748032,1.732327,0.061763,0.199856,-0.313364,1.960954,0.766789,1.811544,-6.923450,-1.493108,-7.408870,4.250782,3.628609,1.868247,0.443316,0.146730,0.186572,0.971106,0.332304,0.696070,1.914574,0.630853,1.673926,228,231,221,140,137,147,224,158,253,12,12,11,-1.362104,1.523101,-2.203906,2.623797,4.374600,5.056808,3719,5.0,17.0,2.0,1.0,5.0,125376.431839,114615.660304,70388.259972,-556.106100,1776.274499,-3346.591036,28.496749,1.979776,170.066905,118.102978,-188.640493,-110.105815,-42.558142,64.215854,-188.640493,-0.0,-0.0,0.0
0,1.027020,1.041297,-0.401656,1.670558,2.408641,0.629152,-0.748236,-2.550784,-2.492981,5.904718,7.133821,0.474394,0.547226,0.140985,-0.215343,0.733870,0.408218,0.206552,1.424532,1.592294,0.422294,129,108,38,55,76,146,61,49,126,5,5,7,1.535490,1.176548,-1.677749,1.695599,0.323017,2.508161,3708,7.047879,12.611952,1.021216,2.251000,0.316169,-0.163348,-0.053651,2.080016,0.882088,1.130701,-6.167850,-1.708348,-3.680224,5.627224,2.762729,1.663650,0.366293,-0.011973,0.103562,0.937782,0.606209,0.742829,1.939318,1.162333,1.562382,227,179,211,141,188,157,188,213,220,11,11,10,-0.155978,0.251532,-0.879793,1.115778,0.184891,0.670894,3719,5.0,17.0,2.0,1.0,5.0,158192.081387,109483.024669,131114.913994,2787.002306,264.663371,551.001809,-73.442786,2753.199453,-1180.669030,364.699919,688.652049,1598.789843,450.136767,1057.231814,688.652049,0.0,-0.0,0.0


In [41]:
data = data.reset_index()
data = data.drop('index', axis=1)
pd.set_option('display.max_columns', None)
data.to_csv('final_data')
data

Unnamed: 0,x_mean_acc,y_mean_acc,z_mean_acc,x_std_acc,y_std_acc,z_std_acc,x_min_acc,y_min_acc,z_min_acc,x_max_acc,y_max_acc,z_max_acc,x_median_acc,y_median_acc,z_median_acc,x_mad_acc,y_mad_acc,z_mad_acc,x_iqr_acc,y_iqr_acc,z_iqr_acc,x_pcount_acc,y_pcount_acc,z_pcount_acc,x_ncount_acc,y_ncount_acc,z_ncount_acc,x_abvmean_acc,y_abvmean_acc,z_abvmean_acc,x_cntpeaks_acc,y_cntpeaks_acc,z_cntpeaks_acc,x_skew_acc,y_skew_acc,z_skew_acc,x_kurt_acc,y_kurt_acc,z_kurt_acc,time,x_energy_acc,y_energy_acc,z_energy_acc,acceleration,x_mean_gyr,y_mean_gyr,z_mean_gyr,x_std_gyr,y_std_gyr,z_std_gyr,x_min_gyr,y_min_gyr,z_min_gyr,x_max_gyr,y_max_gyr,z_max_gyr,x_median_gyr,y_median_gyr,z_median_gyr,x_mad_gyr,y_mad_gyr,z_mad_gyr,x_iqr_gyr,y_iqr_gyr,z_iqr_gyr,x_pcount_gyr,y_pcount_gyr,z_pcount_gyr,x_ncount_gyr,y_ncount_gyr,z_ncount_gyr,x_abvmean_gyr,y_abvmean_gyr,z_abvmean_gyr,x_cntpeaks_gyr,y_cntpeaks_gyr,z_cntpeaks_gyr,x_skew_gyr,y_skew_gyr,z_skew_gyr,x_kurt_gyr,y_kurt_gyr,z_kurt_gyr,time.1,player,age,level,backhands,smash,distance_xy,distance_xz,distance_yz,angle_z,angle_y,angle_x,xy_y_last,xy_y_middle,xy_x_last,xy_x_middle,xz_y_last,xz_y_middle,xz_x_last,xz_x_middle,yz_y_last,yz_y_middle,yz_x_last,yz_x_middle
0,0.119649,0.075253,-0.011181,0.056442,0.036162,0.011034,-0.110110,-0.069654,-0.029241,0.333318,0.198022,0.007792,0.129523,0.079631,-0.011266,0.037786,0.025356,0.010673,0.072021,0.050243,0.021003,127,127,27,3,3,103,75,74,65,9,9,5,-0.550520,-0.607466,-0.099980,2.905261,2.179081,-1.537183,2614,0.022720,0.009049,0.000320,0.145690,-0.346136,-0.289389,0.029494,0.287170,0.198342,0.087777,-0.980284,-0.889825,-0.254882,0.004922,0.004922,0.386447,-0.316807,-0.314412,0.013236,0.235393,0.145666,0.019821,0.482427,0.259737,0.052546,25,28,186,235,230,73,144,120,104,30,25,29,-0.539721,-0.105965,0.206929,-0.768824,-0.565476,3.489603,2634,0.0,26.0,1.0,2.0,0.0,7751.362958,6600.962656,4191.654719,-730.794677,-908.621519,88.569213,-4.485667,2.516997,35.684831,60.030828,-0.817015,9.925753,30.701116,50.109120,-0.817015,0.0,0.0,0.0
1,0.158366,0.084453,-0.044530,0.064365,0.034816,0.025279,0.000542,0.000191,-0.101574,0.280684,0.144491,-0.000102,0.152324,0.087056,-0.043000,0.050265,0.026830,0.020106,0.106840,0.052667,0.036034,130,130,0,0,0,130,63,66,68,4,3,3,0.031116,-0.248425,-0.484940,-0.808359,-0.773329,-0.605872,2614,0.037948,0.010836,0.003402,0.185495,-0.393011,-0.227487,0.215061,0.426748,0.547005,0.196675,-1.263235,-1.111982,-0.076757,0.256611,0.654765,0.640265,-0.307095,-0.065250,0.179255,0.254416,0.442917,0.146131,0.580103,1.029372,0.322227,44,121,226,216,139,34,169,145,104,6,6,8,-0.610069,-0.188370,0.512483,-0.651549,-1.404568,-0.709934,2625,0.0,26.0,1.0,2.0,0.0,9583.617828,8808.333980,5134.556346,-3574.799940,-1986.363467,1537.315632,39.244966,-77.443169,96.189802,-22.048062,-85.554220,-8.288122,45.291439,70.313240,-85.554220,0.0,0.0,-0.0
2,0.705115,0.575456,-0.977364,0.800909,1.265109,1.529748,-0.274389,-0.661758,-5.194325,2.415067,3.896989,0.082945,0.494091,0.015037,-0.249011,0.646356,0.117098,0.301923,1.460489,0.668150,1.095494,94,68,36,36,62,94,57,34,94,2,3,2,0.479716,1.587956,-1.692850,-1.046606,1.130063,1.539312,2614,1.473821,2.495141,4.260577,1.598933,-1.329833,-0.629818,1.348140,1.904659,3.033405,3.271360,-7.805427,-19.245981,-0.814797,1.354094,4.870294,21.817148,-0.870669,-0.303703,0.231868,1.006691,0.795442,0.751809,1.936058,1.487820,2.358656,68,104,151,192,156,109,160,159,81,6,10,9,-1.550616,-3.350901,4.115060,2.395200,16.671190,19.570853,2623,0.0,26.0,1.0,2.0,0.0,64039.823768,73550.805654,67811.514720,-3408.123687,-8271.828911,6724.809839,-38.913111,-149.488290,33.426953,23.089789,19.520663,-81.484065,-48.340884,121.528266,19.520663,-0.0,-0.0,0.0
3,0.750097,1.026938,-1.270542,0.998745,1.910227,1.902650,-0.722997,-0.250944,-6.398224,3.125900,5.935603,0.295709,0.727732,0.067038,-0.447860,0.778379,0.174045,0.583183,1.519565,1.099968,1.671153,92,76,38,38,54,92,65,33,89,3,3,2,0.594898,1.602820,-1.562669,-0.296258,1.004371,1.222706,2614,2.018204,6.078147,6.768461,2.253215,-1.893532,-0.807340,1.568602,1.692539,3.308180,3.627724,-6.636508,-24.220037,-2.443328,0.894481,6.954446,28.441685,-1.651345,-0.605877,0.797371,1.131005,0.658888,1.319573,2.258718,1.275407,2.566612,42,87,193,217,173,67,145,144,92,7,8,6,-0.407977,-3.332839,4.222910,-0.587256,19.475408,22.712071,2625,0.0,26.0,1.0,2.0,0.0,94162.016143,97406.917130,97646.040084,-6885.480609,-6971.656240,6889.834763,-131.047220,206.989865,-120.339399,161.558075,-89.643040,-273.639762,-175.381026,-21.668635,-89.643040,-0.0,-0.0,0.0
4,0.848838,0.973131,-1.458308,1.081737,2.012801,2.271412,-0.769161,-0.798550,-7.525879,2.437628,6.057998,0.426785,0.928033,-0.015841,-0.533851,1.108709,0.196532,0.728995,2.143037,1.317004,1.874395,88,59,42,42,71,88,67,35,96,2,3,3,-0.024427,1.535184,-1.493724,-1.460534,0.875619,0.993267,2615,2.446183,6.457343,9.420174,2.461836,-1.252286,0.777510,-0.256891,2.391435,4.140240,4.124073,-6.461044,-4.869362,-30.057446,7.765651,29.208992,3.876174,-0.943103,-0.278694,0.084140,1.256584,1.209825,1.563147,2.749226,2.405082,3.349084,81,108,133,179,152,127,140,85,145,9,9,7,0.101236,3.507514,-3.907596,0.891800,16.356610,20.350650,2629,0.0,26.0,1.0,2.0,0.0,97079.686757,111619.083348,110140.471945,799.956021,11140.067038,-6939.599536,-409.987607,127.110852,-2105.047947,-161.786748,893.371813,13.391245,-2529.594695,-263.272013,893.371813,0.0,0.0,-0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
775,1.124059,0.868385,-0.557649,1.769305,1.793278,0.640079,-2.124098,-1.618071,-2.107992,5.115294,6.160065,0.137966,0.745668,0.317028,-0.226515,0.819809,0.428861,0.299584,1.895584,0.833830,0.956650,146,134,28,38,50,156,75,40,107,5,4,6,0.595994,1.579119,-0.937681,0.043571,1.770909,-0.220320,3708,8.053562,7.272532,1.321944,2.114526,0.177239,-0.005505,0.000754,2.094841,0.360890,0.924471,-7.820592,-1.390942,-4.466952,4.243997,0.905656,1.955913,0.446576,0.054009,0.075893,1.309064,0.212779,0.417309,2.450877,0.449668,0.834652,216,218,204,152,150,164,201,220,203,8,14,12,-0.709408,-0.839562,-1.723582,1.739160,1.608635,6.632414,3719,5.0,17.0,2.0,1.0,5.0,150476.603535,123709.410214,105031.769317,2243.132459,675.118539,-79.980126,607.160878,-95.941516,764.806596,286.767877,19.728675,70.428455,950.510208,-168.317429,19.728675,-0.0,0.0,0.0
776,0.175868,0.126534,-0.090149,0.063361,0.044371,0.048972,0.034111,0.026536,-0.259139,0.366540,0.257108,-0.010025,0.167381,0.122452,-0.078358,0.031544,0.028989,0.022199,0.061695,0.059532,0.047954,184,184,0,0,0,184,85,87,111,9,9,8,0.636391,0.479981,-1.412384,1.131113,0.751953,2.593557,3711,0.064257,0.033063,0.019342,0.235701,0.016071,0.007215,0.022708,0.441973,0.172015,0.203104,-1.055179,-0.413451,-0.384983,1.443622,0.338025,0.474910,-0.018890,0.024078,0.032592,0.240714,0.091058,0.142473,0.499720,0.204398,0.292795,173,202,210,195,166,157,163,195,191,16,15,15,0.363590,-0.156499,-0.061046,1.719475,-0.398642,-0.672954,3722,5.0,17.0,2.0,1.0,5.0,16328.455012,14926.845369,11764.884720,-1202.284235,-198.746054,124.284340,125.089248,99.077492,36.171542,80.914541,-96.802980,-36.527058,-84.167728,-121.459865,-96.802980,-0.0,-0.0,-0.0
777,1.158998,0.109567,-0.447495,1.400385,1.075495,0.455428,-1.683910,-2.239668,-1.848851,4.628438,4.601656,0.093953,1.040378,0.076231,-0.242911,0.786498,0.497102,0.165139,1.652192,0.967944,0.474715,161,109,9,23,75,175,83,82,121,7,4,7,0.319763,0.749889,-1.385160,0.071514,2.112440,1.169408,3708,6.060401,2.138831,0.748032,1.732327,0.061763,0.199856,-0.313364,1.960954,0.766789,1.811544,-6.923450,-1.493108,-7.408870,4.250782,3.628609,1.868247,0.443316,0.146730,0.186572,0.971106,0.332304,0.696070,1.914574,0.630853,1.673926,228,231,221,140,137,147,224,158,253,12,12,11,-1.362104,1.523101,-2.203906,2.623797,4.374600,5.056808,3719,5.0,17.0,2.0,1.0,5.0,125376.431839,114615.660304,70388.259972,-556.106100,1776.274499,-3346.591036,28.496749,1.979776,170.066905,118.102978,-188.640493,-110.105815,-42.558142,64.215854,-188.640493,-0.0,-0.0,0.0
778,1.027020,1.041297,-0.401656,1.670558,2.408641,0.629152,-0.748236,-2.550784,-2.492981,5.904718,7.133821,0.474394,0.547226,0.140985,-0.215343,0.733870,0.408218,0.206552,1.424532,1.592294,0.422294,129,108,38,55,76,146,61,49,126,5,5,7,1.535490,1.176548,-1.677749,1.695599,0.323017,2.508161,3708,7.047879,12.611952,1.021216,2.251000,0.316169,-0.163348,-0.053651,2.080016,0.882088,1.130701,-6.167850,-1.708348,-3.680224,5.627224,2.762729,1.663650,0.366293,-0.011973,0.103562,0.937782,0.606209,0.742829,1.939318,1.162333,1.562382,227,179,211,141,188,157,188,213,220,11,11,10,-0.155978,0.251532,-0.879793,1.115778,0.184891,0.670894,3719,5.0,17.0,2.0,1.0,5.0,158192.081387,109483.024669,131114.913994,2787.002306,264.663371,551.001809,-73.442786,2753.199453,-1180.669030,364.699919,688.652049,1598.789843,450.136767,1057.231814,688.652049,0.0,-0.0,0.0


## REGRESSION, SVC, KNEIGHBORS

In [6]:
# ACTIVAR PARA CARGAR LOS DATOS DIRECTAMENTE
# data = pd.read_csv('final_data')

X = np.array(data.drop(['smash'], axis=1))
y = np.array(data['smash'])


In [55]:
seed = 2
X = np.array(data.drop(['smash'], axis=1))
y = np.array(data['smash'])
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=seed)
model = svm.SVC()
model.fit(X_train, Y_train)
model.score(X_test, Y_test)


0.30128205128205127

In [58]:
seed = 2
X = np.array(data.drop(['smash'], axis=1))
y = np.array(data['smash'])
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=seed)
model = KNeighborsClassifier(n_neighbors=6)
model.fit(X_train, Y_train)
model.score(X_test, Y_test)

0.6153846153846154

In [43]:
X = np.array(data.drop(['smash'], axis=1))
y = np.array(data['smash'])

seed = 7
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=seed)
model = linear_model.LogisticRegression(max_iter=1000)
name='Logistic Regression'
kfold = model_selection.KFold(n_splits=3, random_state=7)

cv_results = model_selection.cross_val_score(model, X_train, Y_train, cv=kfold, scoring='accuracy')
msg = "%s: %f %s (%f)" % (name, cv_results.mean(), "+-", cv_results.std())
model.fit(X_train, Y_train)
predictions = model.predict(X_test)
print(model.score(X=X_test, y=Y_test))
print(accuracy_score(Y_test, predictions))
print(confusion_matrix(Y_test, predictions))

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logist

0.5576923076923077
0.5576923076923077
[[31  1  1  2  2  1]
 [ 1 18  2  1  3  1]
 [ 0  3 16  2  0  0]
 [ 3  3  3  2  1  1]
 [ 5 10  0  4  8  6]
 [ 3  7  1  1  1 12]]


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


## TREES

In [59]:
k_fold = RepeatedKFold(n_splits=20, n_repeats=3, random_state=1)
val_score = []
train_score = []
n_stimators = 1500
model = RandomForestClassifier(warm_start=True)

for i, (train, val) in enumerate(k_fold.split(X_train)):
    print("Iteración:", i+1)
    print("val_size:", len(val))

    model.fit(X_train[train], Y_train[train])
    model.n_estimators += 100

    score_val = model.score(X_train[val], Y_train[val])
    val_score.append(score_val)
    score_train = model.score(X_train[train], Y_train[train])
    train_score.append(score_train)
    print("score_train:", score_train)
    print("score_val:", score_val)

    if i % 1000 == 0: 
        pickle.dump(model, open("model_saved.sav", "wb"))
    if score_val < np.mean(val_score) // 2:
        print("Se ha encontrado una bajada importante de accuracy en validación")
        print("Este es el responsable:")
        print(X_train[val])
        break 
    if np.mean(val_score) >= 0.99 and len(val_score) > 20:
        pickle.dump(model, open("model_saved.sav", "wb"))
        print("STOP")
        break
    print("##############")

Iteración: 1
val_size: 32
score_train: 1.0
score_val: 0.875
##############
Iteración: 2
val_size: 32
score_train: 1.0
score_val: 1.0
##############
Iteración: 3
val_size: 32
score_train: 1.0
score_val: 1.0
##############
Iteración: 4
val_size: 32
score_train: 1.0
score_val: 1.0
##############
Iteración: 5
val_size: 31
score_train: 1.0
score_val: 1.0
##############
Iteración: 6
val_size: 31
score_train: 1.0
score_val: 1.0
##############
Iteración: 7
val_size: 31
score_train: 1.0
score_val: 1.0
##############
Iteración: 8
val_size: 31
score_train: 1.0
score_val: 1.0
##############
Iteración: 9
val_size: 31
score_train: 1.0
score_val: 1.0
##############
Iteración: 10
val_size: 31
score_train: 1.0
score_val: 1.0
##############
Iteración: 11
val_size: 31
score_train: 1.0
score_val: 1.0
##############
Iteración: 12
val_size: 31
score_train: 1.0
score_val: 1.0
##############
Iteración: 13
val_size: 31
score_train: 1.0
score_val: 1.0
##############
Iteración: 14
val_size: 31
score_train: 1.0
s

In [60]:
predictions = model.predict(X_test)
print(model.score(X=X_test, y=Y_test))
print(accuracy_score(Y_test, predictions))
print(confusion_matrix(Y_test, predictions))

0.9038461538461539
0.9038461538461539
[[29  2  0  2  1  0]
 [ 1 37  0  0  0  0]
 [ 2  0 25  0  0  0]
 [ 1  2  1 18  0  0]
 [ 0  0  0  0 14  2]
 [ 0  0  0  0  1 18]]


In [61]:
pickle.dump(model, open("tree.sav", "wb"))

In [62]:
seed = 2
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)
train_test = (X_train, X_test, y_train, y_test)
k_fold = RepeatedKFold(n_splits=10, n_repeats=1, random_state=seed)
val_score = []
train_score = []
model = RandomForestClassifier(warm_start=True, n_estimators=100, max_features='auto')
for i, (train, val) in enumerate(k_fold.split(X_train)):
    print("Iteración:", i+1)
    print("val_size:", len(val))

    model.fit(X_train[train], y_train[train])
    model.n_estimators += 100

    score_val = model.score(X_train[val], y_train[val])
    val_score.append(score_val)
    score_train = model.score(X_train[train], y_train[train])
    train_score.append(score_train)
    print('Score val:', score_val)
    print('Score train:', score_train)
    print('##########################')

    if np.mean(val_score) > 0.99 and len(val_score) > 50:
        pickle.dump(model, open(path + "model_forest_warm_start", "wb"))
        print("STOP")
        break
    print('##########################')


Iteración: 1
val_size: 63
Score val: 0.8571428571428571
Score train: 1.0
##########################
##########################
Iteración: 2
val_size: 63
Score val: 1.0
Score train: 0.9982174688057041
##########################
##########################
Iteración: 3
val_size: 63
Score val: 1.0
Score train: 1.0
##########################
##########################
Iteración: 4
val_size: 63
Score val: 1.0
Score train: 1.0
##########################
##########################
Iteración: 5
val_size: 62
Score val: 1.0
Score train: 1.0
##########################
##########################
Iteración: 6
val_size: 62
Score val: 1.0
Score train: 1.0
##########################
##########################
Iteración: 7
val_size: 62
Score val: 1.0
Score train: 1.0
##########################
##########################
Iteración: 8
val_size: 62
Score val: 1.0
Score train: 1.0
##########################
##########################
Iteración: 9
val_size: 62
Score val: 1.0
Score train: 1.0
################

In [63]:
model.score(X_test, Y_test)

0.9102564102564102

RANDOM FOREST CONVERGES IN THE THRID ITERATION, WITH 300 STIMATORS. AFTER THAT, IT IS OVER TRAINING. THE RESULT IS EQUALY A 90% (SLIGHTLY UPPER)

## AUTOML
CITE: H2O.ai. (2022) h2o: R Interface for H2O. R package version 3.42.0.2. https://github.com/h2oai/h2o-3.

In [4]:
import h2o
from h2o.automl import H2OAutoML

In [5]:
h2o.init()

Checking whether there is an H2O instance running at http://localhost:54321..... not found.
Attempting to start a local H2O server...
  Java Version: java version "15.0.2" 2021-01-19; Java(TM) SE Runtime Environment (build 15.0.2+7-27); Java HotSpot(TM) 64-Bit Server VM (build 15.0.2+7-27, mixed mode, sharing)
  Starting server from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/h2o/backend/bin/h2o.jar
  Ice root: /var/folders/r2/w54gxl1d7_130wxnhpdtlmhm0000gn/T/tmpcsvmi3zr
  JVM stdout: /var/folders/r2/w54gxl1d7_130wxnhpdtlmhm0000gn/T/tmpcsvmi3zr/h2o_juanbayonfernandez_started_from_python.out
  JVM stderr: /var/folders/r2/w54gxl1d7_130wxnhpdtlmhm0000gn/T/tmpcsvmi3zr/h2o_juanbayonfernandez_started_from_python.err
  Server is running at http://127.0.0.1:54321
Connecting to H2O server at http://127.0.0.1:54321 ... successful.


0,1
H2O_cluster_uptime:,02 secs
H2O_cluster_timezone:,Europe/Madrid
H2O_data_parsing_timezone:,UTC
H2O_cluster_version:,3.44.0.2
H2O_cluster_version_age:,1 month
H2O_cluster_name:,H2O_from_python_juanbayonfernandez_hexwmt
H2O_cluster_total_nodes:,1
H2O_cluster_free_memory:,4 Gb
H2O_cluster_total_cores:,8
H2O_cluster_allowed_cores:,8


In [7]:
train = data[data['player'] != 5]
test =  data[data['player'] == 5]
h2train = h2o.H2OFrame(train)
h2test = h2o.H2OFrame(test)

Parse progress: |████████████████████████████████████████████████████████████████| (done) 100%
Parse progress: |████████████████████████████████████████████████████████████████| (done) 100%


In [8]:
h2train['smash'] = h2train['smash'].asfactor()
h2test['smash'] = h2test['smash'].asfactor()

In [9]:
columnas = [a for a in h2train.columns if a != "smash"][1:]
x = columnas
y = "smash"

In [10]:
automl = H2OAutoML(max_models=50, seed=42, max_runtime_secs=300)
automl.train(x=x, y=y, training_frame=h2train, validation_frame=h2test)

AutoML progress: |
16:25:08.225: User specified a validation frame with cross-validation still enabled. Please note that the models will still be validated using cross-validation only, the validation frame will be used to provide purely informative validation metrics on the trained models.
16:25:08.244: AutoML: XGBoost is not available; skipping it.

███████████████████████████████████████████████████████████████| (done) 100%


Unnamed: 0,number_of_trees,number_of_internal_trees,model_size_in_bytes,min_depth,max_depth,mean_depth,min_leaves,max_leaves,mean_leaves
,60.0,360.0,116719.0,4.0,7.0,6.9805555,8.0,38.0,21.180555

0,1,2,3,4,5,Error,Rate
150.0,0.0,0.0,0.0,0.0,0.0,0.0,0 / 150
0.0,150.0,0.0,0.0,0.0,0.0,0.0,0 / 150
0.0,0.0,75.0,0.0,0.0,0.0,0.0,0 / 75
0.0,0.0,0.0,75.0,0.0,0.0,0.0,0 / 75
0.0,0.0,0.0,0.0,100.0,0.0,0.0,0 / 100
0.0,0.0,0.0,0.0,0.0,100.0,0.0,0 / 100
150.0,150.0,75.0,75.0,100.0,100.0,0.0,0 / 650

k,hit_ratio
1,1.0
2,1.0
3,1.0
4,1.0
5,1.0
6,1.0

0,1,2,3,4,5,Error,Rate
29.0,0.0,0.0,0.0,0.0,1.0,0.0333333,1 / 30
0.0,30.0,0.0,0.0,0.0,0.0,0.0,0 / 30
0.0,0.0,15.0,0.0,0.0,0.0,0.0,0 / 15
4.0,0.0,1.0,10.0,0.0,0.0,0.3333333,5 / 15
0.0,0.0,0.0,0.0,20.0,0.0,0.0,0 / 20
1.0,5.0,0.0,0.0,0.0,14.0,0.3,6 / 20
34.0,35.0,16.0,10.0,20.0,15.0,0.0923077,12 / 130

k,hit_ratio
1,0.9076923
2,0.9769231
3,0.9923077
4,0.9923077
5,1.0
6,1.0

0,1,2,3,4,5,Error,Rate
142.0,5.0,1.0,2.0,0.0,0.0,0.0533333,8 / 150
6.0,140.0,1.0,1.0,0.0,2.0,0.0666667,10 / 150
4.0,0.0,70.0,1.0,0.0,0.0,0.0666667,5 / 75
4.0,3.0,3.0,64.0,0.0,1.0,0.1466667,11 / 75
0.0,0.0,0.0,0.0,97.0,3.0,0.03,3 / 100
1.0,0.0,0.0,0.0,0.0,99.0,0.01,1 / 100
157.0,148.0,75.0,68.0,97.0,105.0,0.0584615,38 / 650

k,hit_ratio
1,0.9415385
2,0.98
3,0.9953846
4,1.0
5,1.0
6,1.0

Unnamed: 0,mean,sd,cv_1_valid,cv_2_valid,cv_3_valid,cv_4_valid,cv_5_valid
accuracy,0.9369231,0.0294928,0.9230769,0.9538462,0.9769231,0.9307692,0.9
auc,,0.0,,,,,
err,0.0630769,0.0294928,0.0769231,0.0461538,0.0230769,0.0692308,0.1
err_count,8.2,3.8340578,10.0,6.0,3.0,9.0,13.0
logloss,0.1766871,0.0718807,0.2092006,0.1703314,0.0613475,0.1874408,0.2551151
max_per_class_error,0.16,0.0596285,0.2,0.1333333,0.0666667,0.2,0.2
mean_per_class_accuracy,0.9322223,0.0310937,0.9111111,0.95,0.9777778,0.9194444,0.9027778
mean_per_class_error,0.0677778,0.0310937,0.0888889,0.05,0.0222222,0.0805556,0.0972222
mse,0.0531012,0.0230726,0.0600619,0.0461072,0.0190841,0.0579787,0.082274
pr_auc,,0.0,,,,,

Unnamed: 0,timestamp,duration,number_of_trees,training_rmse,training_logloss,training_classification_error,training_auc,training_pr_auc,validation_rmse,validation_logloss,validation_classification_error,validation_auc,validation_pr_auc
,2023-12-08 16:26:12,12.308 sec,0.0,0.8333333,1.7917595,0.8415385,,,0.8333333,1.7917595,0.7923077,,
,2023-12-08 16:26:13,12.526 sec,5.0,0.5170442,0.733038,0.0215385,,,0.6153593,0.9853189,0.1461538,,
,2023-12-08 16:26:13,12.726 sec,10.0,0.3193252,0.3722822,0.0123077,,,0.4913691,0.6970028,0.1384615,,
,2023-12-08 16:26:13,13.004 sec,15.0,0.2021333,0.2015694,0.0030769,,,0.4179857,0.5405142,0.1153846,,
,2023-12-08 16:26:13,13.220 sec,20.0,0.1275394,0.1099283,0.0015385,,,0.3805708,0.461374,0.1384615,,
,2023-12-08 16:26:13,13.401 sec,25.0,0.0816675,0.0620923,0.0,,,0.3368487,0.3748702,0.0923077,,
,2023-12-08 16:26:14,13.609 sec,30.0,0.0499892,0.0349797,0.0,,,0.3167181,0.3266586,0.0846154,,
,2023-12-08 16:26:14,13.820 sec,35.0,0.0315688,0.0208953,0.0,,,0.3136315,0.3189197,0.0846154,,
,2023-12-08 16:26:14,14.023 sec,40.0,0.0189054,0.0121462,0.0,,,0.3066071,0.3037329,0.0769231,,
,2023-12-08 16:26:14,14.191 sec,45.0,0.0116658,0.0072313,0.0,,,0.307282,0.3041159,0.1076923,,

variable,relative_importance,scaled_importance,percentage
time,238.8201294,1.0,0.1238700
time.1,226.8978424,0.9500784,0.1176862
y_min_acc,181.9181671,0.7617372,0.0943564
x_median_gyr,154.5376434,0.6470880,0.0801548
y_pcount_gyr,133.6334229,0.5595568,0.0693123
x_ncount_gyr,102.4845657,0.4291287,0.0531562
z_abvmean_acc,59.3293152,0.2484268,0.0307726
player,46.1175690,0.1931059,0.0239200
z_std_acc,42.3111076,0.1771673,0.0219457
x_pcount_gyr,37.8962440,0.1586811,0.0196558


In [11]:
leader_board = automl.leaderboard
leader_board

model_id,mean_per_class_error,logloss,rmse,mse
GBM_2_AutoML_1_20231208_162508,0.0622222,0.17321,0.230279,0.0530285
GBM_3_AutoML_1_20231208_162508,0.0688889,0.19811,0.240943,0.0580533
GBM_4_AutoML_1_20231208_162508,0.0727778,0.186029,0.240083,0.0576397
GBM_5_AutoML_1_20231208_162508,0.0755556,0.202272,0.241134,0.0581455
GBM_grid_1_AutoML_1_20231208_162508_model_8,0.0761111,0.245501,0.253669,0.064348
GBM_grid_1_AutoML_1_20231208_162508_model_5,0.08,0.204018,0.251171,0.0630869
GBM_grid_1_AutoML_1_20231208_162508_model_4,0.0805556,0.222461,0.265315,0.070392
GBM_grid_1_AutoML_1_20231208_162508_model_2,0.0811111,0.232755,0.262878,0.0691047
GBM_grid_1_AutoML_1_20231208_162508_model_9,0.0883333,0.218789,0.260136,0.0676708
GBM_grid_1_AutoML_1_20231208_162508_model_10,0.09,0.217183,0.260383,0.0677991


In [18]:
pred = automl.predict(h2test)
pred


gbm prediction progress: |

███████████████████████████████████████████████████████| (done) 100%


predict,p0,p1,p2,p3,p4,p5
0,0.451647,0.00849109,0.00274638,0.440453,0.00246294,0.0941992
0,0.841028,0.0622694,0.00184136,0.0325844,0.00292564,0.0593518
0,0.990204,0.000106265,9.791e-05,0.000416227,5.48927e-05,0.00912097
0,0.998115,0.000103447,0.000256209,0.000274215,5.15906e-05,0.00119966
0,0.985868,0.0013017,0.00117451,0.000763049,0.000448051,0.0104452
0,0.997186,0.000136379,8.89926e-05,0.000162786,6.5693e-05,0.00236052
0,0.995525,0.000207633,0.000169368,0.000298189,0.000122168,0.00367791
0,0.996699,0.000130332,0.000226806,0.000256428,8.4217e-05,0.0026031
0,0.994722,0.000231982,0.000636167,0.00045164,0.000137058,0.00382076
0,0.481833,0.0137793,0.0184137,0.306925,0.00760053,0.171449


In [22]:
best = automl.get_best_model()
model_path = h2o.save_model(model=best, path="./", force=True)

## CONCLUSION

GBM AS BETTER MODEL. THE COLUMNS GENERATED DISTANCE, ANGLE AND POINTS OF THE TRAJECTORY DO NOT MATTER FOR THE OUTOCOME, AND CAN BE DELETED. 