In [1]:
## Imports and seed
import os
import glob
import random
import copy
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Set random seed for reproducibility
torch.manual_seed(22)
np.random.seed(22)
random.seed(22)

ModuleNotFoundError: No module named 'numpy'

In [None]:
## Custom Dataset for Right-Leg Gait Phase Estimation (CSV version)
class GaitPhaseDataset(Dataset):
    def __init__(self, root_dir, sequence_length=128, subjects=None, transform=None):
        """
        Args:
            root_dir (str): Root directory of the dataset. Files are expected under:
                dataset/<subject>/treadmill/imu/*.csv
                and corresponding gait cycle files under:
                dataset/<subject>/treadmill/gcRight/*.csv.
            sequence_length (int): Number of time steps in each sample window.
            subjects (list of str or None): List of subject IDs (e.g., ['AB09', 'AB10']).
                If None, all subjects are used.
            transform (callable, optional): Optional transform to be applied on the IMU window.
        """
        self.root_dir = root_dir
        self.sequence_length = sequence_length
        self.transform = transform

        # Locate all treadmill IMU CSV files.
        search_path = os.path.join(root_dir, '*', 'treadmill', 'imu', '*.csv')
        self.imu_files = glob.glob(search_path, recursive=True)
        if subjects is not None:
            # Filter files based on subject IDs
            filtered_files = []
            for subject in subjects:
                subject_path = os.path.join(root_dir, subject, '*', 'treadmill', 'imu', '*.csv')
                filtered_files.extend(glob.glob(subject_path))
                print(f"Before filtering - Number of files: {len(self.imu_files)}")
                print(f"Filtering for subject: {subject}")
            self.imu_files = filtered_files
            print(f"After filtering - Number of files: {len(self.imu_files)}")
            print("Sample paths:")
            for f in self.imu_files[:2]:  # Print first 2 paths as examples
                print(f"- {f}")
                print(f"  Subject: {os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(f))))}")
        if len(self.imu_files) == 0:
            raise RuntimeError("No IMU files found. Please check your dataset directory and folder structure.")

    
    def __len__(self):
        return len(self.imu_files)
    
    def __getitem__(self, idx):
        # Get the IMU CSV file path.
        imu_path = self.imu_files[idx]
        # Derive the corresponding gcRight CSV file path by replacing 'imu' with 'gcRight'
        gcRight_path = imu_path.replace(os.sep + 'imu' + os.sep, os.sep + 'gcRight' + os.sep)
        
        # Load CSV files (skip the header row)
        imu_data = self._load_csv_file(imu_path)
        gcRight_data = self._load_csv_file(gcRight_path)
        
        # Drop the timestamp column (first column)
        imu_data = imu_data[:, 1:]
        gcRight_data = gcRight_data[:, 1:]
        
        # Select only shank and thigh channels from IMU data.
        # CSV column order (after dropping timestamp) is:
        # [foot_Accel_X, foot_Accel_Y, foot_Accel_Z,
        #  foot_Gyro_X, foot_Gyro_Y, foot_Gyro_Z,
        #  shank_Accel_X, shank_Accel_Y, shank_Accel_Z,
        #  shank_Gyro_X, shank_Gyro_Y, shank_Gyro_Z,
        #  thigh_Accel_X, thigh_Accel_Y, thigh_Accel_Z,
        #  thigh_Gyro_X, thigh_Gyro_Y, thigh_Gyro_Z,
        #  trunk_Accel_X, trunk_Accel_Y, trunk_Accel_Z,
        #  trunk_Gyro_X, trunk_Gyro_Y, trunk_Gyro_Z]
        # We keep shank (columns 6 to 11) and thigh (columns 12 to 17)
        shank = imu_data[:, 6:12]
        thigh = imu_data[:, 12:18]
        imu_selected = np.concatenate([shank, thigh], axis=1)  # Shape: (N, 12)
        
        # Synchronize lengths: truncate all signals to the minimum available length.
        min_length = min(imu_selected.shape[0], gcRight_data.shape[0])
        imu_selected = imu_selected[:min_length, :]
        gcRight_data = gcRight_data[:min_length, :]
        
        end_idx = start_idx + self.sequence_length
        imu_window = imu_selected[start_idx:end_idx, :]  # (sequence_length, 12)
        
        # Use the HeelStrike value from gcRight at the center of the window.
        center_idx = start_idx + self.sequence_length // 2
        heel_strike = gcRight_data[center_idx, 0]  # HeelStrike value (0-100)
        # Normalize to [0, 1]
        heel_strike_norm = heel_strike / 100.0
        target = np.array([heel_strike_norm], dtype=np.float32)
        
        # Optionally apply a transform; otherwise, convert to torch tensors.
        if self.transform:
            imu_window = self.transform(imu_window)
        else:
            imu_window = torch.tensor(imu_window, dtype=torch.float32)
        target = torch.tensor(target, dtype=torch.float32)
        
        return imu_window, target

    def _load_csv_file(self, file_path):
        """Loads a CSV file using NumPy (skipping the header row)."""
        data = np.loadtxt(file_path, delimiter=',', skiprows=1)
        return data