## InfoGCN++ Training Pipeline
This notebook mirrors the original InfoGCN++ data processing and optimisation setup for 2D skeleton `.npy` clips, making it easy to fine-tune the SODE backbone on custom datasets.

In [12]:
from pathlib import Path

import pandas as pd
import torch
from sklearn.metrics import classification_report, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

from act_rec.datasets import SkeletonNpyDataset
from act_rec.model.losses import LabelSmoothingCrossEntropy, masked_recon_loss
from act_rec.model.sode import SODE
from act_rec.training import TrainConfig, evaluate, train_one_epoch

In [2]:
# Paths and label mapping
data_root = Path("../data/")
csv_path = data_root / "skeleton_labels.csv"

df = pd.read_csv(csv_path).dropna()
label_to_idx = {label: idx for idx, label in enumerate(sorted(df["label"].unique()))}
df["label_idx"] = df["label"].map(label_to_idx)
df["skeleton_path"] = df["skeleton_path"].apply(lambda p: str((data_root / p).resolve()))
print(f"Total samples: {len(df)} | classes: {len(label_to_idx)}")

Total samples: 2216 | classes: 14


In [3]:
# Train/validation split
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df["label_idx"],
    random_state=42,
)
print(f"Train: {len(train_df)} | Val: {len(val_df)}")

Train: 1772 | Val: 444


In [None]:
# Optimiser and schedule configuration mirroring InfoGCN++ defaults
train_hparams = {
    "epochs": 80,
    "base_lr": 1e-2,
    "optimizer": "SGD",
    "weight_decay": 1e-4,
    "warmup_epochs": 5,
    "lr_steps": [30, 45, 60],
    "lr_decay": 0.1,
    "grad_clip": 1.0,
    "batch_size": 32,
    "test_batch_size": 64,
    "num_workers": 4,
    "prefetch_factor": 2,
    "pin_memory": bool(torch.cuda.is_available()),
    "p_interval_train": (0.5, 1.0),
    "p_interval_val": (0.95,),
    "random_rotation": True,
    "use_velocity": False,
    "preload": True,
    "preload_to_tensor": True,
    "lambda_cls": 1.0,
    "lambda_recon": 0.1,
    "lambda_feature": 0.1,
    "lambda_kl": 0.0,
    "smoothing": 0.1,
    "checkpoint_path": "sode_best.pt",
}


def adjust_learning_rate(epoch: int, optimizer: torch.optim.Optimizer, cfg: dict) -> float:
    """Warm-up followed by step decay, as in the InfoGCN++ training script."""
    warmup = cfg["warmup_epochs"]
    base_lr = cfg["base_lr"]
    if warmup > 0 and epoch < warmup:
        lr = base_lr * float(epoch + 1) / float(warmup)
    else:
        steps = cfg["lr_steps"]
        decay = cfg["lr_decay"]
        num_decays = sum(epoch >= step for step in steps)
        lr = base_lr * (decay**num_decays)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    return lr


In [5]:
# Dataset and dataloaders mirroring InfoGCN++ feeder logic
window_size = 64
num_workers = train_hparams["num_workers"]
train_dataset = SkeletonNpyDataset(
    train_df["skeleton_path"].tolist(),
    labels=train_df["label_idx"].tolist(),
    window_size=window_size,
    p_interval=train_hparams["p_interval_train"],
    random_rotation=train_hparams["random_rotation"],
    use_velocity=train_hparams["use_velocity"],
    preload=train_hparams["preload"],
    preload_to_tensor=train_hparams["preload_to_tensor"],
    repeat=1,
)
val_dataset = SkeletonNpyDataset(
    val_df["skeleton_path"].tolist(),
    labels=val_df["label_idx"].tolist(),
    window_size=window_size,
    p_interval=train_hparams["p_interval_val"],
    random_rotation=False,
    use_velocity=train_hparams["use_velocity"],
    preload=train_hparams["preload"],
    preload_to_tensor=train_hparams["preload_to_tensor"],
)

loader_kwargs = {
    "num_workers": num_workers,
    "pin_memory": train_hparams["pin_memory"],
}
if num_workers > 0:
    loader_kwargs["persistent_workers"] = True
    loader_kwargs["prefetch_factor"] = train_hparams.get("prefetch_factor", 2)

train_loader = DataLoader(
    train_dataset,
    batch_size=train_hparams["batch_size"],
    shuffle=True,
    drop_last=True,
    **loader_kwargs,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=train_hparams["test_batch_size"],
    shuffle=False,
    drop_last=False,
    **loader_kwargs,
)


### Training

In [8]:
# Model and optimisation setup
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
model = SODE(
    num_class=len(label_to_idx),
    num_point=17,
    num_person=1,
    graph="act_rec.graph.coco.Graph",
    in_channels=3,
    T=window_size,
    n_step=3,
    num_cls=4,
).to(device)

opt_name = train_hparams["optimizer"].lower()
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=train_hparams["base_lr"],
    weight_decay=train_hparams["weight_decay"],
)

config = TrainConfig(
    device=device,
    cls_loss=LabelSmoothingCrossEntropy(smoothing=train_hparams["smoothing"]),
    lambda_cls=train_hparams["lambda_cls"],
    lambda_recon=train_hparams["lambda_recon"],
    lambda_feature=train_hparams["lambda_feature"],
    lambda_kl=train_hparams["lambda_kl"],
    n_step=model.n_step,
    recon_loss_fn=masked_recon_loss,
    feature_loss_fn=masked_recon_loss,
)

num_epochs = train_hparams["epochs"]
best_top1 = 0.0
history = []
ckpt_path = Path(train_hparams["checkpoint_path"])
ckpt_path.parent.mkdir(parents=True, exist_ok=True)


In [None]:
for epoch in range(num_epochs):
    lr = adjust_learning_rate(epoch, optimizer, train_hparams)
    train_metrics = train_one_epoch(model, train_loader, optimizer, config)
    val_metrics = evaluate(model, val_loader, config)
    metrics = {**train_metrics, **val_metrics, "lr": lr}
    history.append(metrics)

    msg = (
        f"Epoch {epoch + 1}/{num_epochs} | lr={lr:.4e} | "
        f"train_tot={train_metrics['train_total_loss']:.4f} "
        f"train_cls={train_metrics['train_cls_loss']:.4f} "
        f"train_recon={train_metrics['train_recon_loss']:.4f} "
        f"train_feat={train_metrics['train_feature_loss']:.4f} | "
        f"val_tot={val_metrics['val_total_loss']:.4f} "
        f"val_cls={val_metrics['val_cls_loss']:.4f} "
        f"val_recon={val_metrics['val_recon_loss']:.4f} "
        f"val_feat={val_metrics['val_feature_loss']:.4f} "
        f"val_top1={val_metrics['val_top1']:.3f} "
        f"val_top5={val_metrics['val_top5']:.3f}"
    )
    print(msg)

    if val_metrics["val_top1"] > best_top1:
        best_top1 = val_metrics["val_top1"]
        torch.save({"model": model.state_dict(), "label_to_idx": label_to_idx}, ckpt_path)
        print(f"  -> New best checkpoint saved (top1={best_top1:.3f}).")
print("Training finished.")


## Evaluation

Load the best checkpoint and report validation metrics.

In [13]:
if not ckpt_path.exists():
    raise FileNotFoundError(f"Checkpoint not found at {ckpt_path.resolve()} â€“ run the training cell first.")

checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model"])
label_map = checkpoint.get("label_to_idx", label_to_idx)
idx_to_label = {idx: label for label, idx in label_map.items()}
ordered_indices = sorted(idx_to_label.keys())
target_names = [idx_to_label[idx] for idx in ordered_indices]

model.eval()
all_targets = []
all_preds = []

with torch.no_grad():
    for data, labels, mask, _ in val_loader:
        data = data.to(device, dtype=torch.float32)
        labels = labels.to(device, dtype=torch.long)

        logits, *_ = model(data)
        batch_size = data.size(0)
        num_class = logits.size(1)
        num_frames = logits.size(-1)
        n_cls = max(logits.size(0) // batch_size, 1)
        logits_view = logits.view(n_cls, batch_size, num_class, num_frames)
        logits_last = logits_view[-1, :, :, -1]

        preds = logits_last.argmax(dim=1)

        all_targets.extend(labels.cpu().tolist())
        all_preds.extend(preds.cpu().tolist())

print(f"Evaluated {len(all_targets)} validation samples.")
report = classification_report(
    all_targets,
    all_preds,
    labels=ordered_indices,
    target_names=target_names,
    digits=4,
    zero_division=0,
)
print(report)


print(f"Macro Precision: {precision_score(all_targets, all_preds, average='macro')}")
print(f"Macro Recall: {recall_score(all_targets, all_preds, average='macro')}")
print(f"Macro F1: {f1_score(all_targets, all_preds, average='macro')}")


Evaluated 444 validation samples.
                       precision    recall  f1-score   support

          bench_press     0.8571    0.8000    0.8276        15
       clean_and_jerk     0.8636    0.9268    0.8941        41
    handstand_pushups     0.8750    0.9333    0.9032        30
    handstand_walking     1.0000    0.8519    0.9200        27
            jump_rope     0.9691    1.0000    0.9843        94
        jumping_jacks     0.9130    0.9130    0.9130        23
               lunges     0.9623    0.9273    0.9444        55
               pullup     0.9500    0.8837    0.9157        43
               pushup     0.9375    0.9677    0.9524        31
 running_on_treadmill     1.0000    0.6667    0.8000         3
                situp     0.9000    0.9000    0.9000        10
snatch_weight_lifting     0.7500    0.7500    0.7500         8
                squat     0.8974    0.9211    0.9091        38
         wall_pushups     0.9630    1.0000    0.9811        26

             accura