In [1]:
from tslearn.neighbors import KNeighborsTimeSeriesClassifier
from Parser import parse_frames
from scipy.signal import medfilt
from Functions import analyse_each_rep, find_extremas, filter_extremas, numpy_fillna, DTWDistance
from JointAngles import JointAngles
import numpy as np
import glob
import os
import pickle
import pandas as pd
from functools import reduce


In [3]:
# input_folder = "C:\\Users\\altaok\\Documents\\GitHub\\IndividualProject\\keypoints_for_all\\shoulder press"
input_folder = 'C:\\Users\\ak5u16\\Desktop\\IndividualProject\\keypoints_for_all\\shoulder press'
folder_paths = glob.glob(os.path.join(input_folder, 'shoulder_press*'))
points_folder_name = os.path.basename(input_folder)

print(len(folder_paths))


def get_data_for_dataset(folder_paths, points_folder_name):
    angle_arrays = []
    for folder in folder_paths:
        video_name = os.path.basename(folder)
        label = 0 if '_correct' in folder else 1
        frame_poses = parse_frames(folder)
        joint_angles = JointAngles(points_folder_name, frame_poses)
        
        left_upArm_forearm_angles = np.array(joint_angles.left_upArm_forearm_angles)
        left_upArm_forearm_angles = np.nan_to_num(left_upArm_forearm_angles)
        left_upArm_forearm_angles_filtered = medfilt(medfilt(left_upArm_forearm_angles, 5), 5)
            
        right_upArm_forearm_angles = np.array(joint_angles.right_upArm_forearm_angles)
        right_upArm_forearm_angles = np.nan_to_num(right_upArm_forearm_angles)
        right_upArm_forearm_angles_filtered = medfilt(medfilt(right_upArm_forearm_angles, 5), 5)
            
        left_upArm_trunk_angles = np.array(joint_angles.left_upArm_trunk_angles)
        left_upArm_trunk_angles = np.nan_to_num(left_upArm_trunk_angles)
        left_upArm_trunk_angles_filtered = medfilt(medfilt(left_upArm_trunk_angles, 5), 5)
            
        right_upArm_trunk_angles = np.array(joint_angles.right_upArm_trunk_angles)
        right_upArm_trunk_angles = np.nan_to_num(right_upArm_trunk_angles)
        right_upArm_trunk_angles_filtered = medfilt(medfilt(right_upArm_trunk_angles, 5), 5)
            
        extremas1 = filter_extremas(left_upArm_trunk_angles_filtered, find_extremas(left_upArm_trunk_angles_filtered, maxima=False), maxima=False)
        extremas2 = filter_extremas(right_upArm_trunk_angles_filtered, find_extremas(right_upArm_trunk_angles_filtered, maxima=False), maxima=False)
            
        angle_arrays.append((label, extremas1, extremas2, [left_upArm_forearm_angles, right_upArm_forearm_angles, left_upArm_trunk_angles, right_upArm_trunk_angles]))


    return angle_arrays
                                             
                                    

def fill_dataframe(angle_arrays, exercise_folder_name):
    df_list = []
    
    print('Filling dataset with ' + str(exercise_folder_name) + ' data...')
   
    for tup in angle_arrays:
        label = tup[0]
        extremas1 = tup[1]
        extremas2 = tup[2]
        uf_angles1, uf_angles2, ut_angles1, ut_angles2 = tup[3]
        # Extract rep angles
        each_rep_angles = analyse_each_rep('shoulder press', 'dataset', extremas1=extremas1, uf_angles1=uf_angles1, ut_angles1=ut_angles1, extremas2=extremas2, uf_angles2=uf_angles2, ut_angles2=ut_angles2)
        #print(len(each_rep_angles))
        s1 = pd.Series(each_rep_angles, name='Angle_array')
        s2 = pd.Series([label for n in range(len(each_rep_angles))], name='Label')
        df = pd.concat([s1,s2], axis=1)
        df_list.append(df)

    return pd.concat(df_list).reset_index(drop=True)
        

56


In [4]:
angle_arrays = get_data_for_dataset(folder_paths, points_folder_name) 
df = fill_dataframe(angle_arrays, points_folder_name)
print(df.info())

Filling dataset with shoulder press data...
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1688 entries, 0 to 1687
Data columns (total 2 columns):
Angle_array    1688 non-null object
Label          1688 non-null int64
dtypes: int64(1), object(1)
memory usage: 26.5+ KB
None


In [7]:
df

Unnamed: 0,Angle_array,Label
0,"[69.19123919027534, 59.887551643656195, 74.75717607975666, 64.7345556231223, 77.51367471873984, 69.71255962189733, 84.34238387758403, 75.15955273474901, 90.93826967454103, 80.105591495624, 93.65395199520458, 84.72755842467053, 103.38230086504441, 97.4935251517071, 110.84567770980195, 103.03401195647099, 115.79234625572656, 113.51524715660975, 125.87602051399669, 118.7772458606969, 130.7907669857006, 129.37699603473172, 136.37492173227025, 133.17582222025865, 146.3426445352306, 141.73765476245856, 145.2894555493847, 141.40601959184454, 154.88786616677672, 150.0578224057691, 156.12755922033392, 152.08270921802975, 153.43459555035213, 153.07424779289025, 153.3438873077194, 152.45257882066758, 153.33142582283006, 152.17469985307488, 152.0002157730516, 152.14692829418357, 153.7108489031456, 152.41412454840616, 152.55195240932457, 152.08466165211058, 152.5204516137471, 152.7686978421476, 152.51914676955175, 152.00093812768722, 151.55889140225796, 150.8932554655288, 150.4974758203206, 151.41394925047442, 150.49931021445215, 151.12690863773153, 149.95458343207014, 150.8865099313589, 148.64240393613, 147.34844903066252, 146.66609438782197, 140.869820729344, 144.954541521991, 142.57754751165254, 142.80272891068284, 142.04624861787545, 135.9810690743301, 138.03808963083108, 130.2584241235722, 130.91672166836605, 131.53080063182625, 129.13706025056143, 122.99493212131951, 126.23998518634677, 120.16504600642175, 119.45494489880558, 121.08243135017767, 116.9686926615739, 115.92536848869538, 114.20270235800479, 109.66174608924045, 110.80091923509984, 107.26563565311602, 104.82146604285056, 108.04059216034115, 100.08063096928053, 106.55655730910402, 98.20840862303349, 101.17501963122967, 94.10361988031295, 96.42511844575397, 87.79562786425825, 92.42696464825237, 88.69218889873855, 88.05929950701287, 87.30010804383276, 86.16675061553312, 84.79110309137833, 86.12640554361327, 82.51941268201438, 82.40973947919639, 80.78074689984875, ...]",0
1,"[80.24529718281734, 81.04957369206828, 84.76851549434335, 86.4445580412306, 90.01462668012391, 91.20840463789165, 97.1824982822431, 98.11059859027853, 102.75385866553076, 106.47004708871533, 110.09801406942015, 112.52188588811549, 117.83640578609496, 119.1598389889985, 122.06148964137645, 125.1725845986514, 129.95674739582665, 132.39417823421778, 134.6962651886894, 139.01190756818602, 141.1530107409668, 142.38945831493612, 145.29456629876125, 149.36520107946544, 150.62798749212428, 153.51693591134995, 151.1528027844798, 155.95200545668362, 156.84636347499847, 161.3274945830678, 158.8964022118606, 163.7167322299417, 160.262114180295, 163.8875707579901, 160.85563091389383, 164.24029614490883, 160.59054042679915, 164.2214532147605, 160.5712317274534, 164.29286439400704, 160.3380647592718, 165.4338242328741, 159.51161299351966, 165.93857547204476, 160.58505749418651, 165.41132077755438, 160.8881546066555, 165.37895604774243, 160.4878230149625, 163.9472482495823, 160.8065560976225, 162.6061203057859, 161.00591974077088, 162.12183739566086, 160.4892159990995, 161.84485113116537, 158.0814181872865, 160.50244531121376, 154.64447477410525, 159.43891013275933, 152.80028843639465, 157.3199966931651, 151.8732286251309, 155.90628095078753, 149.64371584411322, 152.14693830533898, 146.1236666133121, 146.5063904440623, 143.37750135903178, 144.2444752205464, 140.35722295441565, 139.95450176606383, 136.13379583914207, 136.74154757562505, 133.49521147435934, 134.96140943202423, 129.8929344045865, 133.85819088489345, 126.13132949875312, 130.87027805567894, 123.3281325270624, 124.88216594700248, 119.2984242463484, 123.75320082788586, 117.93725960619824, 120.22845607597765, 114.7170449049114, 115.47839232431345, 110.23693689678112, 115.00954310837724, 108.88805026093526, 112.56516265080893, 107.89691928822894, 108.96751295950526, 104.97328676164021, 106.81247921713151, 103.74669944459308, 106.71369090540738, 101.77827846181188, 104.2201515759892, ...]",0
2,"[67.07368343918425, 61.96936426977568, 68.36493157228425, 66.30267875997095, 74.43807873240114, 70.26506765540834, 81.39442718363037, 76.02519917933726, 86.85373297840479, 80.17744570329549, 88.78225598099675, 81.66383150812707, 97.15289878811151, 95.92470803627506, 108.06998947306907, 102.6398832838312, 111.34946353752053, 114.6832707991073, 122.73539581271137, 118.09291374422794, 127.07890509479003, 131.668718365544, 136.47002154862764, 133.47279788164926, 144.24079205650065, 143.95836882733738, 146.61572514668939, 148.0785973957856, 152.0428280168499, 152.8158052627335, 156.5122153056858, 152.03060962853954, 158.95878287472019, 152.9968120530055, 154.86810955913464, 154.62635604715274, 156.9641818716781, 155.4651241228125, 156.96506983387792, 155.4659851453653, 155.0273005217828, 152.2306013800666, 154.3128781585861, 152.47043220607742, 154.04692916062984, 152.22325327318302, 151.71150091358612, 152.754415351394, 151.70876887643385, 152.4648181978875, 152.76377757563185, 153.4281959316523, 150.6541309144225, 153.69410735891472, 151.5175508029742, 151.73349588118054, 150.6979066809482, 152.7515304641185, 145.7686560615477, 147.18888213789762, 144.45509220008614, 141.93935071449314, 144.39999395821565, 140.1538410304594, 134.76276335916128, 138.2199658589443, 131.9690636086759, 130.8915132876814, 129.98049480962342, 128.50994747376643, 118.44001091136202, 124.29363623361678, 119.46341589603438, 117.72343593554271, 116.23352096113784, 115.8378129351382, 108.88728448633425, 112.38793941865583, 105.30845198074927, 105.60218408503069, 104.88312881971653, 105.18128376334965, 101.62588416634254, 100.64384701813124, 96.75527173710485, 97.50120088835529, 92.0721417646541, 91.94707118621592, 89.63177706317845, 84.75835741293531, 85.87468200519157, 85.25520980489878, 84.25483587861147, 82.4763334290087, 80.72123272964906, 80.29143347128124, 77.98484204363172, 81.54803487327509, 77.98646764521875, 81.54754531904216, ...]",0
3,"[81.0478930981393, 77.59356465669529, 86.28740668669873, 81.23661168173767, 90.51767039084424, 87.06293352686995, 96.71163516746209, 95.06736933451458, 102.9984204877516, 101.78175464321154, 107.60997069499389, 106.55445074144887, 115.43383683187893, 116.12852208715802, 120.16083747303864, 124.58963701385147, 128.98503806283392, 129.6414881131994, 134.27860126508796, 136.47069196554241, 142.69353296505122, 140.54243893203775, 145.34631614440727, 147.8584497664921, 151.91241580937668, 153.04557970803805, 154.04916444130347, 156.87800916202357, 158.77415547946657, 159.18103474212714, 159.0706454099848, 162.2813972459517, 159.2516474055323, 165.92428302360057, 161.73810509977312, 165.27286178127724, 162.16662714692276, 166.6611557364355, 162.16928443169104, 166.66358941475536, 161.79578386728713, 165.3724330561267, 162.06068233345394, 165.92473963858305, 161.31509113881384, 166.14572190460615, 162.34594659891408, 164.65755626196568, 162.03816719339395, 164.67661882584065, 161.74954501505437, 164.62723636787757, 161.51005681626683, 164.10691907931732, 160.6547463573087, 163.71305090933842, 160.58832084408834, 163.08394305257616, 156.74165902400543, 160.87898996651327, 153.64525008320123, 158.4346507289538, 151.4831584890012, 156.8909039716061, 149.59652382331714, 152.18485702507257, 145.63596889182395, 147.23239721256857, 142.79414738451038, 145.2695797642388, 139.07201468730716, 138.99433959880392, 135.44787575715503, 135.98474302051727, 132.75617725160967, 133.38316723144303, 129.46839115932667, 130.65051620058105, 123.97092624481024, 124.8698885119386, 123.25158772387206, 123.283560317317, 117.83567716389643, 118.5967230139588, 115.9853942399894, 115.47957230584448, 112.24943542758965, 113.16474671488207, 107.28959609676807, 109.95933910543967, 107.12583805224764, 106.95686908273927, 104.79414076697385, 104.48062928478352, 101.2546846617653, 102.78318536643177, 101.72881960845878, 99.13781203018846, 101.73085924311485, 99.13298628746438, ...]",0
4,"[64.97333596341046, 56.742370770602555, 66.2236911612523, 56.69947087865342, 67.04363809340853, 56.77439404520502, 72.9057485233729, 56.774362027549735, 81.11950152585882, 58.40535399254576, 84.17322875449466, 62.35467007798497, 91.4733222404824, 67.64456187243933, 96.28590388144467, 72.48357784743517, 110.44708969555379, 76.87799728653151, 110.80463492922776, 83.36852691907258, 121.95577695189914, 88.63394700372538, 127.57657192089461, 99.61618198343362, 134.9516172959366, 105.69664789406096, 142.6928405575152, 120.16940013586722, 148.08547299645045, 123.71607572795482, 149.56479166429895, 133.43634720848385, 149.57191576755397, 143.8382067456466, 155.2435203016919, 145.9794475729923, 152.77173199225604, 157.58905431601394, 155.22656997349154, 157.58827042664763, 153.83219025511198, 155.60040953695488, 157.05813469026697, 156.65487470294383, 154.47758614498525, 158.5961898198548, 153.42259754194, 157.75747013810582, 155.81935268483144, 159.77394011713076, 153.71615816112026, 157.70821574290449, 153.69009581602722, 156.55660163145993, 153.43769289279854, 156.1749747077114, 151.1157357675413, 157.0294197056823, 152.16854479602506, 155.12248765782638, 150.35251613753556, 156.19220335932067, 146.948125474873, 157.25936699814537, 144.61846070362193, 155.13679646376337, 143.26515188518516, 155.54213617416602, 133.416783144198, 147.57743425213434, 128.3214067257407, 144.30463361085515, 129.02299586155564, 144.5698646489695, 119.31217699435611, 136.70940747112434, 115.81373951938218, 133.4279407414774, 113.39058877300475, 129.21389771500503, 104.82555092252646, 120.9154065444782, 103.95969400571849, 117.76963015469966, 98.39277061419952, 116.46357532284101, 97.306113039102, 109.38285546313485, 94.54405912481285, 105.01972405254772, 87.3831742269073, 102.73872931292736, 87.70511621538928, 98.01146676445315, 82.83345173186058, 93.60084024431518, 78.96712867318044, 90.51373654375138, 77.87717209217979, 91.09040710517773, ...]",0
5,"[73.912896791279, 77.03910576972778, 75.8876859905378, 77.02857891567007, 80.20100122674702, 77.67507695210688, 85.83556023503402, 77.67268310408112, 90.9861544228642, 80.16269047897802, 96.16074543738463, 82.65578263959179, 102.97621130184434, 89.91890683546731, 107.73048456261694, 96.93445610041685, 116.85883641271904, 102.55155568014042, 122.38221528304179, 109.93620753108812, 131.7381401495618, 117.10432217699483, 136.99688533962404, 126.8747652229749, 143.9413673745971, 131.31307418501498, 149.60971079022232, 135.2364663902012, 152.31621236919014, 140.37060629316224, 157.96995396328234, 145.86306795783798, 157.99047409297015, 152.1584888677306, 159.53707538047095, 158.3996773958493, 162.1089591727004, 158.92088907152956, 162.54244234688676, 158.90576967578707, 162.1620602323078, 164.58540345267755, 162.56853692053642, 163.89946815826082, 161.7002630223012, 165.93853729656712, 162.1678225001008, 165.16977584197952, 161.3123556182344, 167.41063570145278, 162.15160240694902, 166.09079420576796, 161.32043326422976, 164.59000804945438, 160.83042155192982, 166.3711305290473, 161.26759354338134, 165.38774980089647, 159.60810662169843, 165.36941577099947, 159.19634298582108, 165.43058236261402, 154.20331103025663, 163.9166131373549, 151.31286233030883, 164.41143710485363, 150.3293385493443, 163.23541433228553, 145.87952531839915, 161.06689310738915, 142.50069683821982, 158.0589400500833, 138.94391448312498, 157.49258196758802, 135.59953121369384, 152.843806660584, 132.5161719879874, 146.94411364734566, 129.4834685001309, 145.81791009107278, 126.03927856732194, 140.11853024315306, 121.06488475436075, 135.29009698816074, 119.2615144739334, 131.29154500202387, 115.07499240038548, 125.05055010371758, 111.3990360085281, 125.63338696949866, 109.93248501232118, 119.1127298650001, 110.0268938904159, 115.93291973827051, 105.5115536866248, 113.2030087675313, 103.234163425775, 107.59592127937924, 98.5838807507194, 108.39398412150095, ...]",0
6,"[62.78378238482622, 57.65202514550158, 71.90740003410998, 62.66433263354804, 75.70868867842236, 64.69217188148801, 78.11543702748048, 70.53216521963681, 85.6181088910475, 76.94363629472905, 91.67732131317688, 83.02681145560383, 96.16802697935371, 88.14578362512022, 108.02122945847616, 94.68343566978535, 109.13270698621592, 104.67860034358769, 117.17953586068923, 108.954919292948, 123.11479930685599, 119.28018392079973, 129.8698619291281, 121.90045715493149, 134.1782591073522, 130.7886283342933, 134.7274793381066, 135.7273801616389, 146.42144703169984, 135.73070806467393, 155.13001247129063, 145.69033319829055, 156.47327803947195, 153.16676068977304, 154.49970215521978, 156.67566070194007, 157.86734019362052, 156.65841066927956, 164.6679884587915, 156.9289440414671, 163.7036751524146, 160.9154931821961, 163.67049032517676, 161.86400821042668, 163.75495615759363, 161.7051729859992, 163.75327208835196, 163.88303327368948, 163.58124825739526, 165.8667772591229, 163.68406675370917, 164.58994881468374, 163.6632081477385, 164.7798882697016, 162.41919950724733, 162.73319402730164, 157.9336360856728, 160.51166562280957, 153.23263806163234, 157.29642762503155, 152.7697045060328, 155.28320996567135, 151.52885929784676, 156.16527086277452, 147.33813479740857, 155.99376990541683, 143.7969795009009, 154.37816870198634, 138.79772325608204, 146.39372973428206, 130.32993121364234, 144.80800622822423, 132.33601255240333, 137.27537789084084, 123.24526388001009, 134.7218337277874, 119.14717320579096, 129.25490623310054, 116.30152230503207, 122.03151438524239, 111.46900474892988, 119.99129953049284, 105.1123348656708, 118.82343592271472, 102.60186457285793, 111.1322988131595, 102.60848083770763, 107.46473924041187, 95.6313762387583, 108.06366328388967, 93.01651382143692, 105.49719625226831, 90.66316592929124, 97.45203661878868, 86.7674375231065, 96.3156163180078, 84.52902783774067, 90.25988061075105, 81.75183106457575, 88.24823137494919, ...]",0
7,"[78.94872163678585, 74.71186421442997, 82.20361356587115, 74.33667449464065, 87.98871956242488, 82.1932619498391, 95.68267127916134, 87.5293255239255, 101.9228499941231, 91.71571702355045, 105.74892834776902, 100.59335675834522, 110.3488229891901, 109.32621990681893, 119.61858352239568, 116.63084223404533, 124.20363676690134, 124.44050401665649, 129.28925995897035, 128.02198219848944, 135.3808757715108, 134.8016401851318, 140.07894153332853, 138.19747930540836, 143.15434431280877, 144.88517233416678, 143.15965935170524, 147.32516506401683, 152.3900735113902, 147.8717675724391, 157.547965129952, 156.4543554374678, 158.67585902753345, 161.36595995977348, 160.4157552145312, 165.1490815976387, 162.56411028463947, 164.66280439326846, 164.25743648231057, 166.1182261883832, 165.11511388252833, 169.70113537016152, 164.94521666591257, 169.75844346311095, 165.4877993760663, 169.73593427047138, 166.80732900440375, 170.42285285257927, 166.64949777955613, 169.95215549305385, 166.37169458155557, 169.77241003700792, 165.4995879420033, 170.26267413056513, 164.44574705213472, 170.2569799296861, 162.94454704204736, 168.98364554989877, 161.81866905967345, 166.15395633934955, 160.99638549675782, 164.14680216390147, 161.42672336416464, 164.2477662169735, 158.03322574026112, 162.38447941367187, 154.15120279758895, 158.79779683780663, 152.76450306882708, 155.93695409075923, 148.51043420751154, 153.07886412455187, 145.16267632599005, 146.85838306891196, 140.0251289105772, 144.45153700823678, 136.17809093603358, 141.106816140076, 133.78525438566035, 136.79777443136865, 130.9696500139644, 131.75335487375446, 126.16892888844131, 129.0182571140805, 123.43279068985498, 125.81858967788328, 124.03184220519252, 125.39483618043587, 119.29973086806193, 125.3995082485358, 113.66222422625037, 118.06610730577752, 112.22833886206742, 116.52250573788065, 108.79978239941978, 111.3833462707207, 106.58577714699932, 109.03134530011916, 105.26519236110124, 105.8234383898659, ...]",0
8,"[66.24968157555176, 57.920186866350726, 71.94358627980945, 59.1079852561884, 75.97360999433616, 61.735237510309766, 81.86490560404286, 64.07253896076729, 87.85495861711543, 68.24827662516225, 92.35437420503325, 74.09724690189323, 96.69204510076368, 80.20460027660953, 96.69083515387092, 84.58865709229087, 113.69766509672492, 89.25562045958814, 116.51624281878117, 87.96723005308694, 125.36716610067029, 110.67675466533191, 134.07397946020828, 115.44441966775838, 137.26233514577532, 121.98124570400292, 149.4599603504625, 132.61756428741396, 148.99396462415328, 135.64290213877135, 155.6442789542662, 146.09184483362083, 158.88687767191846, 151.48817392743905, 159.57737720611763, 155.5055555388301, 156.59105207917162, 154.3736643356661, 156.25003183668133, 156.12211185212513, 156.23860943361333, 166.14951764780108, 156.22107930536785, 165.10762393019348, 156.07241351095965, 165.0628323388002, 154.03576320822188, 165.05073515880602, 154.14222589017797, 164.8410261873151, 154.0239354400621, 164.83831180055597, 154.87133685847178, 164.8395329880919, 156.22945419775795, 162.82099418745787, 156.0561769515772, 162.98526455286134, 153.14757511711508, 160.76907860606306, 151.69725734986284, 155.60385748663037, 146.11514689478915, 154.38811211129234, 144.03170178376897, 153.0390909505349, 131.40292847828258, 151.55301578538706, 130.54600379406915, 142.55002645109667, 127.55838431320602, 142.53545164456804, 119.62125964484378, 133.34138361192427, 119.62330949731377, 128.91926726527774, 108.87320781802597, 126.52657252547404, 103.13980244800848, 125.90740396032263, 102.17546917454703, 112.38654472147883, 97.42718786076907, 105.9140856972764, 89.665261478724, 104.72462591197309, 86.73444561432952, 99.5290136657188, 83.17359016400268, 91.67349664372115, 81.05752675522258, 88.36140636724663, 81.53450765877784, 86.25106177798098, 79.7246053200862, 83.84854966094547, 75.75299993519465, 84.5799916397204, 74.106697515391, 79.54777502231195, ...]",0
9,"[80.73466939210242, 76.33439222980114, 83.22948544019137, 74.79801255621781, 90.02139605082307, 77.02338682388755, 97.24907405316054, 82.19404581720067, 104.47530674434338, 86.57199344362377, 109.5582245210315, 93.45475777111824, 112.85722548611, 99.99745173195272, 111.56671185480599, 106.20127058206097, 128.91278256332419, 115.13957998294813, 133.27103193416195, 115.14131684863354, 137.5314125677238, 126.7766159249407, 143.41102980979016, 132.3807094314713, 148.5222443973807, 137.46235112156634, 153.15736247093955, 146.40736224284862, 155.9985975973284, 146.2497548725304, 159.7085212714693, 154.45944265933247, 159.28711702111713, 156.34812019716074, 161.25959829501534, 160.7262733769476, 165.4957509070138, 163.24519330699954, 165.53655927189575, 163.89184345836682, 165.5280491187474, 164.2132633858492, 165.052672272437, 165.15754871128127, 165.29497663602248, 165.1952966163201, 165.75762494296492, 165.65437505120835, 165.27416234098837, 165.21286122548858, 163.9696752744322, 163.79522080715284, 163.6468498106243, 164.24972464184953, 162.63125160206047, 164.72517667455736, 160.8505574248603, 165.92021394074072, 159.955330275687, 166.16261815216882, 159.59717165575753, 166.20278937012137, 157.65443040906584, 162.75014841521522, 152.7301887816107, 161.33293435600473, 152.83338701452337, 158.8542119384429, 147.97210177624987, 156.95976971855458, 143.8413176623455, 148.86639879988522, 140.85555423558156, 145.69518870787223, 140.2338347903216, 144.0446614346392, 131.21511883629782, 138.51440343954994, 127.66081379095073, 138.51437039434802, 123.50137438205066, 128.84466186115358, 119.7482196578758, 123.45871015446849, 115.1218244806137, 123.52528827317713, 110.46071930348073, 118.61093847546452, 109.10755321698672, 115.05510030855058, 105.36319981043496, 111.58849428233215, 105.1967867556011, 109.09868366174094, 98.94144174762063, 107.51157626809572, 97.36922132305288, 107.46129544848534, 96.68778242243317, 105.079524358307, ...]",0


In [5]:
from sklearn.model_selection import train_test_split
from tslearn.neighbors import KNeighborsTimeSeriesClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics

y = df['Label']
x = df['Angle_array']
#x = numpy_fillna(x)


In [6]:
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
X_train.shape

(1350,)

In [7]:
X_test.shape

(338,)

In [8]:
k_range = range(1, 2) #26
scores = {}
scores_list = []
for k in k_range:
    knn_clf = KNeighborsTimeSeriesClassifier(n_neighbors=k, metric="dtw", n_jobs=-1)
    #knn_clf = KNeighborsClassifier(k, n_jobs=-1)
    knn_clf.fit(numpy_fillna(X_train.values), y_train)
    predicted_labels = knn_clf.predict(numpy_fillna(X_test.values))
    acc = metrics.accuracy_score(y_test, predicted_labels)
    scores[k] = acc
    print("Correct classification rate:", acc)
    print('\n')
    print(metrics.classification_report(y_test, predicted_labels))
    print('F1 score: ' + str(metrics.f1_score(y_test, predicted_labels, average='macro')))

Correct classification rate: 0.9497041420118343


              precision    recall  f1-score   support

           0       0.93      0.98      0.96       194
           1       0.98      0.90      0.94       144

    accuracy                           0.95       338
   macro avg       0.95      0.94      0.95       338
weighted avg       0.95      0.95      0.95       338

0.9480108212770193
