
# Symmetric Function Embedding (PyTorch)

This notebook trains a tiny encoder–decoder MLP on a **rotationally symmetric** target function $f(\mathbf{x}) = \sin(|\mathbf{x}|)$.

It includes a low-dimensional **embedding bottleneck** you can analyze, and saves:
- `toy_symmetry_model.pt` – the trained model (state dict + metadata)
- `embeddings_grid.npz` – a polar grid of inputs with the model's embeddings and predictions

> Tip: tweak `cfg.embed_dim` and the target function to explore different symmetries.


In [13]:
# --------------------
# Setup & Config
# --------------------
import math
import os
import random
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

print("Torch:", torch.__version__)
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available() and getattr(torch.backends.mps, 'is_built', lambda: True)():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Device:", device)

_run_path_number = -1
_run_path = ''
while os.path.exists(_run_path) or _run_path_number < 0:
    _run_path_number += 1
    _run_path = os.path.join('runs', f'symmetry_run_{_run_path_number}')

@dataclass
class Config:
    seed: int = 42
    n_train: int = 50_000
    n_val: int = 5_000
    batch_size: int = 256
    lr: float = 1e-3
    epochs: int = 25
    hidden: int = 64
    embed_dim: int = 4   # embedding space size to analyze
    data_range: float = 3.0  # sample x,y uniformly in [-R, R]
    device: torch.device = device
    model_path: str = os.path.join(_run_path, "model.pt")
    emb_grid_path: str = os.path.join(_run_path, "embeddings_grid.npz")

cfg = Config()

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(cfg.seed)

# generate a small random suffix; use integer literal for upper bound to avoid TypeError
writer = SummaryWriter(log_dir=_run_path)
writer.add_text('config', str(dict(hidden=cfg.hidden, embed_dim=cfg.embed_dim, data_range=cfg.data_range, seed=cfg.seed)))
global_step = 0

Torch: 2.8.0
Device: mps


In [14]:
# --------------------
# Data: rotational symmetry (online training)
# Target: f(x, y) = sin(sqrt(x^2 + y^2))
# --------------------
class RandomXYDataset(torch.utils.data.IterableDataset):
    """Iterable dataset that samples (x,y) pairs uniformly in [-R, R]^2 each epoch."""
    def __init__(self, n, R):
        super().__init__()
        self.n = int(n)
        self.R = float(R)

    def __iter__(self):
        # Generate all samples for this epoch and yield them in random order.
        xy = np.random.uniform(-self.R, self.R, size=(self.n, 2)).astype(np.float32)
        r = np.linalg.norm(xy, axis=1, keepdims=True).astype(np.float32)
        y = np.sin(r).astype(np.float32)

        # Optionally shuffle order so batches vary each epoch
        idx = np.arange(self.n)
        np.random.shuffle(idx)
        for i in idx:
            yield torch.from_numpy(xy[i]), torch.from_numpy(y[i])

# Validation set remains pre-sampled for stable evaluation
x_val = np.random.uniform(-cfg.data_range, cfg.data_range, size=(cfg.n_val, 2)).astype(np.float32)
r_val = np.linalg.norm(x_val, axis=1, keepdims=True).astype(np.float32)
y_val = np.sin(r_val).astype(np.float32)

train_dataset = RandomXYDataset(cfg.n_train, cfg.data_range)
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=False, drop_last=True)
val_loader = DataLoader(TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val)), batch_size=1024, shuffle=False)

# Print dataset sizes for confirmation
print(f"Training: online sampling, {cfg.n_train} samples/epoch; Validation: {len(x_val)} samples")

Training: online sampling, 50000 samples/epoch; Validation: 5000 samples


In [15]:

# --------------------
# Model: Encoder -> Embedding -> Decoder
# --------------------
class SymmetricToyNet(nn.Module):
    def __init__(self, in_dim=2, hidden=64, embed_dim=2):
        super().__init__()
        # encoder
        self.enc = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, hidden),
            nn.GELU(),
        )
        # embedding (bottleneck)
        self.embed = nn.Linear(hidden, embed_dim)

        # decoder
        self.dec = nn.Sequential(
            nn.GELU(),
            nn.Linear(embed_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, 1),
        )

        # Kaiming init
        def init(m):
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
                if m.bias is not None:
                    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
                    bound = 1 / math.sqrt(fan_in)
                    nn.init.uniform_(m.bias, -bound, bound)
        self.apply(init)

    def forward(self, x):
        h = self.enc(x)
        z = self.embed(h)
        out = self.dec(z)
        return out, z

model = SymmetricToyNet(hidden=cfg.hidden, embed_dim=cfg.embed_dim).to(cfg.device)
model

SymmetricToyNet(
  (enc): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): GELU(approximate='none')
  )
  (embed): Linear(in_features=64, out_features=4, bias=True)
  (dec): Sequential(
    (0): GELU(approximate='none')
    (1): Linear(in_features=4, out_features=64, bias=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [None]:
# --------------------
# Training
# --------------------
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=1e-4)
loss_fn = nn.MSELoss()

def evaluate():
    model.eval()
    total, n = 0.0, 0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(cfg.device)
            yb = yb.to(cfg.device)
            pred, _ = model(xb)
            loss = loss_fn(pred, yb)
            total += loss.item() * xb.size(0)
            n += xb.size(0)
    return total / n

best_val = float("inf")
for epoch in range(1, cfg.epochs + 1):
    model.train()
    running_loss = 0.0
    batch_count = 0
    for xb, yb in train_loader:
        xb = xb.to(cfg.device)      # [B,2]
        yb = yb.to(cfg.device)      # [B,1]
        pred, z = model(xb)
        loss = loss_fn(pred, yb)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        global_step += 1
        running_loss += loss.item() * xb.size(0)
        batch_count += xb.size(0)
        writer.add_scalar('train/batch_loss', loss.item(), global_step)
        if global_step % 100 == 0:
            writer.add_histogram('embed/batch', z.detach().cpu().numpy(), global_step)

    epoch_loss = running_loss / batch_count if batch_count else 0.0
    writer.add_scalar('train/epoch_loss', epoch_loss, epoch)

    val_loss = evaluate()
    best_val = min(best_val, val_loss)

    writer.add_scalar('val/epoch_loss', val_loss, epoch)
    for i, param_group in enumerate(opt.param_groups):
        writer.add_scalar(f'lr/group_{i}', param_group.get('lr', 0.0), epoch)

    if epoch % 5 == 0 or epoch == 1 or epoch == cfg.epochs:
        print(f"Epoch {epoch:02d}/{cfg.epochs} | val MSE: {val_loss:.6f}")

print(f"Best val MSE: {best_val:.6f}")

[[0.76660204 0.397386   1.4832675  0.7541363 ]
 [1.5072536  2.7673707  0.5099053  0.3163967 ]
 [0.74721813 0.38037446 1.404191   0.7703118 ]
 ...
 [0.78408617 0.7962326  1.3488846  0.75545484]
 [0.72979206 0.01133008 1.4043319  0.73903817]
 [0.90415114 1.1851424  1.6042862  0.7739054 ]]
Epoch 01/25 | val MSE: 0.003577
[[2.3480918  2.5691285  2.056397   0.97382075]
 [1.0164411  0.47309586 1.3806063  0.9433281 ]
 [2.2531257  2.803103   1.8752717  1.0170593 ]
 ...
 [2.591055   2.3546014  1.5087079  1.0067024 ]
 [1.5683738  1.2432367  1.5308281  1.0022707 ]
 [2.0241208  1.7844038  1.5353068  1.0841305 ]]
Epoch 01/25 | val MSE: 0.003577
[[2.3480918  2.5691285  2.056397   0.97382075]
 [1.0164411  0.47309586 1.3806063  0.9433281 ]
 [2.2531257  2.803103   1.8752717  1.0170593 ]
 ...
 [2.591055   2.3546014  1.5087079  1.0067024 ]
 [1.5683738  1.2432367  1.5308281  1.0022707 ]
 [2.0241208  1.7844038  1.5353068  1.0841305 ]]
[[2.8172402  2.8726158  2.4462614  1.3752297 ]
 [3.1162019  3.6734495  2

In [None]:
# --------------------
# Save model + handy metadata
# --------------------
payload = {
    "state_dict": model.state_dict(),
    "config": {
        "hidden": cfg.hidden,
        "embed_dim": cfg.embed_dim,
        "target_fn": "f(x,y) = sin(||[x,y]||)",
        "data_range": cfg.data_range,
        "normalization": "none",
        "seed": cfg.seed,
    },
    "model_class": "SymmetricToyNet",
}
torch.save(payload, cfg.model_path)
print(f"Saved model to: {os.path.abspath(cfg.model_path)}")

Saved model to: /Users/noah-everett/Documents/Research/Embedding-Analysis/runs/symmetry_run_1/model.pt


In [None]:
# --------------------
# Save embeddings for a polar grid (for downstream analysis)
# --------------------
model.eval()
with torch.no_grad():
    # build a polar grid; map to xy; get embeddings z
    radii = np.linspace(0.0, cfg.data_range * math.sqrt(2), 80).astype(np.float32)
    thetas = np.linspace(0.0, 2 * math.pi, 128, endpoint=False).astype(np.float32)
    R, T = np.meshgrid(radii, thetas, indexing='ij')
    X = (R * np.cos(T)).reshape(-1)
    Y = (R * np.sin(T)).reshape(-1)
    XY = np.stack([X, Y], axis=1).astype(np.float32)

    xb = torch.from_numpy(XY).to(cfg.device)
    pred, z = model(xb)
    pred = pred.squeeze(-1).cpu().numpy()
    Z = z.cpu().numpy()

    try:
        radii_flat = R.reshape(-1)
        writer.add_embedding(torch.from_numpy(Z), metadata=[str(float(r)) for r in radii_flat], global_step=0, tag='embeddings/grid')
        if cfg.embed_dim == 2:
            fig, ax = plt.subplots(figsize=(6, 6))
            sc = ax.scatter(Z[:, 0], Z[:, 1], c=radii_flat, s=5, cmap='viridis')
            ax.set_title('Embeddings (colored by radius)')
            ax.set_xlabel('z0')
            ax.set_ylabel('z1')
            plt.colorbar(sc, ax=ax, label='radius')
            writer.add_figure('embeddings/2d_scatter', fig)
            plt.close(fig)
    except Exception as e:
        print('Warning: failed to log embeddings to TensorBoard:', e)

# Save everything needed for downstream math/plots
np.savez_compressed(
    cfg.emb_grid_path,
    xy=XY,               # [N,2] positions in input space
    z=Z,                 # [N,embed_dim] embeddings
    pred=pred,           # [N] predicted f(x)
    radii=R.reshape(-1), # [N] radius used to generate XY
    thetas=T.reshape(-1) # [N] angle used to generate XY
)
print(f"Saved embedding grid to: {os.path.abspath(cfg.emb_grid_path)}")

writer.close()

Saved embedding grid to: /Users/noah-everett/Documents/Research/Embedding-Analysis/runs/symmetry_run_1/embeddings_grid.npz
