### End-to-end SSL stepcounter training with OxWalk

This notebook trains the stepcounter that was used in:

Small SR, Chan S, Walmsley R, et al. (2023)
[Development and Validation of a Machine Learning Wrist-worn Step Detection Algorithm with Deployment in the UK Biobank](https://www.medrxiv.org/content/10.1101/2023.02.20.23285750v1).
medRxiv. DOI: 10.1101/2023.02.20.23285750

In [1]:
import re
import glob
import os
import numpy as np
import pandas as pd
import pathlib
import urllib
import shutil
import zipfile
import torch
import joblib
from tqdm import tqdm
from joblib import Parallel, delayed

from stepcount.models import StepCounter

#### Download OxWalk dataset

In [2]:
print(f"Downloading OxWalk...")
url = "https://ora.ox.ac.uk/objects/uuid:19d3cb34-e2b3-4177-91b6-1bad0e0163e7/files/dcj82k7829"
with urllib.request.urlopen(url) as f_src, open("OxWalk_Dec2022.zip", "wb") as f_dst:
    shutil.copyfileobj(f_src, f_dst)
print("Unzipping...")
with zipfile.ZipFile("OxWalk_Dec2022.zip", "r") as f:
    f.extractall(".")

Downloading OxWalk...
Unzipping...


Definitions and helper functions to process the dataset.

In [3]:
DATA_DIR = './OxWalk_Dec2022/Wrist_100Hz'  # location of OxWalk (the .csv files should be in this folder)
DATAFILES = os.path.join(DATA_DIR, 'P*.csv')
OUT_DIR = './data/oxwalk_30hz_w10_o0/'  # output location of the processed dataset
WINDOW_SEC = 10  # seconds
STEP_THRESHOLD = 4  # how many steps per window to consider a step (equal or more) 

DEVICE_HZ = 100  # original sample rate Hz
RESAMPLE_HZ = 30  # Hz
WINDOW_OVERLAP_SEC = 0  # seconds
WINDOW_LEN = int(DEVICE_HZ * WINDOW_SEC)  # device ticks
WINDOW_OVERLAP_LEN = int(DEVICE_HZ * WINDOW_OVERLAP_SEC)  # device ticks
WINDOW_STEP_LEN = WINDOW_LEN - WINDOW_OVERLAP_LEN  # device ticks
WINDOW_TOL = 0.01  # 1%


def resize(x, length, axis=1):
    """Resize the temporal length using linear interpolation.
    X must be of shape (N,M,C) (channels last) or (N,C,M) (channels first),
    where N is the batch size, M is the temporal length, and C is the number
    of channels.
    If X is channels-last, use axis=1 (default).
    If X is channels-first, use axis=2.
    """
    from scipy.interpolate import interp1d

    length_orig = x.shape[axis]
    t_orig = np.linspace(0, 1, length_orig, endpoint=True)
    t_new = np.linspace(0, 1, length, endpoint=True)
    x = interp1d(t_orig, x, kind="linear", axis=axis, assume_sorted=True)(
        t_new
    )
    return x


def is_good_quality(w):
    """ Window quality check """

    if w.isna().any().any():
        return False

    if len(w) != WINDOW_LEN:
        return False

    w_start, w_end = w.index[0], w.index[-1]
    w_duration = w_end - w_start
    target_duration = pd.Timedelta(WINDOW_SEC, 's')
    if np.abs(w_duration - target_duration) > WINDOW_TOL * target_duration:
        return False

    return True


def make(datafile):
    X, Y, T, P, = [], [], [], []
    y_step = []

    data = pd.read_csv(datafile, parse_dates=['timestamp'], index_col='timestamp')

    p = re.search(r'(P\d{2})', datafile, flags=re.IGNORECASE).group()

    for i in range(0, len(data), WINDOW_STEP_LEN):
        w = data.iloc[i:i + WINDOW_LEN]

        if not is_good_quality(w):
            continue

        t = w.index[0].to_datetime64()
        x = w[['x', 'y', 'z']].values
        count = int(w['annotation'].sum())

        if count >= STEP_THRESHOLD:
            y = 'walk'
        else:
            y = 'notwalk'

        X.append(x)
        Y.append(y)
        T.append(t)
        P.append(p)
        y_step.append(count)

    X = np.asarray(X)
    Y = np.asarray(Y)
    T = np.asarray(T)
    P = np.asarray(P)
    y_step = np.asarray(y_step)

    if DEVICE_HZ != RESAMPLE_HZ:
        X = resize(X, int(RESAMPLE_HZ * WINDOW_SEC))

    return X, Y, T, P, y_step

### Process dataset

Convert the raw data into 10s windowed data `X` at 30Hz, with step count annotations `Y_step` and group labels `groups`. A binary label array `Y` is also generated (for reference, not used in this notebook). Save to disk for later reuse (can be skipped if already done before).

In [7]:
pathlib.Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

x, y, t, p, y_step = zip(
    *Parallel(n_jobs=4)(
        delayed(make)(datafile)
        for datafile in tqdm(glob.glob(DATAFILES))
    )
)

X = np.vstack(x)  # data windows
Y = np.hstack(y)  # binary labels walk/notwalk (based on Y_step >= STEP_THRESHOLD)
Y_step = np.hstack(y_step)  # step counts per window
T = np.hstack(t)  # timestamps
groups = np.hstack(p)  # group ids

np.save(os.path.join(OUT_DIR, 'X'), X)
np.save(os.path.join(OUT_DIR, 'Y'), Y)
np.save(os.path.join(OUT_DIR, 'time'), T)
np.save(os.path.join(OUT_DIR, 'groups'), groups)
np.save(os.path.join(OUT_DIR, 'Y_step'), Y_step)

print(f"Saved in {OUT_DIR}")
print("X shape:", X.shape)
print("Y distribution:")
print(pd.Series(Y).value_counts())

100%|██████████████████████████████████████████████████████████████████████████████████| 39/39 [00:06<00:00,  5.59it/s]


Saved in ./data/oxwalk_30hz_w10_o0/
X shape: (13613, 300, 3)
Y distribution:
notwalk    9468
walk       4145
dtype: int64


### Train stepcounter

In [5]:
path = './data/oxwalk_30hz_w10_o0/'

X = np.load(path + 'X.npy')
Y_step = np.load(path + 'Y_step.npy')
groups = np.load(path + 'groups.npy')

In [None]:
wd_params = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'batch_size': 512
}

step = StepCounter(wd_type='ssl', wd_params=wd_params,
                   steptol=STEP_THRESHOLD, verbose=True)

step.fit(X, Y_step, groups)  # training

step.verbose = False  # don't log spam the user
step.wd.verbose = False
step.wd.batch_size = 64  # don't blow up user's memory
step.wd.device = 'cpu'  # save device-less (cpu) model

# save trained model to disk
joblib.dump(step, 'ssl.joblib.lzma', compress=('lzma', 3))

### Usage
The trained stepcounter can now be used in a Python script:

In [3]:
model = joblib.load('ssl.joblib.lzma')

# example data: 2000 windows, 10s at 30Hz, 3-axis
X = np.random.rand(2000, 300, 3)  

y_pred = model.predict(X)

# y_pred is an array of step count values for each of the 2000 windows
print(y_pred.shape, y_pred)  

# you can optionally run the prediction on a GPU (will run on CPU by default):
model.wd.device = 'cuda'
y_pred = model.predict(X)

(2000,) [0. 0. 0. ... 0. 0. 0.]


Or from the command line by passing the path to the joblib file:

`stepcount sample.cwa --model-path ssl.joblib.lzma`

With GPU:

`stepcount sample.cwa --model-path ssl.joblib.lzma --pytorch-device cuda`