# End-to-end training of Actinet model using Capture-24

This notebook fine tures the final layers of a self-supervised Resnet18 on the Capture-24 dataset for the Walsmley label annotations.

In [1]:
import numpy as np
import joblib
import urllib
import os
from tqdm.auto import tqdm
import zipfile
from glob import glob
import sys
from datetime import datetime
import hashlib

sys.path.append("src")
from actinet.models import ActivityClassifier
from actinet.prepare import load_all_and_make_windows

WINSEC = 30 # seconds
SAMPLE_RATE = 100 # Hz
N_JOBS = 8 # Set to higher number for quicker execution, but don't exceed max.

First we download and unzip the Capture-24 dataset. 

In [2]:
os.makedirs("data", exist_ok=True)
capture24_path = "data/capture24.zip"

if not os.path.exists(capture24_path):
    url = "https://ora.ox.ac.uk/objects/uuid:99d7c092-d865-4a19-b096-cc16440cd001/files/rpr76f381b"

    with tqdm(unit='B', unit_scale=True, desc='Downloading Capture-24: ',
            unit_divisor=1024, miniters=1, ascii=True, total=6900000000) as pbar:
        urllib.request.urlretrieve(url, filename=capture24_path, 
                                   reporthook=lambda b, bsize, tsize: pbar.update(bsize))

    with zipfile.ZipFile(capture24_path, "r") as f:
        for member in tqdm(f.namelist(), desc="Unzipping: "):
            try:
                f.extract(member, "data")
            except zipfile.error:
                pass

We then break the data into the expected shape of WINSEC windows and specified labels.

In [3]:
DATAFILES = f"data/capture24/P[0-9][0-9][0-9].csv.gz"
ANNOFILE = f"data/capture24/annotation-label-dictionary.csv"
SAVEFOLDER = f"data/capture24/{WINSEC}s"

if len(glob(f"{SAVEFOLDER}/*.npy")) == 4:
    X = np.load(f"{SAVEFOLDER}/X.npy")
    Y = np.load(f"{SAVEFOLDER}/Y.npy")
    T = np.load(f"{SAVEFOLDER}/T.npy")
    pid = np.load(f"{SAVEFOLDER}/pid.npy")

else:
    X, Y, T, pid = load_all_and_make_windows(
        datafiles=glob(DATAFILES), 
        annofile=ANNOFILE, 
        out_dir=SAVEFOLDER, 
        anno_label="Walmsley2020", # Choose between WillettsSpecific2018, WillettsMET2018, DohertySpecific2018,...
        sample_rate=SAMPLE_RATE,
        winsec=WINSEC,
        n_jobs=N_JOBS,
    )

We now initialise the Activity Classifier, with appropriate properties for training.

In [4]:
classifier = ActivityClassifier(
    labels = np.unique(Y),
    batch_size=1000,
    device="cuda:0",
    verbose=True
)

print(classifier)

Activity Classifier
class_labels: ['light' 'moderate-vigorous' 'sedentary' 'sleep']
window_length: 30
batch_size: 1000
device: cuda:0
hmm: Hidden Markov Model
prior: None
emission: None
transition: None
labels: None
model: Model has not been loaded.


We call the fit function to train the classifier on the training data. This will save the best weights at the provided location during the training folds.

In [5]:
classifier.fit(X, Y, pid, T, "models/c24_rw.pt")

Training SSL
Using local c:\Users\AidenA\Documents\Projects\actinet\src\actinet\torch_hub_cache\OxWearables_ssl-wearables_v1.0.0
131 Weights loaded


100%|██████████| 253/253 [03:18<00:00,  1.28it/s]


Validation loss decreased (inf --> 0.469236). Saving model ...
[  0/100] | train_loss: 0.650 | train_acc: 0.836 | val_loss: 0.469 | val_acc: 0.83


100%|██████████| 253/253 [03:04<00:00,  1.37it/s]


Validation loss decreased (0.469236 --> 0.460606). Saving model ...
[  1/100] | train_loss: 0.381 | train_acc: 0.870 | val_loss: 0.461 | val_acc: 0.83


100%|██████████| 253/253 [03:04<00:00,  1.37it/s]


Validation loss decreased (0.460606 --> 0.437871). Saving model ...
[  2/100] | train_loss: 0.358 | train_acc: 0.876 | val_loss: 0.438 | val_acc: 0.84


100%|██████████| 253/253 [03:05<00:00,  1.37it/s]


Validation loss decreased (0.437871 --> 0.435937). Saving model ...
[  3/100] | train_loss: 0.322 | train_acc: 0.887 | val_loss: 0.436 | val_acc: 0.85


100%|██████████| 253/253 [03:05<00:00,  1.36it/s]


Validation loss decreased (0.435937 --> 0.423534). Saving model ...
[  4/100] | train_loss: 0.306 | train_acc: 0.891 | val_loss: 0.424 | val_acc: 0.85


100%|██████████| 253/253 [03:05<00:00,  1.36it/s]


EarlyStopping counter: 1/5
[  5/100] | train_loss: 0.294 | train_acc: 0.896 | val_loss: 0.432 | val_acc: 0.85


100%|██████████| 253/253 [03:05<00:00,  1.36it/s]


EarlyStopping counter: 2/5
[  6/100] | train_loss: 0.284 | train_acc: 0.899 | val_loss: 0.457 | val_acc: 0.84


100%|██████████| 253/253 [03:05<00:00,  1.37it/s]


EarlyStopping counter: 3/5
[  7/100] | train_loss: 0.273 | train_acc: 0.903 | val_loss: 0.445 | val_acc: 0.85


100%|██████████| 253/253 [03:05<00:00,  1.36it/s]


EarlyStopping counter: 4/5
[  8/100] | train_loss: 0.260 | train_acc: 0.908 | val_loss: 0.461 | val_acc: 0.84


100%|██████████| 253/253 [03:05<00:00,  1.36it/s]


EarlyStopping counter: 5/5
[  9/100] | train_loss: 0.254 | train_acc: 0.910 | val_loss: 0.483 | val_acc: 0.84
Early stopping
SSLNet weights saved to models/c24_rw.pt
Training HMM
Classifying windows...


100%|██████████| 67/67 [00:16<00:00,  4.15it/s]


<actinet.models.ActivityClassifier at 0x16860e28eb0>

We now need to save the classifier as a .lzma file, that is uploaded to the internet for external use.

In [6]:
classifier_file_name = f"models/ssl_ukb_c24_rw_{datetime.now().strftime('%Y%m%d')}.joblib.lzma"

In [7]:
classifier.save(classifier_file_name)

We load the saved classifier to ensure the expected behaviour

In [8]:
loaded_classifier: ActivityClassifier = joblib.load(classifier_file_name)

print(loaded_classifier)

Activity Classifier
class_labels: ['light' 'moderate-vigorous' 'sedentary' 'sleep']
window_length: 30
batch_size: 512
device: cpu
hmm: Hidden Markov Model
prior: [0.19747338 0.04906083 0.39729389 0.35617191]
emission: [[6.9318348e-01 7.3956624e-02 2.2419752e-01 8.6639896e-03]
 [4.3404096e-01 4.8588941e-01 7.6402329e-02 3.6681890e-03]
 [1.3217883e-01 7.2224061e-03 7.8012472e-01 8.0476902e-02]
 [5.3161164e-03 3.9428155e-04 6.0614076e-02 9.3367684e-01]]
transition: [[9.65415651e-01 7.39807192e-03 2.71862776e-02 0.00000000e+00]
 [2.88154378e-02 9.66816611e-01 4.36795097e-03 0.00000000e+00]
 [1.25400384e-02 4.16399744e-04 9.87043562e-01 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00]]
labels: [0 1 2 3]
model: Model has not been loaded.


Note that model has not been loaded, however does have best weights dictionary saved.

We can load the model. When not specified, the model will load the ssl repository from github/cache.

In [9]:
loaded_classifier.load_model()

Using local c:\Users\AidenA\Documents\Projects\actinet\src\actinet\torch_hub_cache\OxWearables_ssl-wearables_v1.0.0


Now we can confirm that the model has been loaded, and is ready to predict activity labels.

In [10]:
print(loaded_classifier)

Activity Classifier
class_labels: ['light' 'moderate-vigorous' 'sedentary' 'sleep']
window_length: 30
batch_size: 512
device: cpu
hmm: Hidden Markov Model
prior: [0.19747338 0.04906083 0.39729389 0.35617191]
emission: [[6.9318348e-01 7.3956624e-02 2.2419752e-01 8.6639896e-03]
 [4.3404096e-01 4.8588941e-01 7.6402329e-02 3.6681890e-03]
 [1.3217883e-01 7.2224061e-03 7.8012472e-01 8.0476902e-02]
 [5.3161164e-03 3.9428155e-04 6.0614076e-02 9.3367684e-01]]
transition: [[9.65415651e-01 7.39807192e-03 2.71862776e-02 0.00000000e+00]
 [2.88154378e-02 9.66816611e-01 4.36795097e-03 0.00000000e+00]
 [1.25400384e-02 4.16399744e-04 9.87043562e-01 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00]]
labels: [0 1 2 3]
model: Resnet(
  (feature_extractor): Sequential(
    (layer1): Sequential(
      (0): Conv1d(3, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False, padding_mode=circular)
      (1): ResBlock(
        (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affi

This is useful function to get the md5 hash for the classifier file.

In [11]:
hashlib.md5(open(classifier_file_name,'rb').read()).hexdigest()

'd1c0e25764b8ffc7b15cd850225bbfe2'

In [13]:
np.unique(Y)

array(['light', 'moderate-vigorous', 'sedentary', 'sleep'], dtype='<U17')

In [14]:
np.full((900, 3), fill_value=np.nan)

array([[nan, nan, nan],
       [nan, nan, nan],
       [nan, nan, nan],
       ...,
       [nan, nan, nan],
       [nan, nan, nan],
       [nan, nan, nan]])

In [18]:
my_data = np.array([np.full((900, 3), fill_value=np.nan), np.full((900, 3), fill_value=np.nan)])

In [19]:
from actinet import sslmodel
from torch.utils.data import DataLoader

dataset = sslmodel.NormalDataset(my_data)
dataloader = DataLoader(
    dataset,
    batch_size=100,
    shuffle=False,
    num_workers=0,
)

model = sslmodel.get_sslnet(
    tag="v1.0.0",
    local_repo_path="C:/Users/AidenA/Documents/Python/ssl-wearables/",
    pretrained_weights="models/model_30s.pt",
    window_sec=30,
    num_labels=4,
)
model.to("cpu")

_, y_pred, _ = sslmodel.predict(
    model, dataloader, 'cpu', output_logits=False
)

Classifying windows...


100%|██████████| 1/1 [00:00<00:00,  6.67it/s]


In [20]:
y_pred

array([0, 0], dtype=int64)

In [22]:
import actipy
sample, info = actipy.read_device(
    "data/sample.cwa.gz",
    resample_hz=30,
    verbose=True,
    lowpass_hz=None,
    calibrate_gravity=True,
    detect_nonwear=True,
)

Decompressing... Done! (1.03s)
Reading file... Done! (11.90s)
Converting to dataframe... Done! (3.57s)
Getting stationary points... Done! (8.28s)
Gravity calibration... Done! (7.80s)
Nonwear detection... Done! (8.16s)
Resampling... Done! (3.26s)


In [51]:
import pandas as pd

def is_good_window(x, window_len, revelant_columns = ["x", "y", "z"]):
    """
    Check if a window is considered good based on its length and the presence of NaN values.

    Args:
        x (ndarray): Window data.
        window_len (int): The index length of the data.

    Returns:
        bool: True if the window is considered good, False otherwise.

    """
    # Check window length is correct
    if len(x[revelant_columns]) != window_len:
        return False

    # Check no nans
    if pd.isna(x[revelant_columns]).any(axis=None):
        
        return False

    return True

def make_windows(data, window_sec, return_index=False, verbose=True):
    """Split data into windows"""

    if verbose:
        print("Defining windows...")

    X, T = [], []
    for t, x in tqdm(
        data.resample(f"{window_sec}s", origin="start"),
        mininterval=5,
        disable=not verbose,
    ):
        if not is_good_window(x, 900):
            print(x)
            continue

        X.append(x)
        T.append(t)

    X = np.asarray(X)

    if return_index:
        T = pd.DatetimeIndex(T, name=data.index.name)
        return X, T

    return X

X, T = make_windows(
    sample, 30, return_index=True, verbose=True
)

Defining windows...


  0%|          | 0/16842 [00:00<?, ?it/s]

                                      x         y         z  temperature  \
time                                                                       
2014-05-07 16:27:50.430000000 -0.513421 -0.071519 -0.886475    21.049999   
2014-05-07 16:27:50.463333333 -0.497678 -0.040279 -0.870555    21.049999   
2014-05-07 16:27:50.496666666 -0.544908  0.084680 -0.838715    21.049999   
2014-05-07 16:27:50.530000000 -0.513421  0.147160 -0.950156    21.049999   
2014-05-07 16:27:50.563333333 -0.812545  0.334600 -0.775034    21.049999   
...                                 ...       ...       ...          ...   
2014-05-07 16:28:20.263333333       NaN       NaN       NaN          NaN   
2014-05-07 16:28:20.296666666       NaN       NaN       NaN          NaN   
2014-05-07 16:28:20.330000000       NaN       NaN       NaN          NaN   
2014-05-07 16:28:20.363333333       NaN       NaN       NaN          NaN   
2014-05-07 16:28:20.396666666       NaN       NaN       NaN          NaN   

           

In [50]:
len(X)

16837

In [53]:
sample.loc["2014-05-08 04:00:00":"2014-05-08 05:00:00"]

Unnamed: 0_level_0,x,y,z,temperature,light
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2014-05-08 04:00:00.030000000,-0.985922,0.007037,0.209710,20.60,3.259841
2014-05-08 04:00:00.063333333,-0.985922,0.007037,0.209710,20.60,3.259841
2014-05-08 04:00:00.096666666,-0.970179,0.007037,0.209710,20.60,3.259841
2014-05-08 04:00:00.130000000,-0.985922,0.007037,0.209710,20.60,3.259841
2014-05-08 04:00:00.163333333,-0.970179,0.007037,0.209710,20.60,3.259841
...,...,...,...,...,...
2014-05-08 05:00:00.863333333,-0.971179,0.009319,0.245954,18.35,3.281928
2014-05-08 05:00:00.896666666,-0.971179,0.009319,0.245954,18.35,3.281928
2014-05-08 05:00:00.930000000,-0.955435,0.009319,0.245954,18.35,3.281928
2014-05-08 05:00:00.963333333,-0.971179,0.009319,0.245954,18.35,3.281928
