In [1]:
import os
import re
import random
import pandas as pd
from glob import glob

def load_tracks(data_dir):
    """
    Reads all CSV files from the given directory and loads them into a dictionary.
    The dictionary keys are labels (extracted from the filename) and the values
    are lists of track dictionaries containing 'filename' and 'data' keys.
    
    File name example: '12345_3.csv' (where '3' is the label)
    """
    tracks = {}
    # Get all CSV files in the directory
    filepaths = glob(os.path.join(data_dir, "*.csv"))
    
    for path in filepaths:
        filename = os.path.basename(path)
        # Extract the label from the filename using regex:
        # Searches for one of the digits 0-3 immediately preceding ".csv"
        match = re.search(r'([0-3])(?=\.csv$)', filename)
        if match:
            label = match.group(1)
            if label not in tracks:
                tracks[label] = []
            # Load CSV file as a DataFrame
            df = pd.read_csv(path)
            tracks[label].append({'filename': filename, 'data': df})
        else:
            print(f"Warning: Could not extract label from filename: {filename}")
    return tracks

def drop_trailing_missing(track_df):
    """
    Drops trailing rows from the track DataFrame where ALL columns starting with
    'x0' or 'y0' have missing values. Iterates from the bottom upward and stops
    when a row with any non-missing value in these columns is encountered.
    """
    # Identify columns that start with 'x0' or 'y0'
    pattern = re.compile(r'^(x0|y0)')
    cols = [col for col in track_df.columns if pattern.match(col)]
    
    # Work from the last row upward
    last_valid_index = None
    for idx in reversed(track_df.index):
        # Check only the relevant columns in the row
        row_values = track_df.loc[idx, cols]
        # If at least one value is not missing, mark this row as valid
        if not row_values.isnull().all():
            last_valid_index = idx
            break
    # If no valid row was found, return an empty DataFrame
    if last_valid_index is None:
        return pd.DataFrame(columns=track_df.columns)
    
    # Return DataFrame up to the last valid index (inclusive) and reset index
    return track_df.loc[:last_valid_index].reset_index(drop=True)

def count_missing_in_middle(track_df):
    """
    Counts the number of missing values in columns that start with 'x0' or 'y0'
    in the provided track DataFrame. This is intended for assessing missing values
    "in the middle" of a track (after trailing rows with missing values have been dropped).
    """
    pattern = re.compile(r'^(x0|y0)')
    cols = [col for col in track_df.columns if pattern.match(col)]
    # Count all missing values in the selected columns
    missing_count = track_df[cols].isnull().sum().sum()
    return missing_count

def balance_dataset(tracks):
    """
    Balances the dataset across labels. Given a dictionary where each key is a label
    and its value is a list of track dictionaries (each with a 'data' key containing
    a DataFrame), this function first determines the minimum number of tracks among
    all labels. Then, for each label, it selects the tracks with the highest number
    of available trajectory points. Here, we consider the number of rows in the
    DataFrame (after any processing such as dropping trailing missing rows) as the
    number of available trajectory points.
    
    Returns a new dictionary with balanced label keys.
    """
    # Calculate the minimum track count among all labels
    min_count = min(len(track_list) for track_list in tracks.values())
    
    balanced = {}
    for label, track_list in tracks.items():
        # Sort the track list by the number of available trajectory points in descending order.
        # Here we assume that each row in the DataFrame corresponds to one trajectory point.
        sorted_tracks = sorted(track_list, key=lambda track: track['data'].shape[0] - count_missing_in_middle(track['data']), reverse=True)
        # Select the top 'min_count' tracks for each label.
        balanced[label] = sorted_tracks[:min_count]
    
    return balanced

def filter_tracks_by_valid_rows(tracks, min_valid_rows):
    """
    Filters out tracks that do not have at least `min_valid_rows` valid rows.
    
    A valid row is defined as one in which all columns starting with 'x0' or 'y0'
    contain a non-empty value.
    
    Args:
        tracks: Dictionary mapping labels to lists of track dictionaries.
        min_valid_rows: Minimum required number of valid rows for a track to be kept.
    
    Returns:
        A new dictionary with tracks that meet the valid row threshold.
        Also prints the number of tracks dropped for each label and total.
    """
    pattern = re.compile(r'^(x0|y0)')
    filtered_tracks = {}
    total_dropped = 0
    for label, track_list in tracks.items():
        valid_tracks = []
        dropped_count = 0
        for track in track_list:
            df = track['data']
            # Identify columns that match the pattern (x0 and y0 columns)
            cols = [col for col in df.columns if pattern.match(col)]
            # A valid row: all values in the selected columns are not missing
            valid_rows = df[cols].notnull().all(axis=1)
            num_valid = valid_rows.sum()
            if num_valid >= min_valid_rows:
                valid_tracks.append(track)
            else:
                dropped_count += 1
        filtered_tracks[label] = valid_tracks
        print(f"For label {label}, dropped {dropped_count} tracks out of {len(track_list)} because they had less than {min_valid_rows} valid rows.")
        total_dropped += dropped_count
    print(f"Total tracks dropped: {total_dropped}")
    return filtered_tracks


import os
import random

def split_and_save(tracks,
                   train_ratio=0.8,
                   val_ratio=0.1,
                   train_folder='train',
                   val_folder='val',
                   test_folder='test'):
    """
    Splits the dataset into train/val/test sets while maintaining label balance.
    For each label in the `tracks` dict, the tracks are shuffled and then split
    according to the specified ratios. Each track is written to CSV in the
    appropriate folder with a filename "<label>_XXX.csv".

    Args:
        tracks (dict): label → list of {'filename', 'data'} dicts
        train_ratio (float): fraction for training set
        val_ratio (float):   fraction for validation set
        train_folder (str):  output dir for train
        val_folder (str):    output dir for val
        test_folder (str):   output dir for test
    """
    # compute test ratio implicitly
    test_ratio = 1.0 - train_ratio - val_ratio
    if test_ratio < 0:
        raise ValueError("train_ratio + val_ratio must be ≤ 1.0")

    # make directories
    os.makedirs(train_folder, exist_ok=True)
    os.makedirs(val_folder,   exist_ok=True)
    os.makedirs(test_folder,  exist_ok=True)

    train_count = val_count = test_count = 0

    for label, track_list in tracks.items():
        random.shuffle(track_list)
        n = len(track_list)
        n_train = int(n * train_ratio)
        n_val   = int(n * val_ratio)
        # note: remaining = n - n_train - n_val goes to test

        for idx, track in enumerate(track_list, start=1):
            new_filename = f"{label}_{idx:03}.csv"
            if idx <= n_train:
                out_dir = train_folder
                train_count += 1
            elif idx <= n_train + n_val:
                out_dir = val_folder
                val_count += 1
            else:
                out_dir = test_folder
                test_count += 1

            out_path = os.path.join(out_dir, new_filename)
            track['data'].to_csv(out_path, index=False)
            print(f"Saved {new_filename} to {out_path}")

    # summary
    print(f"\nTotal trajectories saved:")
    print(f"  Train:      {train_count}")
    print(f"  Validation: {val_count}")
    print(f"  Test:       {test_count}")





In [2]:
# Set the path to the directory containing the CSV files
data_dir = "../datasets/insects-tracks/Insect_Trajectory_Dataset/Extracted_Tracks/Insect_Tracks_CSV"
TARGET_DIR = "../datasets/it-trajs_2/"
os.makedirs(TARGET_DIR, exist_ok=True)
# Step 1: Load tracks from CSV files
tracks = load_tracks(data_dir)

# Step 2 and 3: Process each track by dropping trailing missing rows
# and count the missing values in the middle of the track.
for label in tracks:
    for track in tracks[label]:
        original_df = track['data']
        # Drop trailing rows that are missing the x0/y0 columns
        processed_df = drop_trailing_missing(original_df)
        track['data'] = processed_df  # update with processed data
        missing_count = count_missing_in_middle(processed_df)
        print(f"File '{track['filename']}' (Label {label}) has {missing_count} missing x0/y0 values in the middle.")

min_valid_rows = 10
tracks = filter_tracks_by_valid_rows(tracks, min_valid_rows)

tracks
#min_count = min(len(track_list) for track_list in tracks.values())



File '11715463302.csv' (Label 2) has 0 missing x0/y0 values in the middle.
File '51214325702.csv' (Label 2) has 84 missing x0/y0 values in the middle.
File '31514124402.csv' (Label 2) has 218 missing x0/y0 values in the middle.
File '41714574102.csv' (Label 2) has 0 missing x0/y0 values in the middle.
File '61714100402.csv' (Label 2) has 0 missing x0/y0 values in the middle.
File '61711480302.csv' (Label 2) has 0 missing x0/y0 values in the middle.
File '21711041902.csv' (Label 2) has 188 missing x0/y0 values in the middle.
File '91714304102.csv' (Label 2) has 6 missing x0/y0 values in the middle.
File '61712394802.csv' (Label 2) has 160 missing x0/y0 values in the middle.
File '41014294802.csv' (Label 2) has 62 missing x0/y0 values in the middle.
File '31514135202.csv' (Label 2) has 36 missing x0/y0 values in the middle.
File '81513053002.csv' (Label 2) has 38 missing x0/y0 values in the middle.
File '61714110302.csv' (Label 2) has 16 missing x0/y0 values in the middle.
File '61712394

{'2': [{'filename': '11715463302.csv',
   'data':     Unnamed: 0  nframe_11715463302  absframe_11715463302  x0_11715463302  \
   0            0                   0                550938          1644.0   
   1            1                   1                550939          1624.0   
   2            2                   2                550940          1586.0   
   3            3                   3                550941          1538.0   
   4            4                   4                550942          1483.0   
   5            5                   5                550943          1423.0   
   6            6                   6                550944          1365.0   
   7            7                   7                550945          1338.0   
   8            8                   8                550946          1313.0   
   9            9                   9                550947          1304.0   
   10          10                  10                550948          1300.0   
   11

In [6]:
# Step 4: Balance the dataset so that each label has the same number of tracks
balanced_tracks = balance_dataset(tracks)

# Step 5: Split the balanced tracks into training and test sets and save to CSV files
split_and_save(balanced_tracks, train_ratio=0.7, val_ratio=0.15, train_folder=f'{TARGET_DIR}/train', val_folder=f'{TARGET_DIR}/val', test_folder=f'{TARGET_DIR}/test')

Saved 2_001.csv to ../datasets/it-trajs_2//train/2_001.csv
Saved 2_002.csv to ../datasets/it-trajs_2//train/2_002.csv
Saved 2_003.csv to ../datasets/it-trajs_2//train/2_003.csv
Saved 2_004.csv to ../datasets/it-trajs_2//train/2_004.csv
Saved 2_005.csv to ../datasets/it-trajs_2//train/2_005.csv
Saved 2_006.csv to ../datasets/it-trajs_2//train/2_006.csv
Saved 2_007.csv to ../datasets/it-trajs_2//train/2_007.csv
Saved 2_008.csv to ../datasets/it-trajs_2//train/2_008.csv
Saved 2_009.csv to ../datasets/it-trajs_2//train/2_009.csv
Saved 2_010.csv to ../datasets/it-trajs_2//train/2_010.csv
Saved 2_011.csv to ../datasets/it-trajs_2//train/2_011.csv
Saved 2_012.csv to ../datasets/it-trajs_2//train/2_012.csv
Saved 2_013.csv to ../datasets/it-trajs_2//train/2_013.csv
Saved 2_014.csv to ../datasets/it-trajs_2//train/2_014.csv
Saved 2_015.csv to ../datasets/it-trajs_2//train/2_015.csv
Saved 2_016.csv to ../datasets/it-trajs_2//train/2_016.csv
Saved 2_017.csv to ../datasets/it-trajs_2//train/2_017.c