# Pressure2Pose: Multi-Architecture Training & Comparison

This notebook trains and evaluates **5 model architectures** for predicting SMPL body parameters from plantar pressure sequences:

| Model | Temporal Method | Description |
|-------|----------------|-------------|
| A | CNN Baseline | Single frame, no temporal context |
| B | CNN + BiGRU | Bidirectional GRU (h=256, 2 layers) |
| C | CNN + BiLSTM | Bidirectional LSTM (h=256, 2 layers) |
| D | CNN + TCN | Dilated Conv1d (d=1,2,4,8) |
| E | CNN + Transformer | TransformerEncoder (d=512, 8 heads, 4 layers) |

## 1. Setup & Configuration

In [2]:
import sys
from pathlib import Path

# Ensure project root is on the path
PROJECT_ROOT = Path('.').resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time
import copy

from models import build_model, MODEL_REGISTRY
from models.pressure_to_smpl import SMPLLoss
from datasets import PressureSequenceDataset, create_sequence_dataloaders
from utils.metrics import (
    compute_mpjpe, compute_pa_mpjpe,
    compute_vertex_error, compute_bone_length_error
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
print(f'Available models: {list(MODEL_REGISTRY.keys())}')

Device: cuda
Available models: ['cnn_baseline', 'cnn_bigru', 'cnn_bilstm', 'cnn_tcn', 'cnn_transformer']


In [None]:
# ---- Paths ----
DATA_ROOT = PROJECT_ROOT / 'data'
SMPL_PATH = PROJECT_ROOT / 'smpl_models' / 'SMPL_python_v.1.1.0' / 'SMPL_python_v.1.1.0' / 'smpl' / 'models'
CHECKPOINT_DIR = PROJECT_ROOT / 'checkpoints'
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

# ---- Hyper-parameters ----
SEQ_LEN     = 32       # Temporal window
BATCH_SIZE  = 16
EPOCHS      = 80
LR          = 1e-4
WEIGHT_DECAY = 1e-4
PATIENCE    = 20       # Early stopping patience
GRAD_CLIP   = 1.0      # For RNN models
WARMUP_EPOCHS = 5      # LR warmup for Transformer

# Pressure sensor shape per foot
PRESSURE_H, PRESSURE_W = 33, 15

# ---- Data split ----
# Only walking1_physics.pkl exists. We split it 80/20 by manually
# creating train/val loaders from the same sequence.
TRAIN_SEQUENCES = ['walking1']
VAL_SEQUENCES   = ['walking1']  # Same file; split handled below

## 2. Data Loading & Visualization

In [None]:
import pickle

# ---- Build train / val split from a single sequence ----
csv_file = DATA_ROOT / 'walking1_cleaned.csv'
pkl_file = DATA_ROOT / 'smpl_params' / 'walking1_physics.pkl'

print(f'CSV: {csv_file} (exists={csv_file.exists()})')
print(f'PKL: {pkl_file} (exists={pkl_file.exists()})')

# Load the full sequence dataset once
full_ds = PressureSequenceDataset(
    [csv_file], [pkl_file],
    seq_len=SEQ_LEN, stride=1,
    pressure_shape=(2, PRESSURE_H, PRESSURE_W),
    normalize=True
)

# 80/20 split
n_total = len(full_ds)
n_train = int(0.8 * n_total)
n_val   = n_total - n_train

train_ds, val_ds = torch.utils.data.random_split(
    full_ds, [n_train, n_val],
    generator=torch.Generator().manual_seed(42)
)

train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

print(f'\nTrain: {n_train} windows  |  Val: {n_val} windows')
print(f'Train batches: {len(train_loader)}  |  Val batches: {len(val_loader)}')

In [None]:
# ---- Visualize a sample pressure heatmap ----
sample = full_ds[0]
pressure_seq = sample['pressure']  # (T, 2, H, W)
center = SEQ_LEN // 2

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
for i, (ax, title) in enumerate(zip(axes, ['Left Foot', 'Right Foot'])):
    im = ax.imshow(pressure_seq[center, i].numpy(), cmap='hot',
                   interpolation='bilinear', aspect='auto')
    ax.set_title(f'{title} Pressure (center frame)', fontsize=12)
    ax.set_xlabel('Sensor Column')
    ax.set_ylabel('Sensor Row')
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

print(f'Pressure sequence shape: {pressure_seq.shape}')  # (T, 2, H, W)
print(f'SMPL target dims: betas={sample["betas"].shape}, '
      f'body_pose={sample["body_pose"].shape}, '
      f'global_orient={sample["global_orient"].shape}, '
      f'transl={sample["transl"].shape}')

## 3. Model Instantiation & Parameter Counts

In [None]:
MODEL_CONFIGS = {
    'A: CNN Baseline': {
        'model': {'type': 'cnn_baseline', 'feature_dim': 512}
    },
    'B: CNN + BiGRU': {
        'model': {'type': 'cnn_bigru', 'feature_dim': 512,
                  'hidden_dim': 256, 'num_layers': 2}
    },
    'C: CNN + BiLSTM': {
        'model': {'type': 'cnn_bilstm', 'feature_dim': 512,
                  'hidden_dim': 256, 'num_layers': 2}
    },
    'D: CNN + TCN': {
        'model': {'type': 'cnn_tcn', 'feature_dim': 512,
                  'num_blocks': 4, 'kernel_size': 3}
    },
    'E: CNN + Transformer': {
        'model': {'type': 'cnn_transformer', 'feature_dim': 512,
                  'nhead': 8, 'num_layers': 4, 'dim_feedforward': 1024}
    },
}

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

rows = []
for name, cfg in MODEL_CONFIGS.items():
    m = build_model(cfg)
    n = count_params(m)
    rows.append({'Model': name, 'Parameters': n, 'Params (M)': f'{n/1e6:.2f}M'})

param_df = pd.DataFrame(rows)
print(param_df.to_string(index=False))

In [None]:
# Quick forward-pass sanity check
dummy_seq = torch.randn(2, SEQ_LEN, 2, PRESSURE_H, PRESSURE_W).to(device)

for name, cfg in MODEL_CONFIGS.items():
    m = build_model(cfg).to(device)
    betas, pose, orient, transl = m(dummy_seq)
    print(f'{name:25s}  betas={betas.shape}  pose={pose.shape}  '
          f'orient={orient.shape}  transl={transl.shape}')
    del m

print('All models passed forward-pass check.')

## 4. Unified Training Function

In [None]:
import smplx

# Load SMPL for loss computation (joints / vertices supervision)
smpl_layer = smplx.SMPL(
    model_path=str(SMPL_PATH),
    gender='neutral',
    batch_size=1,
    create_transl=False
).to(device)

criterion = SMPLLoss(
    lambda_joints=1.0,
    lambda_vertices=0.5,
    lambda_betas=0.01,
    lambda_pose=0.001
)

print('SMPL layer and loss function ready.')

In [None]:
def train_model(model_name, config, train_loader, val_loader,
                smpl_layer, criterion, device,
                epochs=EPOCHS, lr=LR, patience=PATIENCE,
                grad_clip=None, warmup_epochs=0):
    """
    Train a single model with early stopping.

    Returns:
        model: Best model (by val MPJPE)
        history: dict with train_loss, val_loss, val_mpjpe lists
    """
    model = build_model(config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    history = {'train_loss': [], 'val_loss': [], 'val_mpjpe': []}
    best_mpjpe = float('inf')
    best_state = None
    wait = 0

    print(f'\n{"="*60}')
    print(f'Training: {model_name}')
    print(f'{"="*60}')

    for epoch in range(1, epochs + 1):
        # ---- LR warmup for Transformer ----
        if warmup_epochs > 0 and epoch <= warmup_epochs:
            warmup_lr = lr * epoch / warmup_epochs
            for pg in optimizer.param_groups:
                pg['lr'] = warmup_lr

        # ---- Train ----
        model.train()
        epoch_loss = 0.0
        for batch in train_loader:
            pressure = batch['pressure'].to(device)
            target_betas = batch['betas'].to(device)
            target_pose = batch['body_pose'].to(device)
            target_orient = batch['global_orient'].to(device)
            target_transl = batch['transl'].to(device)

            optimizer.zero_grad()

            betas, body_pose, global_orient, transl = model(pressure)

            # SMPL forward for predicted
            bs = pressure.shape[0]
            smpl_layer.batch_size = bs
            pred_out = smpl_layer(
                betas=betas, body_pose=body_pose,
                global_orient=global_orient, transl=transl
            )

            # SMPL forward for target
            with torch.no_grad():
                tgt_out = smpl_layer(
                    betas=target_betas, body_pose=target_pose,
                    global_orient=target_orient, transl=target_transl
                )

            target_dict = {
                'joints': tgt_out.joints,
                'vertices': tgt_out.vertices
            }
            _, loss = criterion(
                pred_out, target_dict,
                pred_params=(betas, body_pose, global_orient, transl)
            )

            loss.backward()
            if grad_clip:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

            epoch_loss += loss.item()

        epoch_loss /= len(train_loader)

        # ---- Validate ----
        model.eval()
        val_loss_sum = 0.0
        mpjpe_list = []

        with torch.no_grad():
            for batch in val_loader:
                pressure = batch['pressure'].to(device)
                target_betas = batch['betas'].to(device)
                target_pose = batch['body_pose'].to(device)
                target_orient = batch['global_orient'].to(device)
                target_transl = batch['transl'].to(device)

                betas, body_pose, global_orient, transl = model(pressure)
                bs = pressure.shape[0]
                smpl_layer.batch_size = bs

                pred_out = smpl_layer(
                    betas=betas, body_pose=body_pose,
                    global_orient=global_orient, transl=transl
                )
                tgt_out = smpl_layer(
                    betas=target_betas, body_pose=target_pose,
                    global_orient=target_orient, transl=target_transl
                )

                target_dict = {'joints': tgt_out.joints, 'vertices': tgt_out.vertices}
                _, loss = criterion(
                    pred_out, target_dict,
                    pred_params=(betas, body_pose, global_orient, transl)
                )
                val_loss_sum += loss.item()

                mpjpe = compute_mpjpe(pred_out.joints, tgt_out.joints)
                mpjpe_list.append(mpjpe)

        val_loss = val_loss_sum / len(val_loader)
        val_mpjpe = np.mean(mpjpe_list)

        history['train_loss'].append(epoch_loss)
        history['val_loss'].append(val_loss)
        history['val_mpjpe'].append(val_mpjpe)

        if epoch <= warmup_epochs:
            pass  # Don't step scheduler during warmup
        else:
            scheduler.step()

        # ---- Early stopping ----
        if val_mpjpe < best_mpjpe:
            best_mpjpe = val_mpjpe
            best_state = copy.deepcopy(model.state_dict())
            wait = 0
        else:
            wait += 1

        if epoch % 10 == 0 or epoch == 1 or wait == 0:
            cur_lr = optimizer.param_groups[0]['lr']
            print(f'  Epoch {epoch:3d}  train_loss={epoch_loss:.4f}  '
                  f'val_loss={val_loss:.4f}  val_MPJPE={val_mpjpe:.1f}mm  '
                  f'lr={cur_lr:.1e}  {"*best" if wait==0 else ""}')

        if wait >= patience:
            print(f'  Early stopping at epoch {epoch} (patience={patience})')
            break

    # Restore best weights
    if best_state is not None:
        model.load_state_dict(best_state)

    # Save checkpoint
    ckpt_name = config['model']['type'] + '_best.pth'
    ckpt_path = CHECKPOINT_DIR / ckpt_name
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config,
        'best_mpjpe': best_mpjpe,
    }, ckpt_path)
    print(f'  Saved best model to {ckpt_path}  (MPJPE={best_mpjpe:.1f}mm)')

    return model, history

## 5. Train All 5 Models

In [None]:
trained_models = {}
histories = {}

for name, cfg in MODEL_CONFIGS.items():
    model_type = cfg['model']['type']
    is_rnn = model_type in ('cnn_bigru', 'cnn_bilstm')
    is_transformer = model_type == 'cnn_transformer'

    model, hist = train_model(
        model_name=name,
        config=cfg,
        train_loader=train_loader,
        val_loader=val_loader,
        smpl_layer=smpl_layer,
        criterion=criterion,
        device=device,
        epochs=EPOCHS,
        lr=LR,
        patience=PATIENCE,
        grad_clip=GRAD_CLIP if is_rnn else None,
        warmup_epochs=WARMUP_EPOCHS if is_transformer else 0,
    )

    trained_models[name] = model
    histories[name] = hist

In [None]:
# ---- Plot loss curves ----
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for name, hist in histories.items():
    axes[0].plot(hist['train_loss'], label=name)
    axes[1].plot(hist['val_loss'], label=name)
    axes[2].plot(hist['val_mpjpe'], label=name)

axes[0].set_title('Train Loss'); axes[0].set_xlabel('Epoch')
axes[1].set_title('Val Loss');   axes[1].set_xlabel('Epoch')
axes[2].set_title('Val MPJPE (mm)'); axes[2].set_xlabel('Epoch')

for ax in axes:
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(str(PROJECT_ROOT / 'output' / 'loss_curves.png'), dpi=150, bbox_inches='tight')
plt.show()

## 6. Evaluation & Comparison Table

In [None]:
def evaluate_model(model, val_loader, smpl_layer, device):
    """Compute all metrics on the validation set."""
    model.eval()
    mpjpe_all, pa_mpjpe_all, ve_all, ble_all = [], [], [], []
    timings = []

    with torch.no_grad():
        for batch in val_loader:
            pressure = batch['pressure'].to(device)
            target_betas = batch['betas'].to(device)
            target_pose = batch['body_pose'].to(device)
            target_orient = batch['global_orient'].to(device)
            target_transl = batch['transl'].to(device)

            t0 = time.perf_counter()
            betas, body_pose, global_orient, transl = model(pressure)
            elapsed = (time.perf_counter() - t0) * 1000 / pressure.shape[0]
            timings.append(elapsed)

            bs = pressure.shape[0]
            smpl_layer.batch_size = bs

            pred_out = smpl_layer(
                betas=betas, body_pose=body_pose,
                global_orient=global_orient, transl=transl)
            tgt_out = smpl_layer(
                betas=target_betas, body_pose=target_pose,
                global_orient=target_orient, transl=target_transl)

            mpjpe_all.append(compute_mpjpe(pred_out.joints, tgt_out.joints))
            pa_mpjpe_all.append(compute_pa_mpjpe(pred_out.joints, tgt_out.joints))
            ve_all.append(compute_vertex_error(pred_out.vertices, tgt_out.vertices))
            ble_all.append(compute_bone_length_error(pred_out.joints, tgt_out.joints))

    return {
        'MPJPE (mm)': np.mean(mpjpe_all),
        'PA-MPJPE (mm)': np.mean(pa_mpjpe_all),
        'Vertex Err (mm)': np.mean(ve_all),
        'Bone Len Err (mm)': np.mean(ble_all),
        'Inference (ms)': np.mean(timings),
    }

In [None]:
results_rows = []

for name, model in trained_models.items():
    metrics = evaluate_model(model, val_loader, smpl_layer, device)
    n_params = count_params(model)
    row = {'Model': name, 'Params (M)': f'{n_params/1e6:.2f}'}
    row.update({k: f'{v:.1f}' for k, v in metrics.items()})
    results_rows.append(row)

results_df = pd.DataFrame(results_rows)
print('\n' + '=' * 90)
print('MODEL COMPARISON')
print('=' * 90)
print(results_df.to_string(index=False))
print('=' * 90)

# Save to CSV
output_dir = PROJECT_ROOT / 'output'
output_dir.mkdir(parents=True, exist_ok=True)
results_df.to_csv(output_dir / 'model_comparison.csv', index=False)
print(f'\nSaved to {output_dir / "model_comparison.csv"}')

## 7. Showcase Visualization

In [None]:
# Visualize a prediction from the best model
best_name = results_df.iloc[results_df['MPJPE (mm)'].astype(float).idxmin()]['Model']
best_model = trained_models[best_name]
print(f'Best model: {best_name}')

# Get a validation sample
sample = val_ds[0]
pressure = sample['pressure'].unsqueeze(0).to(device)

best_model.eval()
with torch.no_grad():
    betas, body_pose, global_orient, transl = best_model(pressure)
    smpl_layer.batch_size = 1
    pred_out = smpl_layer(
        betas=betas, body_pose=body_pose,
        global_orient=global_orient, transl=transl
    )
    # Ground truth
    tgt_out = smpl_layer(
        betas=sample['betas'].unsqueeze(0).to(device),
        body_pose=sample['body_pose'].unsqueeze(0).to(device),
        global_orient=sample['global_orient'].unsqueeze(0).to(device),
        transl=sample['transl'].unsqueeze(0).to(device),
    )

pred_joints = pred_out.joints[0].cpu().numpy()
gt_joints = tgt_out.joints[0].cpu().numpy()

# Plot predicted vs GT joints (3D scatter)
fig = plt.figure(figsize=(12, 5))

ax1 = fig.add_subplot(121, projection='3d')
ax1.scatter(gt_joints[:24, 0], gt_joints[:24, 2], gt_joints[:24, 1], c='blue', s=30)
ax1.set_title('Ground Truth Joints')

ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(pred_joints[:24, 0], pred_joints[:24, 2], pred_joints[:24, 1], c='red', s=30)
ax2.set_title('Predicted Joints')

for ax in [ax1, ax2]:
    ax.set_xlabel('X'); ax.set_ylabel('Z'); ax.set_zlabel('Y')

plt.tight_layout()
plt.show()

mpjpe_val = compute_mpjpe(pred_out.joints, tgt_out.joints)
print(f'Sample MPJPE: {mpjpe_val:.1f} mm')

In [None]:
print('Training and evaluation complete!')
print(f'\nResults saved to: {output_dir}')
print(f'Checkpoints saved to: {CHECKPOINT_DIR}')