# Split Data

Split dataset into train/validation/test split

In [None]:
import os
import glob
import shutil
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from __init__ import data_path

In [None]:
!rm -rf "/home/florsanders/adl_ai_tennis_coach/data/tenniset/shot_labels/train"
!rm -rf "/home/florsanders/adl_ai_tennis_coach/data/tenniset/shot_labels/val"
!rm -rf "/home/florsanders/adl_ai_tennis_coach/data/tenniset/shot_labels/test"

In [None]:
labels_path = os.path.join(data_path, "tenniset", "shot_labels")
labels_files = os.listdir(labels_path)

In [None]:
# Test set = V010_*
test_vid = "V010"
test_dir = os.path.join(labels_path, "test")
os.makedirs(test_dir, exist_ok=True)
test_files = glob.glob(os.path.join(labels_path, f"{test_vid}*"))
for test_file in tqdm(test_files):
    test_filename = os.path.basename(test_file)
    shutil.copy(test_file, os.path.join(test_dir, test_filename))

In [None]:
# Train/Val splits => Remaining videos + sklearn data splitting
val_frac = 0.2
train_frac = 1 - val_frac
train_dir = os.path.join(labels_path, "train")
os.makedirs(train_dir, exist_ok=True)
val_dir = os.path.join(labels_path, "val")
os.makedirs(val_dir, exist_ok=True)
vids = ["V006", "V007", "V008", "V009"]
for vid in vids:
    # Load files for said annotation
    vid_files = list(sorted(glob.glob(os.path.join(labels_path, f"{vid}*_info.json"))))

    # Train validation splits
    train_files, val_files = train_test_split(vid_files, test_size=val_frac, random_state=42)
    
    # Copy train files
    for train_info_file in tqdm(sorted(train_files)):
        train_basename = os.path.basename(train_info_file.replace("_info.json", ""))
        for train_file in sorted(glob.glob(os.path.join(labels_path, f"{train_basename}*"))):
            train_filename = os.path.basename(train_file)
            shutil.copy(train_file, os.path.join(train_dir, train_filename))

    # Copy validation files
    for val_info_file in tqdm(sorted(val_files)):
        val_basename = os.path.basename(val_info_file.replace("_info.json", ""))
        for val_file in sorted(glob.glob(os.path.join(labels_path, f"{val_basename}*"))):
            val_filename = os.path.basename(val_file)
            shutil.copy(val_file, os.path.join(val_dir, val_filename))

In [None]:
print(len(os.listdir(train_dir)))
print(len(os.listdir(val_dir)))
print(len(os.listdir(test_dir)))