In [2]:
import os
import sys
import warnings
import pandas as pd
from tqdm.notebook import tqdm

base_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
print(f"base_path: {base_path}")
sys.path.append(base_path)

base_path: /home/miru/sr-press


In [3]:
from express.databases import SQLiteDatabase
from express.datasets import PressingDataset

from express import features as fs
from express import labels as ls

In [4]:
TRAIN_DB_1_PATH = os.path.join(base_path, "stores/train_database (1).sqlite")
TRAIN_DB_PATH = os.path.join(base_path, "stores/train_database.sqlite")

TEST_DB_1_PATH = os.path.join(base_path, "stores/test_database (1).sqlite")
TEST_DB_PATH = os.path.join(base_path, "stores/test_database.sqlite")

train_db_1= SQLiteDatabase(TRAIN_DB_1_PATH)
train_db = SQLiteDatabase(TRAIN_DB_PATH)

test_db_1 = SQLiteDatabase(TEST_DB_1_PATH)
test_db = SQLiteDatabase(TEST_DB_PATH)

print("train_db_1:", train_db_1)
print("train_db:", train_db)

print("test_db_1:", test_db_1)
print("test_db:", test_db)

train_db_1: <express.databases.sqlite.SQLiteDatabase object at 0x7f8374095460>
train_db: <express.databases.sqlite.SQLiteDatabase object at 0x7f8374095a60>
test_db_1: <express.databases.sqlite.SQLiteDatabase object at 0x7f809a614100>
test_db: <express.databases.sqlite.SQLiteDatabase object at 0x7f836c5e9f70>


In [6]:
print(train_db_1.games().shape, test_db_1.games().shape)

(136, 11) (64, 11)


In [7]:
all_features = [f.__name__ for f in fs.all_features]
all_labels = [f.__name__ for f in ls.all_labels]
print("Features:", all_features)
print("Labels:", all_labels)

Features: ['actiontype', 'actiontype_onehot', 'result', 'result_onehot', 'bodypart', 'bodypart_onehot', 'time', 'startlocation', 'relative_startlocation', 'endlocation', 'relative_endlocation', 'startpolar', 'endpolar', 'movement', 'team', 'time_delta', 'space_delta', 'goalscore', 'angle', 'under_pressure', 'speed', 'freeze_frame_360', 'dist_opponent', 'defenders_in_3m_radius', 'closest_11_players', 'get_column_sum_to_player']
Labels: ['concede_shots', 'counterpress', 'possession_change_by_2_actions', 'possession_change_by_4_actions', 'possession_change_by_6_actions', 'possession_change_by_2_actions_and_3m_distance', 'possession_change_by_4_actions_and_3m_distance', 'possession_change_by_6_actions_and_3m_distance', 'possession_change_by_2_actions_and_5m_distance', 'possession_change_by_4_actions_and_5m_distance', 'possession_change_by_6_actions_and_5m_distance', 'possession_change_by_2_actions_and_7m_distance', 'possession_change_by_4_actions_and_7m_distance', 'possession_change_by_6_a

In [8]:
train_dataset = PressingDataset(
    path= os.path.join(base_path, "stores", "datasets", "train"), 
    xfns=["startlocation", "freeze_frame_360"],
    yfns=["counterpress","possession_change_by_5_seconds"], 
    load_cached =False,
    nb_prev_actions = 3,
)

test_dataset = PressingDataset(
    path= os.path.join(base_path, "stores", "datasets", "test"), 
    xfns=["startlocation", "freeze_frame_360"],
    yfns=["counterpress","possession_change_by_5_seconds"], 
    load_cached =False,
    nb_prev_actions = 3,
)

In [9]:
train_dataset.create(train_db_1)
test_dataset.create(test_db_1)

100%|██████████| 136/136 [21:00<00:00,  9.27s/it]
100%|██████████| 136/136 [03:47<00:00,  1.67s/it]
100%|██████████| 64/64 [09:38<00:00,  9.04s/it]
100%|██████████| 64/64 [01:34<00:00,  1.47s/it]


In [13]:
train_dataset.labels['possession_change_by_5_seconds'].value_counts()

possession_change_by_5_seconds
False    26607
True     10341
Name: count, dtype: int64

In [14]:
train_dataset.features

Unnamed: 0_level_0,Unnamed: 1_level_0,start_x_a0,start_y_a0,start_x_a1,start_y_a1,start_x_a2,start_y_a2,start_x_a3,start_y_a3,freeze_frame_360_a0,freeze_frame_360_a1,freeze_frame_360_a2,freeze_frame_360_a3
game_id,action_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
3788741,3,41.7375,61.285,31.2375,42.585,28.0000,43.945,52.0625,34.425,"[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': False, 'actor': False, 'keeper':..."
3788741,19,84.5250,59.500,41.4750,59.500,38.5875,54.400,29.9250,34.170,"[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ..."
3788741,27,27.2125,13.940,76.3875,58.055,72.1875,68.000,83.0375,56.865,"[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...",,"[{'teammate': True, 'actor': False, 'keeper': ..."
3788741,31,75.7750,60.690,25.9000,9.520,25.2875,11.815,24.4125,10.370,"[{'teammate': False, 'actor': False, 'keeper':...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': False, 'actor': False, 'keeper':...","[{'teammate': True, 'actor': False, 'keeper': ..."
3788741,34,79.9750,58.055,20.9125,10.710,25.7250,10.965,25.9000,9.520,"[{'teammate': False, 'actor': False, 'keeper':...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': False, 'keeper': ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
3943043,2183,82.0750,20.230,37.1000,68.000,100.7125,36.805,98.5250,33.235,"[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': True, 'keeper': F...",,"[{'teammate': True, 'actor': False, 'keeper': ..."
3943043,2186,34.3000,61.710,33.3375,0.425,91.9625,20.655,67.9000,0.000,"[{'teammate': True, 'actor': True, 'keeper': F...","[{'teammate': True, 'actor': False, 'keeper': ...","[{'teammate': True, 'actor': True, 'keeper': T...","[{'teammate': True, 'actor': True, 'keeper': F..."
3943043,2201,47.7750,31.790,35.3500,12.070,35.9625,7.055,42.6125,20.825,"[{'teammate': True, 'actor': False, 'keeper': ...",,,"[{'teammate': True, 'actor': False, 'keeper': ..."
3943043,2202,49.6125,35.700,35.3500,12.070,35.9625,7.055,42.6125,20.825,"[{'teammate': True, 'actor': False, 'keeper': ...",,,"[{'teammate': True, 'actor': False, 'keeper': ..."


In [17]:
train_dataset.labels

Unnamed: 0,counterpress
0,False
1,False
2,False
3,True
4,True
...,...
39041,False
39042,False
39043,False
39044,False


In [18]:
test_dataset.labels

Unnamed: 0,counterpress
0,False
1,False
2,False
3,False
4,False
...,...
12401,False
12402,False
12403,False
12404,False


In [19]:
train_dataset.labels["counterpress"].value_counts()

False    31605
True      7441
Name: counterpress, dtype: int64

In [20]:
test_dataset.labels["counterpress"].value_counts()

False    9954
True     2452
Name: counterpress, dtype: int64