# Simple TF32 stability check (single-process)

This notebook avoids subprocesses and tests TF32 stability in the **current kernel** by training a small linear classifier on GPU with:
- TF32 OFF
- TF32 ON

It reports metric deltas and speed ratio.

In [20]:
from __future__ import annotations

import time
from pathlib import Path

import numpy as np
import pandas as pd
import torch

In [21]:
if not torch.cuda.is_available():
    raise RuntimeError('CUDA is required for TF32 testing. Start a GPU-enabled kernel.')

print('CUDA available:', torch.cuda.is_available())
print('GPU:', torch.cuda.get_device_name(0))
print('Torch:', torch.__version__)

CUDA available: True
GPU: NVIDIA GeForce RTX 4070 SUPER
Torch: 2.10.0+cu128


In [22]:
def find_repo_root(start: Path) -> Path:
    for candidate in [start, *start.parents]:
        if (candidate / 'trustME').exists():
            return candidate
    raise FileNotFoundError('Could not locate repo root')

REPO_ROOT = find_repo_root(Path.cwd())
emb_path = REPO_ROOT / 'trustME' / 'data' / 'processed' / 'imwut_tobii_finetuned_penultimate_quick' / 'avm' / 'embeddings.npz'

if emb_path.exists():
    p = np.load(emb_path, allow_pickle=False)
    X_np = p['embeddings'].astype(np.float32)
    labels = p['label_str'].astype(str)
    classes = sorted(np.unique(labels).tolist())
    c2i = {c: i for i, c in enumerate(classes)}
    y_np = np.asarray([c2i[x] for x in labels], dtype=np.int64)
    print('Using real embeddings:', emb_path)
else:
    print('Real embeddings not found, using synthetic data')
    rng = np.random.default_rng(42)
    n, d, k = 3000, 1024, 3
    X_np = rng.normal(size=(n, d)).astype(np.float32)
    y_np = rng.integers(0, k, size=n, dtype=np.int64)

print('X shape:', X_np.shape, 'y shape:', y_np.shape, 'n_classes:', int(np.max(y_np)) + 1)

Using real embeddings: /home/ppg/eyetracking/moment4ET/trustME/data/processed/imwut_tobii_finetuned_penultimate_quick/avm/embeddings.npz
X shape: (770, 1024) y shape: (770,) n_classes: 3


In [23]:
def split_indices(n: int, seed: int = 42):
    rng = np.random.default_rng(seed)
    idx = rng.permutation(n)
    n_train = int(0.7 * n)
    n_val = int(0.15 * n)
    tr = idx[:n_train]
    va = idx[n_train:n_train+n_val]
    te = idx[n_train+n_val:]
    return tr, va, te

def run_linear_ab(tf32_enabled: bool, seed: int, epochs: int = 5, batch_size: int = 256):
    torch.manual_seed(seed)
    np.random.seed(seed)

    torch.backends.cuda.matmul.allow_tf32 = bool(tf32_enabled)
    torch.backends.cudnn.allow_tf32 = bool(tf32_enabled)

    X = torch.from_numpy(X_np).to('cuda', non_blocking=True)
    y = torch.from_numpy(y_np).to('cuda', non_blocking=True)

    tr, va, te = split_indices(X.shape[0], seed=seed)
    tr_t = torch.from_numpy(tr).to('cuda')
    va_t = torch.from_numpy(va).to('cuda')
    te_t = torch.from_numpy(te).to('cuda')

    n_classes = int(y.max().item()) + 1
    model = torch.nn.Linear(X.shape[1], n_classes).to('cuda')
    opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()

    t0 = time.perf_counter()
    model.train()
    for _ in range(epochs):
        perm = tr_t[torch.randperm(tr_t.shape[0], device='cuda')]
        for i in range(0, perm.shape[0], batch_size):
            b = perm[i:i+batch_size]
            opt.zero_grad(set_to_none=True)
            logits = model(X[b])
            loss = loss_fn(logits, y[b])
            if not torch.isfinite(loss):
                raise RuntimeError(f'Non-finite loss with tf32={tf32_enabled}')
            loss.backward()
            opt.step()
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - t0

    @torch.no_grad()
    def acc(ids):
        pred = model(X[ids]).argmax(dim=1)
        return float((pred == y[ids]).float().mean().item())

    return {
        'tf32_enabled': tf32_enabled,
        'seed': seed,
        'elapsed_sec': elapsed,
        'train_acc': acc(tr_t),
        'val_acc': acc(va_t),
        'test_acc': acc(te_t),
    }

In [24]:
SEEDS = [41, 42, 43]
rows = []
for s in SEEDS:
    rows.append(run_linear_ab(tf32_enabled=False, seed=s))
    rows.append(run_linear_ab(tf32_enabled=True, seed=s))

df = pd.DataFrame(rows).sort_values(['seed', 'tf32_enabled']).reset_index(drop=True)
df

Unnamed: 0,tf32_enabled,seed,elapsed_sec,train_acc,val_acc,test_acc
0,False,41,0.008626,0.638219,0.66087,0.672414
1,True,41,0.007372,0.638219,0.66087,0.672414
2,False,42,0.007422,0.640074,0.626087,0.663793
3,True,42,0.007672,0.640074,0.626087,0.663793
4,False,43,0.00785,0.649351,0.652174,0.637931
5,True,43,0.007338,0.649351,0.652174,0.637931


In [25]:
wide = (
    df.pivot_table(index='seed', columns='tf32_enabled', values=['val_acc', 'test_acc', 'elapsed_sec'])
    .sort_index(axis=1)
)
wide.columns = ['_'.join([str(c) for c in col]) for col in wide.columns]
wide['delta_val_acc'] = wide['val_acc_True'] - wide['val_acc_False']
wide['delta_test_acc'] = wide['test_acc_True'] - wide['test_acc_False']
wide['speedup_ratio'] = wide['elapsed_sec_False'] / wide['elapsed_sec_True']
wide

Unnamed: 0_level_0,elapsed_sec_False,elapsed_sec_True,test_acc_False,test_acc_True,val_acc_False,val_acc_True,delta_val_acc,delta_test_acc,speedup_ratio
seed,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
41,0.008626,0.007372,0.672414,0.672414,0.66087,0.66087,0.0,0.0,1.170041
42,0.007422,0.007672,0.663793,0.663793,0.626087,0.626087,0.0,0.0,0.967372
43,0.00785,0.007338,0.637931,0.637931,0.652174,0.652174,0.0,0.0,1.069672


In [26]:
summary = {
    'mean_delta_val_acc': float(wide['delta_val_acc'].mean()),
    'mean_delta_test_acc': float(wide['delta_test_acc'].mean()),
    'mean_speedup_ratio': float(wide['speedup_ratio'].mean()),
}
summary

{'mean_delta_val_acc': 0.0,
 'mean_delta_test_acc': 0.0,
 'mean_speedup_ratio': 1.0690282197910863}

## Interpretation
Treat TF32 as stable if:
- no NaN/inf occurred,
- accuracy deltas are small,
- and speedup ratio is >= 1.0.