In [1]:
from data.normlengthdata import NormLengthData
from match import Match
from encoder import LabelEncoder
from model import GaitLengthPredictor

In [2]:
match_1_instance = Match("../data/match_1.json")
match_2_instance = Match("../data/match_2.json")

In [3]:
# Drop "no action" as it only figures twice in the second match
for index, element in enumerate(match_2_instance.data):
    if element["label"] == "no action":
        match_2_instance.data.pop(index)

In [4]:
match_1_data = NormLengthData(match_1_instance)
match_2_data = NormLengthData(match_2_instance)

In [5]:
encoder = LabelEncoder.load("../data/model/encoder.json")

In [6]:
all_data = match_1_data + match_2_data

In [7]:
X, y = all_data.X, all_data.y
all_data.train_test_split(X, y)

In [8]:
all_data.info()

Size of training set: 829 
Size of validation set: 178 
Size of test set: 178


In [9]:
X_train_transformed, X_val_transformed, X_test_transformed = all_data.transform(encoder)

In [10]:
gait_predictor = GaitLengthPredictor()
gait_predictor.fit(X_train_transformed, all_data.y_train)

In [11]:
gait_predictor.evaluate(
    X_val_transformed, all_data.y_val, X_test_transformed, all_data.y_test
)

Validation Metrics: 
MAE: 12.569494479567194 
MSE: 405.99080141435076 
RMSE: 20.149213419246685 
R-squared (R2): 0.2613283207325052

Test Metrics: 
MAE: 12.457286466114667 
MSE: 423.2854716746332 
RMSE: 20.57390268458158 
R-squared (R2): -0.19465952799840158



In [12]:
encoder.class_to_index

{'run': 0,
 'walk': 1,
 'tackle': 2,
 'pass': 3,
 'rest': 4,
 'cross': 5,
 'dribble': 6,
 'shot': 7,
 '<Undefined>': 18}

In [13]:
(match_1_instance + match_2_instance).average_gait_length_per_action

{'tackle': 46,
 'dribble': 39,
 'shot': 33,
 'rest': 115,
 'cross': 52,
 'pass': 42,
 'walk': 55,
 'run': 37}

In [14]:
test_action = [[2], [7], [5]]
gait_predictor.predict(test_action)

array([52.90692371, 31.70909723, 53.52598328])

In [15]:
gait_predictor.save()

In [16]:
gait_predictor = GaitLengthPredictor.load()

In [17]:
gait_predictor.predict(test_action)

array([52.90692371, 31.70909723, 53.52598328])