
# Symmetric Function Embedding (PyTorch)

This notebook trains a tiny encoder–decoder MLP on a **rotationally symmetric** target function $f(\mathbf{x}) = \sin(|\mathbf{x}|)$.

It includes a low-dimensional **embedding bottleneck** you can analyze, and saves:
- `toy_symmetry_model.pt` – the trained model (state dict + metadata)
- `embeddings_grid.npz` – a polar grid of inputs with the model's embeddings and predictions

> Tip: tweak `cfg.embed_dim` and the target function to explore different symmetries.


In [None]:

# --------------------
# Setup & Config
# --------------------
import math
import os
import random
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

print("Torch:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

@dataclass
class Config:
    seed: int = 42
    n_train: int = 50_000
    n_val: int = 5_000
    batch_size: int = 256
    lr: float = 1e-3
    epochs: int = 25
    hidden: int = 64
    embed_dim: int = 2   # embedding space size to analyze
    data_range: float = 3.0  # sample x,y uniformly in [-R, R]
    device: str = device
    model_path: str = "toy_symmetry_model.pt"
    emb_grid_path: str = "embeddings_grid.npz"

cfg = Config()

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(cfg.seed)


In [None]:

# --------------------
# Data: rotational symmetry
# Target: f(x, y) = sin(sqrt(x^2 + y^2))
# --------------------
def sample_xy(n, R):
    # Uniform in square [-R, R]^2 (not uniform over radius; fine for training)
    xy = np.random.uniform(-R, R, size=(n, 2)).astype(np.float32)
    r = np.linalg.norm(xy, axis=1, keepdims=True).astype(np.float32)
    y = np.sin(r).astype(np.float32)  # scalar target
    return xy, y

x_train, y_train = sample_xy(cfg.n_train, cfg.data_range)
x_val, y_val = sample_xy(cfg.n_val, cfg.data_range)

train_loader = DataLoader(
    TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)),
    batch_size=cfg.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val)),
    batch_size=1024, shuffle=False
)

len(x_train), len(x_val)


In [None]:

# --------------------
# Model: Encoder -> Embedding -> Decoder
# --------------------
class SymmetricToyNet(nn.Module):
    def __init__(self, in_dim=2, hidden=64, embed_dim=2):
        super().__init__()
        # encoder
        self.enc = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, hidden),
            nn.GELU(),
        )
        # embedding (bottleneck)
        self.embed = nn.Linear(hidden, embed_dim)

        # decoder
        self.dec = nn.Sequential(
            nn.GELU(),
            nn.Linear(embed_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, 1),
        )

        # Kaiming init
        def init(m):
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
                if m.bias is not None:
                    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
                    bound = 1 / math.sqrt(fan_in)
                    nn.init.uniform_(m.bias, -bound, bound)
        self.apply(init)

    def forward(self, x):
        h = self.enc(x)
        z = self.embed(h)           # <-- embedding to analyze
        out = self.dec(z)
        return out, z

model = SymmetricToyNet(hidden=cfg.hidden, embed_dim=cfg.embed_dim).to(cfg.device)
model


In [None]:

# --------------------
# Training
# --------------------
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=1e-4)
loss_fn = nn.MSELoss()

def evaluate():
    model.eval()
    total, n = 0.0, 0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(cfg.device)
            yb = yb.to(cfg.device)
            pred, _ = model(xb)
            loss = loss_fn(pred, yb)
            total += loss.item() * xb.size(0)
            n += xb.size(0)
    return total / n

best_val = float("inf")
for epoch in range(1, cfg.epochs + 1):
    model.train()
    for xb, yb in train_loader:
        xb = xb.to(cfg.device)      # [B,2]
        yb = yb.to(cfg.device)      # [B,1]
        pred, _ = model(xb)
        loss = loss_fn(pred, yb)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

    val_loss = evaluate()
    best_val = min(best_val, val_loss)
    if epoch % 5 == 0 or epoch == 1 or epoch == cfg.epochs:
        print(f"Epoch {epoch:02d}/{cfg.epochs} | val MSE: {val_loss:.6f}")

print(f"Best val MSE: {best_val:.6f}")


In [None]:

# --------------------
# Save model + handy metadata
# --------------------
payload = {
    "state_dict": model.state_dict(),
    "config": {
        "hidden": cfg.hidden,
        "embed_dim": cfg.embed_dim,
        "target_fn": "f(x,y) = sin(||[x,y]||)",
        "data_range": cfg.data_range,
        "normalization": "none",
        "seed": cfg.seed,
    },
    "model_class": "SymmetricToyNet",
}
torch.save(payload, cfg.model_path)
print(f"Saved model to: {os.path.abspath(cfg.model_path)}")


In [None]:

# --------------------
# Save embeddings for a polar grid (for downstream analysis)
# --------------------
model.eval()
with torch.no_grad():
    # build a polar grid; map to xy; get embeddings z
    radii = np.linspace(0.0, cfg.data_range * math.sqrt(2), 80).astype(np.float32)
    thetas = np.linspace(0.0, 2 * math.pi, 128, endpoint=False).astype(np.float32)
    R, T = np.meshgrid(radii, thetas, indexing='ij')
    X = (R * np.cos(T)).reshape(-1)
    Y = (R * np.sin(T)).reshape(-1)
    XY = np.stack([X, Y], axis=1).astype(np.float32)

    xb = torch.from_numpy(XY).to(cfg.device)
    pred, z = model(xb)
    pred = pred.squeeze(-1).cpu().numpy()
    Z = z.cpu().numpy()

# Save everything needed for downstream math/plots
np.savez_compressed(
    cfg.emb_grid_path,
    xy=XY,               # [N,2] positions in input space
    z=Z,                 # [N,embed_dim] embeddings
    pred=pred,           # [N] predicted f(x)
    radii=R.reshape(-1), # [N] radius used to generate XY
    thetas=T.reshape(-1) # [N] angle used to generate XY
)
print(f"Saved embedding grid to: {os.path.abspath(cfg.emb_grid_path)}")



## How to load and analyze later

```python
import numpy as np
import torch
import torch.nn as nn

# Recreate model and load weights
payload = torch.load('toy_symmetry_model.pt', map_location='cpu')
class SymmetricToyNet(nn.Module):
    def __init__(self, in_dim=2, hidden=payload['config']['hidden'], embed_dim=payload['config']['embed_dim']):
        super().__init__()
        self.enc = nn.Sequential(nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, hidden), nn.GELU())
        self.embed = nn.Linear(hidden, embed_dim)
        self.dec = nn.Sequential(nn.GELU(), nn.Linear(embed_dim, hidden), nn.GELU(), nn.Linear(hidden, 1))
    def forward(self, x):
        h = self.enc(x); z = self.embed(h); out = self.dec(z); return out, z

model = SymmetricToyNet()
model.load_state_dict(payload['state_dict'])
model.eval()

# Embedding grid
data = np.load('embeddings_grid.npz')
Z = data['z']  # analyze this!
```

For symmetry probes, you can check whether `Z` is approximately invariant to rotations in input space, or measure how much the embedding collapses angular information vs. radial information.
