# SSL learning
Learning useful representations through recognising augmentations and matching augmented views

In [None]:
# imports
import random
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F

import utils.ssl_utils as ssl

from tqdm import tqdm
from dataclasses import dataclass, fields
from typing import List, Tuple, Optional, Dict
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

## Loading data and exploring augmentations

In [None]:
# Prepare capture24 data and split into train, val and test splits
dm = ssl.Capture24DataManager(srv_root="/srv", local_root="..")
ssl.set_seed(42)
dm.prepare() # checks everything is downloaded
(
    x_tr, y_tr, pid_tr, # x_tr: accel segements, y_tr: labels, pid_tr: participant ID
    x_val, y_val, pid_val, 
    x_te, y_te, pid_te,
    le # label encoder - maps from integer labels to strings
) = dm.train_val_test_split(prop=0.1)

In [None]:
@dataclass
class AugmentConfig:
    jitter: float = 0.5
    scaling: float = 0.5
    time_flip: float = 0.5
    axis_swap: float = 0.2
    time_mask: float = 0.3

class Augmenter:
    """
    Composable time-series augs tailored to wrist accelerometer windows.
    """
    def __init__(self, cfg: AugmentConfig | None = None):
        # if None, fall back to defaults
        self.cfg = cfg or AugmentConfig()

    @classmethod
    def available_ops(cls) -> list[str]:
        """Ops = config fields that have a same-named augmentation method."""
        names = [f.name for f in fields(AugmentConfig)]  # preserves declaration order
        return [n for n in names if hasattr(cls, n) and callable(getattr(cls, n))]

    def probs(self) -> dict[str, float]:
        """Current op -> probability mapping from the config."""
        return {n: getattr(self.cfg, n) for n in self.available_ops()}

    # ---- primitive ops: (C, L) -> (C, L) ----
    @staticmethod
    def jitter(x: torch.Tensor, sigma: float = 0.01) -> torch.Tensor:
        # add small Gaussian noise
        return x + torch.randn_like(x) * sigma

    @staticmethod
    def scaling(x: torch.Tensor, sigma: float = 0.1) -> torch.Tensor:
        # per-sample scalar scale ~ N(1, sigma^2)
        s = torch.randn((), device=x.device) * sigma + 1.0  # shape ()
        return x * s

    @staticmethod
    def time_flip(x: torch.Tensor) -> torch.Tensor:
        # reverse along time axis
        return torch.flip(x, dims=[-1])

    @staticmethod
    def axis_swap(x: torch.Tensor) -> torch.Tensor:
        # swap y and z (requires >=3 channels); no-op otherwise
        if x.size(0) >= 3:
            return x[[0, 2, 1], :]
        return x

    @staticmethod
    def time_mask(x: torch.Tensor, max_frac: float = 0.1) -> torch.Tensor:
        # zero a contiguous span of the series
        L = x.size(-1)
        w = max(1, int(L * max_frac))
        start = random.randint(0, L - w)
        y = x.clone()
        y[:, start:start + w] = 0
        return y

    # --- pipelines ---
    def view(self, x: torch.Tensor) -> torch.Tensor:
        """
        Stochastic augmentation pipeline.
        """
        if random.random() < self.cfg.jitter:
            x = self.jitter(x)
        if random.random() < self.cfg.scaling:
            x = self.scaling(x)
        if random.random() < self.cfg.axis_swap:
            x = self.axis_swap(x)
        if random.random() < self.cfg.time_mask:
            x = self.time_mask(x)
        if random.random() < self.cfg.time_flip:
            x = self.time_flip(x)
        return x

    def two_views(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.view(x), self.view(x)

### Visualising augmentations
Below we sample an accelerometer segment. To ensure the segment is interesting, we pick one with a high standard deviation. Your job is to go through each of the augmentations, apply them, and document how each augmentation changes the accelerometer data.

Then, re-examine what happens if you augment a segment with a much lower standard deviation. Is it still easy to detect the difference between the augmented and unaugmented signal?

In [None]:
# Compute SD of each segment and pick the 90th percentile example
sds = np.std(x_tr, axis=(1,2)) 
p90 = np.percentile(sds, 90)
argp90 = np.argmin(np.abs(sds - p90))
x_org = x_tr[argp90]

# Visualise the original segment
fig, ax = ssl.visualize_segment(x_org, title="Original")

# Apply augmentations to it to figure out what each is doing
aug = Augmenter()

x_org =  torch.from_numpy(x_org) # Convert segment to tensor

# Jitter 
x_jitter = aug.jitter(x_org, sigma=0.2)
fig, ax = ssl.visualize_segment(x_jitter.numpy(), title="Jittered")

# Scaling 
x_scale = aug.scaling(x_org, sigma=0.5)
fig, ax = ssl.visualize_segment(x_scale.numpy(), title="Scaled")

# Time flip 
x_flip = aug.time_flip(x_org)
fig, ax = ssl.visualize_segment(x_flip.numpy(), title="Time Flipped")

# Axis Swap 
x_swap = aug.axis_swap(x_org)
fig, ax = ssl.visualize_segment(x_swap.numpy(), title="Axis Swapped")

# Time Mask 
x_mask = aug.time_mask(x_org, max_frac=0.1)
fig, ax = ssl.visualize_segment(x_mask.numpy(), title="Time Masked")


## Augmentation recognition pretraining
So far, we have split data into training, validation and test splits, and we have explored different ways of augmenting segments of accelerometer data. 
Let's now use these augmentations to implement a SSL method: Augmentation recognition pretraining. In this method, we train a model to recognise when different augmentations have been applied to a data-set.

- Implement pretraining
- Train model and plot loss trajectory
- Fine-tune pretrained model and assess preformance

In [None]:
# Prepare the data-set!
 
# Implement the AugRec __getitem__ method.
class AugRecDataset(ssl.BaseWearableDataset):
    def __init__(self, *args, multi_label: bool = True, **kwargs):
        super().__init__(*args, **kwargs)
        self.multi_label = multi_label
        self.ops = self.aug.available_ops()          # dynamic, ordered
        self._op_probs = self.aug.probs()            # dict for quick lookup

    def __getitem__(self, idx):
        x = self._get_x(idx)                         # (C, L)
        labels = torch.zeros(len(self.ops), dtype=torch.float32)

        for k, op in enumerate(self.ops):
            p = self._op_probs[op]
            if random.random() < p:
                labels[k] = 1.0
                x = getattr(self.aug, op)(x)         # call op by name

        if not self.multi_label:
            labels = labels.max().unsqueeze(0)       # binary: any-aug

        return x, labels
    
# Define the train data-set, visualise one of the (X,y) pairs
aug_cfg = AugmentConfig() # Make any changes to the augmentation config here
aug = Augmenter(cfg=aug_cfg) 
train_dataset = AugRecDataset(
    X=x_tr,
    y=y_tr,
    augmenter=aug,
)

# Let's visualise one of the data-points before augmentation, and after augmentations
x_org = x_tr[argp90]
fig, ax = ssl.visualize_segment(x_org, title="Original")

x_train, y_train = train_dataset[argp90]
appl_augs = ", ".join([aug.available_ops()[k] for k, b in enumerate(y_train) if b > 0])
fig, ax = ssl.visualize_segment(x_train, title=appl_augs)

In [None]:
# Define a train config!
@dataclass
class TrainConfig:
    seed: int = 42
    batch_size: int = 8
    num_workers: int = 2
    lr: float = 1e-3
    weight_decay: float = 1e-4
    max_epochs: int = 3
    patience: int = 1
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

train_cfg = TrainConfig()

In [None]:
# Prepare data-loaders that efficiently sample and batch the data for model training
# weights from std (as you had)
def make_weights(values: np.ndarray, alpha: float = 1.0) -> torch.Tensor:
    v = values - values.min()
    if v.max() > 0: v = v / v.max()
    v = (v + 1e-6) ** alpha
    return torch.from_numpy(v.astype(np.float32))

# assume x_tr, x_val are (N,C,L) and exist already
sds = np.std(x_tr, axis=(1, 2))
sds = np.nan_to_num(sds, nan=0.0, posinf=0.0, neginf=0.0)
weights = make_weights(sds, alpha=2.0)

sampler = WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)

aug_train = DataLoader(
    dataset=train_dataset,
    batch_size=train_cfg.batch_size,
    sampler=sampler, # comment out this line to see the impact of sampling!
    shuffle=False,
    num_workers=train_cfg.num_workers,
    pin_memory=True
)

aug_val = DataLoader(
    dataset=AugRecDataset(X=x_val, augmenter=Augmenter(cfg=aug_cfg)),
    batch_size=train_cfg.batch_size,
    shuffle=False,
    num_workers=train_cfg.num_workers,
)

# Visualise samples in one batch - here you can compare the difference including the sampler makes!
for check_batch in aug_train:
    break
for x_train, y_train in zip(*check_batch):
    appl_augs = ", ".join([aug.available_ops()[k] for k, b in enumerate(y_train) if b > 0])
    fig, ax = ssl.visualize_segment(x_train, title=appl_augs)

In [None]:
# Build the model! 
model_cfg = ssl.SSLConfig(
    hub=ssl.HubConfig(pretrained=True),
    model=ssl.ModelConfig(
        in_channels=3,
        input_len=900,
        proj_dim=128,
        num_classes=4,
        k_labels=len(train_dataset.ops),
        freeze_backbone=False),
)

model = ssl.SSLNet(model_cfg)

x_batch, y_batch = check_batch
with torch.inference_mode():
    aug_pred = model(x_batch, head="aug")   # logits (1, k_labels)

print("Input shape:", tuple(x_batch.shape))
print("Model aug_pred shape:", tuple(aug_pred.shape))
print("Ground truth label shape:", tuple(y_batch.shape))

print("\nRaw aug_pred logits:\n", aug_pred)
pred_bin = (aug_pred > 0).int()  # naive threshold at 0
print("\nNaive binarized predictions:\n", pred_bin)
print("\nGround truth labels:\n", y_batch)
# Untrained models predictions shouldn't match 

In [None]:
# Check model can overfit to one batch
def overfit_one_batch_augrec(model, dl, steps=100, lr=1e-2, device="cpu"):
    model.train().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    b = next(iter(dl))
    x, y = b[0].to(device), b[1].to(device)

    print(f"[Overfit-check] batch shape: x={tuple(x.shape)}, y={tuple(y.shape)}")
    losses = []
    for t in range(steps):
        logits = model(x, head="aug")
        loss = F.binary_cross_entropy_with_logits(logits, y)
        opt.zero_grad(); loss.backward(); opt.step()
        losses.append(loss.item())
        if (t+1) % max(1, steps//5) == 0:
            print(f" step {t+1:03d} | loss {losses[-1]:.4f}")
    print(f" start loss={losses[0]:.4f}  end loss={losses[-1]:.4f}")
    return losses

_ = overfit_one_batch_augrec(model, aug_train, lr=train_cfg.lr, device=train_cfg.device)

In [None]:
# Finally, do model pretraining! ~ 7 minutes on GPU node
trainer = ssl.Trainer(train_cfg)
hist_ar = trainer.fit_augrec(model, aug_train, aug_val)

# Plot train, val loss trajectories
fig, ax = plt.subplots()
ax.plot(hist_ar["train_loss"], label="Train")
ax.plot(hist_ar["val_loss"], label="Val")
ax.set_xlabel("Epoch")
ax.set_ylabel("CE loss")
ax.set_title("AugRec training trajectory")

plt.show()


In [None]:
import importlib; importlib.reload(ssl)

trainer = ssl.Trainer(train_cfg)

In [None]:
# Build supervised datasets (labelled)
sup_train = DataLoader(
    ssl.BaseWearableDataset(X=x_tr, y=y_tr, augmenter=None), 
    batch_size=train_cfg.batch_size, shuffle=True, num_workers=train_cfg.num_workers
)
sup_val = DataLoader(
    ssl.BaseWearableDataset(X=x_val, y=y_val, augmenter=None),
    batch_size=train_cfg.batch_size, shuffle=False, num_workers=train_cfg.num_workers
)

# Finetune 
label_names = list(le.classes_) if le is not None else None
finetune_out = trainer.finetune(model, sup_train, sup_val, label_names=label_names)

## Finetuning a pretrained SSL model
- Explore different fine-tuning methods
- Explore classification performance