# ST-GCN Triplet Training

This notebook fine-tunes ST-GCN using triplet loss and shows progress with `tqdm`.

In [None]:

import random
from pathlib import Path

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm.notebook import tqdm

from utils.config_loader import load_config
from utils.stgcn_backbone import load_pretrained
from datasets import TripletDataset

STGCN_ROOT = Path('model/stgcn_ntu_init.pth')


In [None]:

cfg = load_config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if not STGCN_ROOT.is_file():
    raise FileNotFoundError(f"Pretrained weights not found: {STGCN_ROOT}")

model = load_pretrained(STGCN_ROOT, STGCN_ROOT, in_channels=16, num_class=128)

for p in list(model.st_gcn_networks)[:3]:
    for param in p.parameters():
        param.requires_grad = False
model.to(device)

# dataset
ds = TripletDataset(cfg.PROCESSED_DATA_DIR, cfg)
indices = list(range(len(ds)))
random.shuffle(indices)
split = int(0.8 * len(indices))
train_idx = indices[:split]
val_idx = indices[split:]

train_dl = DataLoader(ds, batch_size=cfg.TRIPLET_BATCH_SIZE,
                      sampler=torch.utils.data.SubsetRandomSampler(train_idx))
val_dl = DataLoader(ds, batch_size=cfg.TRIPLET_BATCH_SIZE,
                    sampler=torch.utils.data.SubsetRandomSampler(val_idx))

opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=cfg.TRIPLET_LR, weight_decay=1e-5)
criterion = torch.nn.TripletMarginLoss(margin=0.2, p=2)


In [None]:

for epoch in range(cfg.TRIPLET_EPOCHS):
    model.train()
    for a, p, n in tqdm(train_dl, desc=f'Train {epoch+1}/{cfg.TRIPLET_EPOCHS}'):
        a = a.to(device).permute(0, 2, 1).unsqueeze(-1).unsqueeze(-1)
        p = p.to(device).permute(0, 2, 1).unsqueeze(-1).unsqueeze(-1)
        n = n.to(device).permute(0, 2, 1).unsqueeze(-1).unsqueeze(-1)
        z_a = torch.nn.functional.normalize(model(a), dim=1)
        z_p = torch.nn.functional.normalize(model(p), dim=1)
        z_n = torch.nn.functional.normalize(model(n), dim=1)
        loss = criterion(z_a, z_p, z_n)
        loss.backward()
        opt.step()
        opt.zero_grad()

    model.eval()
    cos_ap = 0.0
    cos_an = 0.0
    count = 0
    with torch.no_grad():
        for a, p, n in tqdm(val_dl, desc='Validate'):
            a = a.to(device).permute(0, 2, 1).unsqueeze(-1).unsqueeze(-1)
            p = p.to(device).permute(0, 2, 1).unsqueeze(-1).unsqueeze(-1)
            n = n.to(device).permute(0, 2, 1).unsqueeze(-1).unsqueeze(-1)
            z_a = torch.nn.functional.normalize(model(a), dim=1)
            z_p = torch.nn.functional.normalize(model(p), dim=1)
            z_n = torch.nn.functional.normalize(model(n), dim=1)
            cos_ap += (z_a * z_p).sum(-1).mean().item()
            cos_an += (z_a * z_n).sum(-1).mean().item()
            count += 1
    cos_ap /= max(1, count)
    cos_an /= max(1, count)
    print(f'Epoch {epoch+1}/{cfg.TRIPLET_EPOCHS} val cos+={cos_ap:.3f} cos-={cos_an:.3f}')


In [None]:

# Save trained model
out_path = cfg.PROCESSED_DATA_DIR / 'stgcn_triplet_suemd.pth'
torch.save(model.state_dict(), out_path)
print(f'Saved model to {out_path}')


To export this notebook to HTML, run the following command from the repository root:

```bash
jupyter nbconvert --to html notebooks/training/triplet_training.ipynb --output-dir notebooks/exported
```
