# Self-supervised learning tutorial 

This notebook aims to provide you with a basic overview of self-supervised learning using accelerometers. There are three main components.

1. Using the pre-trained model using self-supervision 

2. Surgical fine-tuning to enhance the downstream performance  

3. Design novel self-supervised learning tasks for representation learning

## 1. Fine tuning a pretrained model

This notebook contains a minimal example of how to do fine-tuning on a pretrained PyTorch model. Fine-tuning means we take a pretrained model, and re-train it for a supervised learning task. The model pre-training in this example was done on 700,000 person-days of data on the UK Biobank. Details for the model development can be found in *[Self-supervised Learning for Human Activity Recognition Using 700,000 Person-days of Wearable Data](https://oxwearables.github.io/ssl-wearables/)*.

The target downstream dataset is the Capture-24 Dataset with Walmsley labels ('light' 'moderate-vigorous' 'sedentary' 'sleep'). Generate a `Y.npy` yourself with the Walmsley labels using the earlier notebooks.

Some helper functions and classes are loaded from `utils.py` and `data.py` in the utils folder.

In [2]:
import joblib
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.model_selection import GroupShuffleSplit
from sklearn.preprocessing import LabelEncoder

from utils.data import NormalDataset, resize, get_inverse_class_weights
from utils.utils import EarlyStopping

Load the raw data in Numpy array format. You also need to change the paths inside this function to reflect your environment and dataset.

In [9]:
def load_data():
    root = '/Users/hangy/Desktop/processed_data'

    X = np.load(os.path.join(root, 'X.npy'), mmap_mode='r')  # accelerometer data
    Y = np.load(os.path.join(root, 'Y.npy'))  # true labels
    pid = np.load(os.path.join(root, 'pid.npy'))  # participant IDs
    time = np.load(os.path.join(root, 'T.npy'))  # timestamps

    print(f'X shape: {X.shape}')
    print(f'Y shape: {Y.shape}')  # same shape as pid and time
    print(f'Label distribution:\n{pd.Series(Y).value_counts()}')

    # The original labels in Y are in categorical format (e.g.: 'light', 'sleep', etc). PyTorch expects numerical labels (e.g.: 0, 1, etc).
    # LabelEncoder transforms categorical labels -> numerical.
    # After obtaining the test predictions, you can use le.inverse_transform(y) to go from numerical -> categorical (the fitted le object is returned at the end of this function)
    le = LabelEncoder()
    le.fit(np.unique(Y))

    y = le.transform(Y)
    print(f'Original labels: {le.classes_}')
    print(f'Transformed labels: {le.transform(le.classes_)}')

    # down sample if required.
    # our pre-trained model expects windows of 30s at 30Hz = 900 samples
    input_size = 900  # 10s at 30Hz

    if X.shape[1] == input_size:
        print("No need to downsample")
    else:
        X = resize(X, input_size)

    X = X.astype(
        "f4"
    )  # PyTorch defaults to float32

    # generate train/test splits
    folds = GroupShuffleSplit(
        1, test_size=0.2, random_state=42
    ).split(X, y, groups=pid)
    train_idx, test_idx = next(folds)

    x_test = X[test_idx]
    y_test = y[test_idx]
    time_test = time[test_idx]
    group_test = pid[test_idx]

    # further split train into train/val
    X = X[train_idx]
    y = y[train_idx]
    pid = pid[train_idx]
    time = time[train_idx]

    folds = GroupShuffleSplit(
        1, test_size=0.125, random_state=41
    ).split(X, y, groups=pid)
    train_idx, val_idx = next(folds)

    x_train = X[train_idx]
    x_val = X[val_idx]

    y_train = y[train_idx]
    y_val = y[val_idx]

    time_train = time[train_idx]
    time_val = time[val_idx]

    group_train = pid[train_idx]
    group_val = pid[val_idx]

    return (
        x_train, y_train, group_train, time_train,
        x_val, y_val, group_val, time_val,
        x_test, y_test, group_test, time_test,
        le
    )

In [10]:
(
    x_train, y_train, group_train, time_train,
    x_val, y_val, group_val, time_val,
    x_test, y_test, group_test, time_test,
    le
) = load_data()

X shape: (294659, 900, 3)
Y shape: (294659,)
Label distribution:
sedentary            116544
sleep                113355
light                 50822
moderate-vigorous     13938
dtype: int64
Original labels: ['light' 'moderate-vigorous' 'sedentary' 'sleep']
Transformed labels: [0 1 2 3]
No need to downsample


We now load the pre-trained self-supervised PyTorch model (a ResNet-18) from its GitHub repo (https://www.github.com/OxWearables/ssl-wearables).
This repo exposes a Torch Hub API, and the model can be loaded using `torch.hub.load()`. Take note of the `pretrained=True` argument: this loads the pretrained weights into the model.

Deep learning training loops benefit from batching, so called mini-batch training. It is faster than passing the whole dataset at once, and prevents getting stuck in local minima. The `NormalDataset` and `DataLoader` classes handle this process. `NormalDataset` implements as map-style dataset as described here https://pytorch.org/docs/stable/data.html. For the training dataset, we also enable augmentation by setting `transform=True`. Inspect the class to see how it works.

The resulting `DataLoader` objects expose an iterable that will return a minibatch containing the accelerometer data, the ground truth label and the participant id. We later iterate over this dataloader during the training and testing loop using `enumerate()`.

In [9]:
repo = 'OxWearables/ssl-wearables'
my_device = 'cuda:0'

# load the pretrained model
sslnet: nn.Module = torch.hub.load(repo, 'harnet30', trust_repo=True, class_num=4, pretrained=True)
sslnet.to(my_device)

# construct dataloaders
train_dataset = NormalDataset(x_train, y_train, group_train, name="training", transform=True)
val_dataset = NormalDataset(x_val, y_val, group_val, name="validation")
test_dataset = NormalDataset(x_test, y_test, group_test, name="test")

train_loader = DataLoader(
    train_dataset,
    batch_size=2000,
    shuffle=True,
    num_workers=2,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=2000,
    shuffle=False,
    num_workers=0,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=2000,
    shuffle=False,
    num_workers=0,
)

Using cache found in /home/azw524/.cache/torch/hub/OxWearables_ssl-wearables_main


131 Weights loaded
training set sample count : 206503
validation set sample count : 28664
test set sample count : 59492


PyTorch models don't have `fit()` or `predict()` functions. We define the helper functions `train()`, `_validate_model()` and `predict()` ourselves. Inspect these to see what's going on.

The model is then trained and tested. Training is done with an early-stopping mechanism. If the validation loss doesn't improve for 5 consecutive epochs, training is halted and the best weights prior to early-stopping are used. Inspect the `EarlyStopping` class in `utils.py` to see how this works.

In [4]:
def train(model, train_loader, val_loader, my_device, weights=None):
    """
    Iterate over the training dataloader and train a pytorch model.
    After each epoch, validate model and early stop when validation loss function bottoms out.

    Trained model weights will be saved to disk (state_dict.pt).

    :param nn.Module model: pytorch model
    :param train_loader: training data loader
    :param val_loader: validation data loader
    :param str my_device: pytorch map device.
    :param weights: training class weights (to enable weighted loss function)
    """

    state_dict = 'state_dict.pt'
    num_epoch = 100

    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.0001, amsgrad=True
    )

    if weights:
        weights = torch.FloatTensor(weights).to(my_device)
        loss_fn = nn.CrossEntropyLoss(weight=weights)
    else:
        loss_fn = nn.CrossEntropyLoss()

    early_stopping = EarlyStopping(
        patience=5, path=state_dict, verbose=True
    )

    for epoch in range(num_epoch):
        model.train()
        train_losses = []
        train_acces = []
        for i, (x, y, _) in enumerate(tqdm(train_loader)):
            x.requires_grad_(True)
            x = x.to(my_device, dtype=torch.float)
            true_y = y.to(my_device, dtype=torch.long)

            optimizer.zero_grad()

            logits = model(x)
            loss = loss_fn(logits, true_y)
            loss.backward()
            optimizer.step()

            pred_y = torch.argmax(logits, dim=1)
            train_acc = torch.sum(pred_y == true_y)
            train_acc = train_acc / (pred_y.size()[0])

            train_losses.append(loss.cpu().detach())
            train_acces.append(train_acc.cpu().detach())

        val_loss, val_acc = _validate_model(model, val_loader, my_device, loss_fn)

        epoch_len = len(str(num_epoch))
        print_msg = (
            f"[{epoch:>{epoch_len}}/{num_epoch:>{epoch_len}}] | "
            + f"train_loss: {np.mean(train_losses):.3f} | "
            + f"train_acc: {np.mean(train_acces):.3f} | "
            + f"val_loss: {val_loss:.3f} | "
            + f"val_acc: {val_acc:.2f}"
        )

        early_stopping(val_loss, model)
        print(print_msg)

        if early_stopping.early_stop:
            print('Early stopping')
            print(f'SSLNet weights saved to {state_dict}')
            break


def _validate_model(model, val_loader, my_device, loss_fn):
    """ Iterate over a validation data loader and return mean model loss and accuracy. """
    model.eval()
    losses = []
    acces = []
    for i, (x, y, _) in enumerate(val_loader):
        with torch.inference_mode():
            x = x.to(my_device, dtype=torch.float)
            true_y = y.to(my_device, dtype=torch.long)

            logits = model(x)
            loss = loss_fn(logits, true_y)

            pred_y = torch.argmax(logits, dim=1)

            val_acc = torch.sum(pred_y == true_y)
            val_acc = val_acc / (list(pred_y.size())[0])

            losses.append(loss.cpu().detach())
            acces.append(val_acc.cpu().detach())
    losses = np.array(losses)
    acces = np.array(acces)
    return np.mean(losses), np.mean(acces)


def predict(model, data_loader, my_device):
    """
    Iterate over the dataloader and do inference with a pytorch model.

    :param nn.Module model: pytorch Module
    :param data_loader: pytorch dataloader
    :param str my_device: pytorch map device
    :return: true labels, model predictions, pids
    :rtype: (np.ndarray, np.ndarray, np.ndarray)
    """

    from tqdm import tqdm

    predictions_list = []
    true_list = []
    pid_list = []
    model.eval()

    for i, (x, y, pid) in enumerate(tqdm(data_loader)):
        with torch.inference_mode():
            x = x.to(my_device, dtype=torch.float)
            logits = model(x)
            true_list.append(y)
            pred_y = torch.argmax(logits, dim=1)
            predictions_list.append(pred_y.cpu())
            pid_list.extend(pid)
    true_list = torch.cat(true_list)
    predictions_list = torch.cat(predictions_list)

    return (
        torch.flatten(true_list).numpy(),
        torch.flatten(predictions_list).numpy(),
        np.array(pid_list),
    )

In [None]:
# Train the model. The trained weights will be saved in the file 'state_dict.pt'
train(sslnet, train_loader, val_loader, my_device, get_inverse_class_weights(y_train))

In [17]:
# helper function to calculate classification performance scores: precision, recall, F1 and Kappa
def classification_scores(y_test, y_test_pred):
    import sklearn.metrics as metrics

    cohen_kappa = metrics.cohen_kappa_score(y_test, y_test_pred)
    precision = metrics.precision_score(
        y_test, y_test_pred, average="macro", zero_division=0
    )
    recall = metrics.recall_score(
        y_test, y_test_pred, average="macro", zero_division=0
    )
    f1 = metrics.f1_score(
        y_test, y_test_pred, average="macro", zero_division=0
    )

    data = {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "kappa": cohen_kappa,
    }

    df = pd.DataFrame(data, index=[0])  # use a dataframe because this prints nicely later

    return df

In [18]:
# load fine tuned weights (best weights prior to early-stopping) and do inference on the test set
model_dict = torch.load('state_dict.pt', map_location=my_device)
sslnet.load_state_dict(model_dict)

y_test, y_test_pred, pid_test = predict(sslnet, test_loader, my_device)

100%|██████████| 20/20 [00:07<00:00,  2.69it/s]


In [22]:
scores = classification_scores(y_test, y_test_pred)
print(scores.round(3))

   precision  recall     f1  kappa
0      0.823   0.836  0.829  0.805


## 2. Surgical fine-tuning
The fine-tuning in the previous part retrained all the layers in the model. Intuitively speaking, some information might be forgotten if we fine-tune on all the layers especially on small datasets.

It is possible to freeze the weights of certain layers, and only fine-tune on selected layers. This is called surgical fine-tuning. Following the recent paper *[Surgical Fine-Tuning Improves Adaptation to Distribution Shifts](https://arxiv.org/abs/2210.11466)*, we would like to investigate the most optimal configuration.

Below, we demonstrate how to access the list of model parameters and how to freeze certain weights during training. 

In [22]:
repo = 'OxWearables/ssl-wearables'
my_device = 'cuda:0'

# Load the model again. This resets the model with only the pretrained weights.
sslnet: nn.Module = torch.hub.load(repo, 'harnet30', trust_repo=True, class_num=4, pretrained=True)
sslnet.to(my_device)

To freeze network weights during training, you will need to know the name of the layers that you want to freeze, then set their `requires_grad` property to `False`. In this way, gradients will not be computed for those parameters.

In [18]:
for name, param in sslnet.named_parameters():
    print(name)

feature_extractor.layer1.0.weight
feature_extractor.layer1.1.bn1.weight
feature_extractor.layer1.1.bn1.bias
feature_extractor.layer1.1.bn2.weight
feature_extractor.layer1.1.bn2.bias
feature_extractor.layer1.1.conv1.weight
feature_extractor.layer1.1.conv2.weight
feature_extractor.layer1.2.bn1.weight
feature_extractor.layer1.2.bn1.bias
feature_extractor.layer1.2.bn2.weight
feature_extractor.layer1.2.bn2.bias
feature_extractor.layer1.2.conv1.weight
feature_extractor.layer1.2.conv2.weight
feature_extractor.layer1.3.weight
feature_extractor.layer1.3.bias
feature_extractor.layer2.0.weight
feature_extractor.layer2.1.bn1.weight
feature_extractor.layer2.1.bn1.bias
feature_extractor.layer2.1.bn2.weight
feature_extractor.layer2.1.bn2.bias
feature_extractor.layer2.1.conv1.weight
feature_extractor.layer2.1.conv2.weight
feature_extractor.layer2.2.bn1.weight
feature_extractor.layer2.2.bn1.bias
feature_extractor.layer2.2.bn2.weight
feature_extractor.layer2.2.bn2.bias
feature_extractor.layer2.2.conv1.w

### Freezing all the conv layers but the linear layers 

In [17]:
def set_bn_eval(m):
    # keep the batch norm stats during forward pass
    # see https://discuss.pytorch.org/t/how-to-freeze-bn-layers-while-training-the-rest-of-network-mean-and-var-wont-freeze/89736
    classname = m.__class__.__name__
    if classname.find("BatchNorm1d") != -1:
        m.eval()

In [16]:
i = 0
name_idx = 0
for name, param in sslnet.named_parameters():
    if name.split(".")[name_idx] == "feature_extractor":
        param.requires_grad = False
        i += 1
sslnet.apply(set_bn_eval)

In [19]:
print("Weights being frozen: %d" % i)

Weights being frozen: 63


### Freezing all the weights layers in the first residual block

In [20]:
i = 0
name_idx = 1
for name, param in sslnet.named_parameters():
    if name.split(".")[name_idx] == "layer1":
        param.requires_grad = False
        i += 1
sslnet.apply(set_bn_eval)

In [21]:
print("Weights being frozen: %d" % i)

Weights being frozen: 15


Complete the following exercises by freezing parts of the model and fine-tuning the model on Capture24.

### Exercise 2-1. Does the performance change if you only fine-tune the first layer?

### Exercise 2-2. Does the performance change if you only fine-tune the middle layer?

### Exercise 2-3 [challenge]. What fine-tuning configurations might yield the best performance?  

## 3. Design novel self-supervised learning tasks for representation learning

Self-supervised learning means we train a model with self-derived labels. This is typically done on datasets that lack ground truth labels. We may not know the true ground truth label of a piece of data, but we can try to derive a label ourselves. For accelerometer signals, we can transform the signal in a certain way and then use the type of transformation as the label (for example: 'rotation'). The model is then trained to predict the transformation label. This may seem useless, and by itself it is, but when followed up with a supervised learning task (the fine-tuning) it can improve performance.

Once you know what your self-supervised learning task is, the implementation is usually easy because all you have to do is to change the label vector as x stays the same. The training pipeline also shouldn't change much. That's how many people differentiate between unsupervised learning and self-supervised learning when training using labelled data. **If your training pipeline stays the same, then your technique is *self-supervised*. If your pipeline changes, then your technique is *unsupervised*.**

In [47]:
def reverse(input_x, label):
    # label = 0: no reversal
    # label = 1: reversal 
    if label == 0:
        return input_x
    else:
        return np.flip(input_x)

In [48]:
reverse_probability = 0.5
new_X = []
new_Y = []
for i in range(len(x_train)):
    current_x = x_train[i]
    if np.random.rand() > reverse_probability:
        current_y = 1
    else:
        current_y = 0
    new_x = reverse(current_x, current_y)
    new_X.append(new_x)
    
    new_Y.append(current_y)

In [49]:
new_X = np.array(new_X)
new_Y = np.array(new_Y)

### Exercise 3-1: Can you pretrain the capture-24 using the reverse task above and see if it helps with the activity recognition?

### Exercise 3-2: Can you implement any other self-supervised tasks that you think might help?