# End-to-end training of ActiNet model using Capture-24

This notebook shows how to fine-tune the final layer of the modified self-supervised ResNet-18 model on the Capture-24 dataset for the Walsmley label annotations.

In [None]:
import numpy as np
import joblib
import urllib
import os
from tqdm.auto import tqdm
import zipfile
from glob import glob
import sys
from datetime import datetime
import hashlib

sys.path.append("src")
from actinet.actinet import read
from actinet.accPlot import plotTimeSeries
from actinet.models import ActivityClassifier
from actinet.prepare import load_all_and_make_windows

WINSEC = 30 # seconds
SAMPLE_RATE = 100 # Hz
N_JOBS = 8 # Set to higher number for quicker execution, but don't exceed max.

First we download and unzip the Capture-24 dataset. 

In [None]:
os.makedirs("data", exist_ok=True)
capture24_path = "data/capture24.zip"

if not os.path.exists(capture24_path):
    url = "https://ora.ox.ac.uk/objects/uuid:99d7c092-d865-4a19-b096-cc16440cd001/files/rpr76f381b"

    with tqdm(unit='B', unit_scale=True, desc='Downloading Capture-24: ',
            unit_divisor=1024, miniters=1, ascii=True, total=6900000000) as pbar:
        urllib.request.urlretrieve(url, filename=capture24_path, 
                                   reporthook=lambda b, bsize, tsize: pbar.update(bsize))

    with zipfile.ZipFile(capture24_path, "r") as f:
        for member in tqdm(f.namelist(), desc="Unzipping: "):
            try:
                f.extract(member, "data")
            except zipfile.error:
                pass

We then break the data into the expected shape of WINSEC windows and specified labels.

In [None]:
DATAFILES = f"data/capture24/P[0-9][0-9][0-9].csv.gz"
ANNOFILE = f"data/capture24/annotation-label-dictionary.csv"
SAVEFOLDER = f"data/capture24/{WINSEC}s"

if len(glob(f"{SAVEFOLDER}/*.npy")) == 4:
    X = np.load(f"{SAVEFOLDER}/X.npy")
    Y = np.load(f"{SAVEFOLDER}/Y.npy")
    T = np.load(f"{SAVEFOLDER}/T.npy")
    pid = np.load(f"{SAVEFOLDER}/pid.npy")

else:
    X, Y, T, pid = load_all_and_make_windows(
        datafiles=glob(DATAFILES),
        annofile=ANNOFILE, 
        out_dir=SAVEFOLDER, 
        anno_label="Walmsley2020", # Choose between WillettsSpecific2018, WillettsMET2018, DohertySpecific2018,...
        sample_rate=SAMPLE_RATE,
        winsec=WINSEC,
        n_jobs=N_JOBS,
    )

We now initialise the Activity Classifier, with appropriate properties for training.

In [None]:
classifier = ActivityClassifier(
    labels = np.unique(Y),
    batch_size=1000,
    device="cuda:0",
    window_sec=WINSEC,
    verbose=True
)

print(classifier)

We optionally subset the data, to smoke test

In [None]:
def subset_data(X, Y, pid, T, n_samples=1000, n_participants=20):
    np.random.seed(42)

    if n_participants is not None:
        pids = np.random.choice(np.unique(pid), n_participants, replace=False)
        idx = np.isin(pid, pids)

        X, Y, pid, T = X[idx], Y[idx], pid[idx], T[idx]
    
    if n_samples is not None:
        idx = np.random.choice(len(X), n_samples, replace=False)
        X, Y, pid, T = X[idx], Y[idx], pid[idx], T[idx]
    
    return X, Y, pid, T

# X, Y, pid, T = subset_data(X, Y, pid, T, None, 10)

We call the fit function to train the classifier on the training data. This will save the best weights at the provided location during the training folds.

In [None]:
classifier_save_path = f"models/c24_rw_{WINSEC}s_{datetime.now().strftime('%Y%m%d')}.pt"
classifier.fit(X, Y, pid, T, True, classifier_save_path, n_splits=1)

We now need to save the classifier as a .lzma file, that is uploaded to the internet for external use.

In [None]:
classifier_file_name = f"models/ssl-ukb-c24-rw-{WINSEC}s-{datetime.now().strftime('%Y%m%d')}.joblib.lzma"

In [None]:
classifier.save(classifier_file_name)

We load the saved classifier to ensure the expected behaviour

In [None]:
loaded_classifier: ActivityClassifier = joblib.load(classifier_file_name)

print(loaded_classifier)

Note that model has not been loaded, however does have best weights dictionary saved.

We can load the model. When not specified, the model will load the ssl repository from github/cache.

In [None]:
loaded_classifier.load_model()

Now we can confirm that the model has been loaded, and is ready to predict activity labels.

In [None]:
print(loaded_classifier)

In [None]:
# Note ordering is sleep, sedentary, light, MVPA
loaded_classifier.hmm.display(precision=3)

This is useful function to get the md5 hash for the classifier file.

In [None]:
hashlib.md5(open(classifier_file_name,'rb').read()).hexdigest()

In [None]:
data, info = read("data/sample.cwa.gz", resample_hz=None, verbose=True)

In [None]:
y = loaded_classifier.predict_from_frame(data, 100, True)

In [None]:
p = plotTimeSeries(y)