# Data Preprocessing: From Skeletal Coordinates to Spatio-Temporal Graphs

**Project:** Unsupervised Motor Signatures in ASD

**Paper Reference:** *Unsupervised Deep Learning Framework for Quantifying Atypical Motor Signatures in ASD*

## Overview
This notebook outlines the pipeline for transforming raw 3D skeletal data (extracted via MediaPipe) into the specific graph tensors required by the STGCN-AE model.

**Key Steps:**
1.  **Graph Construction:** Converting frame-by-frame joint coordinates into a fully connected graph structure ($N=24$ nodes).
2.  **Temporal Trimming:** Retaining only the final 120 frames (approx. 4 seconds) of the action to isolate the core motor behavior.
3.  **Sliding Window Segmentation:** slicing the sequence into 2-second overlapping windows (60 frames) with a 0.5-second stride.

In [None]:
import os
import torch

# 1. Force install PyTorch 2.8.0 with CUDA 12.6
# We use --upgrade --force-reinstall to overwrite whatever Colab loaded by default
!pip install torch==2.8.0+cu126 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 --upgrade

# 2. Verify the version before proceeding
import torch
print(f"Successfully installed PyTorch version: {torch.__version__}")

# 3. Set environment variable for PyG (PyTorch Geometric) installation
# This tells pip explicitly which binary wheels to grab
os.environ['TORCH'] = "2.8.0+cu126"

Looking in indexes: https://download.pytorch.org/whl/cu126
Collecting torchvision
  Using cached https://download.pytorch.org/whl/cu126/torchvision-0.24.1%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)
Collecting torchaudio
  Using cached https://download.pytorch.org/whl/cu126/torchaudio-2.9.1%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.9 kB)
INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.
Collecting torchvision
  Using cached https://download.pytorch.org/whl/cu126/torchvision-0.24.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)
INFO: pip is looking at multiple versions of torchaudio to determine which version is compatible with other requirements. This could take a while.
Collecting torchaudio
  Using cached https://download.pytorch.org/whl/cu126/torchaudio-2.9.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.9 kB)
Successfully

In [None]:
!pip install torch_geometric

# Optional dependencies:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-geometric-temporal

Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Collecting pyg_lib
  Downloading https://data.pyg.org/whl/torch-2.8.0%2Bcu126/pyg_lib-0.5.0%2Bpt28cu126-cp312-cp312-linux_x86_64.whl (4.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m61.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_scatter
  Downloading https://data.pyg.org/whl/torch-2.8.0%2Bcu126/torch_scatter-2.1.2%2Bpt28cu126-cp312-cp312-linux_x86_64.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m65.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_sparse
  Downloading https://data.pyg.org/whl/torch-2.8.0%2Bcu126/torch_sparse-0.6.18%2Bpt28cu126-cp312-cp312-linux_x86_64.whl (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m31.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_cluster
  Downloading https://data.pyg.org/whl/torch-2.8.0%2Bcu126/torch_cluster-1.6.3%2Bp

In [None]:
import pandas as pd
import numpy as np
import os
import shutil

def generate_dummy_dataset(output_dir='sample_data', num_subjects=4, num_frames=150):
    """
    Generates a synthetic dataset to mimic the structure required by the ASD Motor pipeline.

    Args:
        output_dir (str): Directory to save the dummy files.
        num_subjects (int): Number of dummy subjects to create.
        num_frames (int): Number of frames per video (should be >120 for trimming logic).
    """
    # 1. Setup Directory
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    os.makedirs(output_dir)
    print(f"Created directory: {output_dir}")

    # 2. Generate Meta Data
    # IDs starting from 1001
    video_ids = [1001 + i for i in range(num_subjects)]
    # Assign 0 (control) for the first half and 1 (autism) for the second half
    subject_types = [0] * (num_subjects // 2) + [1] * (num_subjects // 2)

    df_meta = pd.DataFrame({
        'video_id': video_ids,
        'subject_type': subject_types
    })

    meta_path = os.path.join(output_dir, 'meta_data.csv')
    df_meta.to_csv(meta_path, index=False)
    print(f"Generated metadata: {meta_path}")

    # 3. Generate Skeletal CSVs
    # Each file mimics the [video_id]_[action].csv format
    action_code = "BP" # Example action code from your prompt

    for vid_id in video_ids:
        # Create columns: video_id, action, Frame
        data = {
            'video_id': [vid_id] * num_frames,
            'action': [action_code] * num_frames,
            'Frame': list(range(num_frames))
        }

        # Create Joint Columns 0 to 23
        # Format: String representation of [x, y, z] e.g., "[0.5, 0.2, 0.1]"
        for j in range(24): # 24 joints as per manuscript [cite: 95]
            # Generate random 3D coordinates for this joint across all frames
            # Using random floats between -1 and 1 to simulate normalized space
            coords = np.random.rand(num_frames, 3).round(4).tolist()
            # Convert list to string format "[x, y, z]" to match ast.literal_eval expectation
            data[str(j)] = [str(c) for c in coords]

        df_skeleton = pd.DataFrame(data)

        # Save file
        filename = f"{vid_id}_{action_code}.csv"
        file_path = os.path.join(output_dir, filename)
        df_skeleton.to_csv(file_path, index=False)
        print(f"Generated skeletal file: {file_path}")

    print("\nDummy dataset generation complete.")
    print(f"Example usage in pipeline: \n  control_train_files = ['{output_dir}/1001_BP.csv', ...]")

# --- Execute Generation ---
generate_dummy_dataset()

Created directory: sample_data
Generated metadata: sample_data/meta_data.csv
Generated skeletal file: sample_data/1001_BP.csv
Generated skeletal file: sample_data/1002_BP.csv
Generated skeletal file: sample_data/1003_BP.csv
Generated skeletal file: sample_data/1004_BP.csv

Dummy dataset generation complete.
Example usage in pipeline: 
  control_train_files = ['sample_data/1001_BP.csv', ...]


## 1. Core Processing Logic
The following functions handle the parsing of string-formatted coordinates from CSVs and the generation of sliding windows.

In [None]:
import pandas as pd
from torch_geometric.data import Data
import ast
from itertools import permutations
import os
import numpy as np

def transform_to_graphs(csv_path, meta_data_path):
    """
    Transforms time-series joint data from a CSV file into a collection of
    PyTorch Geometric graphs.

    Args:
        csv_path (str): The file path to the input CSV file.

    Returns:
        list: A list of torch_geometric.data.Data objects, where each object
              represents a graph for a single frame.
    """
    # --- 1. Validate paths and read data ---
    if not os.path.exists(csv_path):
        return f"Error: The file '{csv_path}' was not found."

    try:
        df = pd.read_csv(csv_path)
        df_metadata = pd.read_csv(meta_data_path)
    except Exception as e:
        return f"Error reading the CSV file: {e}"

    # --- 2. Extract video_id and get the corresponding label ---
    if 'video_id' not in df.columns:
        return "Error: 'video_id' column not found in the input CSV."

    # Assume the video_id is the same for all rows in the file
    video_id = df['video_id'].iloc[0]

    # Find the matching row in the metadata DataFrame
    meta_row = df_metadata[df_metadata['video_id'] == video_id]

    if meta_row.empty:
        return f"Error: video_id '{video_id}' not found in '{meta_data_path}'."

    # Get the subject_type
    subject_type = meta_row['subject_type'].iloc[0]
    y_value = subject_type

    # The label `y` is a graph-level target
    y = torch.tensor([y_value], dtype=torch.long)

    # --- 3. Process frames and create graphs ---
    graph_collection = []
    num_nodes = 24  # 24 joints

    # Create the edge_index for a fully connected graph
    # This can be created once and reused for all graphs since the number of nodes is constant.
    perm = torch.tensor(list(permutations(range(num_nodes), 2)), dtype=torch.long)
    edge_index = perm.t().contiguous()

    # Iterate over each row in the DataFrame (each row is a frame)
    for index, row in df.iterrows():
        # Extract joint coordinates (columns '0' to '23')
        # The joint coordinates are stored as strings, so we need to parse them.
        node_features = []
        for i in range(num_nodes):
            joint_str = row[str(i)]
            try:
                # Safely evaluate the string to a list of coordinates
                joint_coords = ast.literal_eval(joint_str)
                node_features.append(joint_coords)
            except (ValueError, SyntaxError):
                # Handle cases where the string is not a valid list
                print(f"Warning: Could not parse joint data for row {index}, joint {i}. Skipping row. {video_id}")
                continue

        # Ensure we have the correct number of features before creating a tensor
        if len(node_features) != num_nodes:
            continue

        # Convert the list of node features to a PyTorch tensor
        x = torch.tensor(node_features, dtype=torch.float)

        # Get the frame number to store as a graph-level attribute
        frame = row['Frame'] if 'Frame' in row else index

        # Create a PyTorch Geometric Data object
        data = Data(x=x, edge_index=edge_index, y=y, frame=frame)

        # Add the graph to our collection
        graph_collection.append(data)

    return graph_collection

def create_sliding_windows(graph_sequence, window_size, step):
    """
    Takes a sequence of graphs and creates overlapping windows.

    Args:
        graph_sequence (list): A list of torch_geometric.data.Data objects.
        window_size (int): The number of graphs in each window (timesteps).
        step (int): The stride or step size between windows.

    Returns:
        list: A list of windows, where each window is a list of graphs.
    """
    windows = []
    for i in range(0, len(graph_sequence) - window_size + 1, step):
        window = graph_sequence[i: i + window_size]
        windows.append(window)
    return windows

def trim_by_ending_frames(graph_sequence, num_frames_to_keep):
    """
    Trims a sequence of graphs to keep only the final N frames.

    This is based on the assumption that the video is cut shortly after
    the main action is completed, making the end frames the most relevant.

    Args:
        graph_sequence (list): A list of torch_geometric.data.Data objects.
        num_frames_to_keep (int): The number of frames to keep from the end.

    Returns:
        list: The trimmed list of graphs containing up to the last N frames.
    """
    # Get the total number of frames in the sequence
    total_frames = len(graph_sequence)

    # If the video is already shorter than or equal to the desired number of frames,
    # we don't need to do anything. Just return the whole sequence.
    if total_frames <= num_frames_to_keep:
        return graph_sequence

    # Otherwise, calculate the starting index for the slice.
    # For example, if we have 200 frames and want to keep 150,
    # the start index will be 200 - 150 = 50. The slice will be from 50 to the end.
    start_index = total_frames - num_frames_to_keep

    # Return the slice from the calculated start index to the end of the list.
    return graph_sequence[start_index:]

## 2. Custom STGCN Dataset Class
We define a custom `torch.utils.data.Dataset` that wraps the preprocessing logic. This class:
1. Iterates through a list of file paths.
2. Applies the graph transformation and trimming.
3. Stacks the graph snapshots into 3D tensors of shape `[Nodes, Features, Time]`.

In [None]:
from torch_geometric.data import Data
from torch.utils.data import Dataset, DataLoader
# --- --- --- --- NEW DATASET CLASS --- --- --- ---

class STGCN_Dataset(Dataset):
    """
    A PyTorch Dataset for loading spatio-temporal graph windows.

    Args:
        data_dir (str): Path to the directory containing the CSV data files.
        meta_data_path (str): Path to the metadata CSV file.
        window_size (int): The number of time steps (frames) in each window.
        step (int): The stride between consecutive windows.
    """
    def __init__(self, file_paths, meta_data_path, window_size, step):
        super(STGCN_Dataset, self).__init__()
        self.window_size = window_size
        self.step = step
        self.samples = []
        self._edge_index = None

        print("Processing data files...")

        # --- MODIFICATION ---
        # Instead of iterating over a directory, we iterate over the provided list of file paths.
        for filepath in file_paths:
            # We no longer need os.path.join since `filepath` is the full path
            video_id = os.path.splitext(os.path.basename(filepath))[0]

            # 1. Transform the entire CSV into a sequence of graphs
            graph_sequence = transform_to_graphs(filepath, meta_data_path)

            graph_sequence = trim_by_ending_frames(graph_sequence, 120)

            # if isinstance(graph_sequence, list) and len(graph_sequence) >= self.window_size:
            if len(graph_sequence) >= self.window_size:
                # 2. Create sliding windows from the graph sequence
                windows = create_sliding_windows(graph_sequence, self.window_size, self.step)
                for window in windows:
                    self.samples.append({'window': window, 'video_id': video_id})

                # Store the edge_index (it's the same for all)
                if self._edge_index is None and graph_sequence:
                    self._edge_index = graph_sequence[0].edge_index
            else:
                # This block will catch error strings and print a helpful warning
                print(f"  - WARNING: Skipping file '{os.path.basename(filepath)}'. Reason: {graph_sequence}")

        print(f"Finished processing. Found {len(self.samples)} total training windows for this dataset.")

    def __len__(self):
        """Returns the total number of samples (windows)."""
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Generates one sample of data.

        Returns:
            tuple: A tuple containing:
                - X (torch.Tensor): Node features for the window, with shape
                                    [num_nodes, num_features, window_size].
                - y (torch.Tensor): The label for the window.
        """
        # Get the window (a list of Data objects)
        sample = self.samples[idx]
        window = sample['window']
        video_id = sample['video_id']

        # Stack node features along a new dimension to create the [N, F, T] tensor
        # N=num_nodes, F=num_features, T=timesteps (window_size)
        X = torch.stack([graph.x for graph in window], dim=2)

        X_permuted = X.permute(2, 0, 1)

        # The label is the same for all graphs in the window
        y = window[0].y

        return X_permuted, y, video_id

    def get_edge_index(self):
        """Helper function to get the constant edge_index."""
        return self._edge_index

## 3. Data Loading & Configuration

**Note regarding data privacy:** The raw skeletal data files are not included in this repository. The paths below are placeholders. To run this locally, point `file_paths` to your directory containing MediaPipe `.csv` exports.

**Hyperparameters:**
* **Window Size:** 60 frames (2 seconds)
* **Step Size:** 15 frames (0.5 seconds)

In [None]:
# --- Configuration ---
META_PATH = '/content/sample_data/meta_data.csv'  # Ensure a dummy metadata file exists in your repo
WINDOW_SIZE = 60
STEP = 15

# --- Placeholder File Lists ---
# In a real scenario, you would populate these lists using glob or os.listdir
# e.g., control_train_files = glob.glob("data/control_train/*.csv")

control_train_files = ['/content/sample_data/1001_BP.csv'] # [Add path to sample_control.csv provided in repo]
control_val_files = ['/content/sample_data/1002_BP.csv']
autism_test_files = ['/content/sample_data/1003_BP.csv', '/content/sample_data/1004_BP.csv']   # [Add path to sample_asd.csv provided in repo]

print(f"Configuration Set: Window={WINDOW_SIZE}, Step={STEP}")

Configuration Set: Window=60, Step=15


## 4. Instantiating Datasets
Here we process the raw files into ready-to-train datasets.

In [None]:
# Only run this if files are present
if len(control_train_files) > 0:
    print("--- Processing Control Training Set ---")
    control_dataset_train = STGCN_Dataset(
        file_paths=control_train_files,
        meta_data_path=META_PATH,
        window_size=WINDOW_SIZE,
        step=STEP
    )

    print("\n--- Processing Control Validation Set ---")
    control_dataset_val = STGCN_Dataset(
        file_paths=control_val_files,
        meta_data_path=META_PATH,
        window_size=WINDOW_SIZE,
        step=STEP
    )

    print("\n--- Processing ASD Test Set ---")
    autism_dataset_test = STGCN_Dataset(
        file_paths=autism_test_files,
        meta_data_path=META_PATH,
        window_size=WINDOW_SIZE,
        step=STEP
    )
else:
    print("No data files found. Please populate 'control_train_files' list.")

--- Processing Control Training Set ---
Processing data files...
Finished processing. Found 5 total training windows for this dataset.

--- Processing Control Validation Set ---
Processing data files...
Finished processing. Found 5 total training windows for this dataset.

--- Processing ASD Test Set ---
Processing data files...
Finished processing. Found 10 total training windows for this dataset.


## 5. Shape Verification
It is critical to verify that the output tensors match the expected input shape for the STGCN-AE: `[Time, Nodes, Features]`.

In [None]:
if 'control_dataset_train' in locals() and len(control_dataset_train) > 0:
    sample_x, sample_y, video_id = control_dataset_train[0]

    print(f"Video ID: {video_id}")
    print(f"Window Shape: {sample_x.shape}")
    # Expected: [60, 24, 3] (Time, Nodes, Features)
    # Note: The Dataloader will add the Batch dimension later.

Video ID: 1001_BP
Window Shape: torch.Size([60, 24, 3])
