# 3DDFA Evaluation on 300W-LP

This notebook evaluates a 3DDFA model (outputs 3DMM parameters) on the **300W-LP** dataset using **ONNX Runtime (GPU)**. Metrics included:

- NME (Normalized Mean Error) — primary metric, normalized by inter-ocular distance
- RMSE (pixel-space)
- MAE (pixel-space)
- FPS (inference speed)
- Chamfer Distance (optional, if dense ground-truth meshes available)

**Before running:** place your ONNX model, BFM `.pkl`, and the 300W-LP dataset in accessible paths and update the configuration cells below.

In [None]:
# Install (uncomment and run if needed)
# !pip install onnxruntime-gpu==1.15.1 numpy scipy matplotlib tqdm imageio opencv-python trimesh==3.21.11
# If you don't have GPU or onnxruntime-gpu, use onnxruntime (CPU): !pip install onnxruntime


In [None]:
# Configuration: edit these paths before running
ONNX_MODEL_PATH = 'models/3ddfa.onnx'   # path to your 3DDFA ONNX model that outputs 3DMM params
BFM_PKL_PATH = 'bfm/bfm_noneck_v3.pkl' # path to BFM pkl used by your model (matching training)
DATASET_ROOT = 'datasets/300W_LP'      # root dir for 300W-LP images and annotations
USE_GPU = True                         # set False to use CPU provider
NUM_WORKERS = 4


## Imports and utility functions

In [None]:
import os, time, math, pickle
from pathlib import Path
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm

# ONNX runtime
import onnxruntime as ort

# Optional: trimesh for Chamfer (if ground-truth dense meshes exist)
try:
    import trimesh
    HAS_TRIMESH = True
except Exception as e:
    HAS_TRIMESH = False
    print('trimesh not available - Chamfer will be skipped if requested')


## Helpers: load BFM (3DMM) and reconstruct vertices from parameters

In [None]:
def load_bfm(bfm_pkl_path):
    with open(bfm_pkl_path, 'rb') as f:
        bfm = pickle.load(f, encoding='latin1') if ('rb' in f.mode) else pickle.load(f)
    u = bfm.get('u').astype(np.float32)
    w_shp = bfm.get('w_shp').astype(np.float32)
    w_exp = bfm.get('w_exp').astype(np.float32)
    tri = bfm.get('tri')
    keypoints = bfm.get('keypoints', None)
    return {'u': u, 'w_shp': w_shp, 'w_exp': w_exp, 'tri': tri, 'keypoints': keypoints}

def reconstruct_vertices(bfm, alpha_shp, alpha_exp, R=None, offset=None):
    """Reconstruct dense vertices from 3DMM parameters.
    alpha_shp: (shape_dim,) or (shape_dim,1)
    alpha_exp: (exp_dim,) or (exp_dim,1)
    """
    u = bfm['u']
    w_shp = bfm['w_shp']
    w_exp = bfm['w_exp']

    # Ensure shapes: convert 2D->(N,3,k) if stored flattened
    if w_shp.ndim == 2:
        sd = alpha_shp.shape[0]
        w_shp = w_shp.reshape(-1, 3, sd)
    if w_exp.ndim == 2:
        ed = alpha_exp.shape[0]
        w_exp = w_exp.reshape(-1, 3, ed)

    delta_shp = np.tensordot(w_shp, alpha_shp.reshape(-1), axes=([2],[0]))
    delta_exp = np.tensordot(w_exp, alpha_exp.reshape(-1), axes=([2],[0]))
    verts = u + delta_shp + delta_exp

    if R is not None:
        verts = verts @ R.T
    if offset is not None:
        verts = verts + offset.reshape(1,3)
    return verts.astype(np.float32)


## ONNX Runtime: create session (GPU if available) and inference wrapper

In [None]:
def make_ort_session(onnx_path, use_gpu=True):
    providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
    sess_options = ort.SessionOptions()
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    sess = ort.InferenceSession(onnx_path, sess_options, providers=providers)
    print('ONNX session providers:', sess.get_providers())
    return sess

def run_inference(sess, img):
    """Run model on a single image. Update preprocessing/output parsing for your model."""
    input_name = sess.get_inputs()[0].name
    H, W = img.shape[:2]
    img_resized = cv2.resize(img, (224, 224))
    img_trans = img_resized.astype(np.float32) / 255.0
    img_trans = np.transpose(img_trans, (2,0,1))[None, ...]
    outputs = sess.run(None, {input_name: img_trans})
    return outputs


## Metric functions: NME, RMSE, MAE, Chamfer (optional)

In [None]:
def inter_ocular_distance(gt_landmarks):
    left = gt_landmarks[36]
    right = gt_landmarks[45]
    return np.linalg.norm(left - right)

def compute_nme(pred, gt, normalize_by='interocular'):
    pred2 = pred[:, :2]
    gt2 = gt[:, :2]
    d = inter_ocular_distance(gt2) if normalize_by=='interocular' else np.linalg.norm(gt2.max(axis=0)-gt2.min(axis=0))
    nme = np.mean(np.linalg.norm(pred2-gt2, axis=1)) / d
    return nme

def compute_rmse(pred, gt):
    return np.sqrt(np.mean(np.sum((pred[:,:2]-gt[:,:2])**2, axis=1)))

def compute_mae(pred, gt):
    return np.mean(np.abs(pred[:,:2]-gt[:,:2]))

def chamfer_distance(pc1, pc2):
    if not HAS_TRIMESH:
        raise RuntimeError('trimesh not available')
    from scipy.spatial import cKDTree
    tree1 = cKDTree(pc1)
    tree2 = cKDTree(pc2)
    d1,_ = tree1.query(pc2)
    d2,_ = tree2.query(pc1)
    return np.mean(d1) + np.mean(d2)


## Dataset loader (300W-LP) — adjust to your annotation format

This cell shows an example loader assuming annotations contain 2D landmarks per image. Edit as needed to match your dataset.

In [None]:
def load_300w_lp_annotations(root_dir):
    samples = []
    ann_dir = Path(root_dir) / 'annotations'
    img_dir = Path(root_dir) / 'images'
    for ann_fp in ann_dir.glob('*.pts'):
        img_fp = img_dir / (ann_fp.stem + '.jpg')
        try:
            pts = np.loadtxt(ann_fp)
        except Exception:
            continue
        samples.append({'image': str(img_fp), 'landmarks': pts})
    return samples


## Evaluation loop
This cell runs inference on the dataset and computes metrics. It also measures FPS.

In [None]:
def evaluate(onnx_path, bfm_pkl, dataset_root, use_gpu=True, max_samples=None):
    sess = make_ort_session(onnx_path, use_gpu)
    bfm = load_bfm(bfm_pkl)
    samples = load_300w_lp_annotations(dataset_root)

    nmEs = []
    rmses = []
    maes = []
    times = []

    for i, s in enumerate(tqdm(samples)):
        if max_samples and i>=max_samples:
            break
        img = cv2.imread(s['image'])
        if img is None:
            continue
        t0 = time.time()
        outputs = run_inference(sess, img)
        t1 = time.time()
        times.append(t1-t0)

        # Parse outputs depending on your ONNX model outputs
        # Example assume outputs = [pose_vec, alpha_shp, alpha_exp]
        try:
            pose_vec = outputs[0].reshape(-1)
            alpha_shp = outputs[1].reshape(-1)
            alpha_exp = outputs[2].reshape(-1)
        except Exception:
            print('Model outputs:', [o.name for o in sess.get_outputs()])
            raise

        verts = reconstruct_vertices(bfm, alpha_shp, alpha_exp)

        if 'landmarks' in s:
            gt = s['landmarks']
            if bfm.get('keypoints') is not None:
                kps = np.array(bfm['keypoints']).astype(int)
                pred_landmarks_3d = verts[kps]
                pred_landmarks_2d = pred_landmarks_3d[:,:2]
            else:
                pred_landmarks_2d = verts[:len(gt), :2]

            nme = compute_nme(pred_landmarks_2d, gt)
            rmse = compute_rmse(pred_landmarks_2d, gt)
            mae = compute_mae(pred_landmarks_2d, gt)

            nmEs.append(nme)
            rmses.append(rmse)
            maes.append(mae)

    results = {
        'NME_mean': float(np.mean(nmEs)) if len(nmEs)>0 else None,
        'NME_median': float(np.median(nmEs)) if len(nmEs)>0 else None,
        'RMSE_mean': float(np.mean(rmses)) if len(rmses)>0 else None,
        'MAE_mean': float(np.mean(maes)) if len(maes)>0 else None,
        'FPS': float(1.0/np.mean(times)) if len(times)>0 else None,
        'N_samples': len(nmEs)
    }
    return results, nmEs


In [None]:
# Example run (uncomment to execute):
# results, nmEs = evaluate(ONNX_MODEL_PATH, BFM_PKL_PATH, DATASET_ROOT, use_gpu=USE_GPU, max_samples=200)
# print(results)


## Visualization helpers
Plot NME CDF and example overlays

In [None]:
def plot_nme_cdf(nmes, ax=None):
    import numpy as np
    if ax is None:
        fig, ax = plt.subplots()
    nmes = np.array(nmes)
    vals = np.sort(nmes)
    cdf = np.arange(1, len(vals)+1) / len(vals)
    ax.plot(vals, cdf)
    ax.set_xlabel('NME')
    ax.set_ylabel('CDF')
    ax.grid(True)
    return ax

# Example usage:
# ax = plot_nme_cdf(nmEs)


## Notes and Interpretation

- **NME**: primary metric. Lower is better. Use inter-ocular normalization.
- **RMSE/MAE**: raw pixel errors — useful for intuition.
- **FPS**: shows inference speed; measure on target hardware.
- **Chamfer**: if you have dense GT meshes, Chamfer measures 3D similarity.

Interpretation guidance:

- Report mean and median NME; median is robust to outliers.
- Provide CDF plot of NME (x-axis NME, y-axis proportion ≤ that NME). Many papers show the curve.
- Compare FPS separately (model accuracy vs speed tradeoff).

Edit dataset loading and model input/output parsing to match your exact ONNX model and dataset format.