# General Function Embedding (PyTorch)

This notebook is now generalized to learn any target function f: R^{in_dim} -> R^{out_dim} using a compact bottleneck (latent) representation.

What you can customize quickly:
- Input size (in_dim), output size (out_dim), and latent size (embed_dim)
- Target function via a simple Python callable
- Model depth/width and activations
- Sampling domain and dataset sizes

Artifacts saved per run:
- model.pt – trained weights + config metadata
- embeddings_probe.npz – random probe inputs with embeddings and predictions

Tip: Edit the `target_fn` in the setup cell to learn any function you like.

In [None]:
# --------------------
# Setup & Config
# --------------------
import math
import os
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from lib import (
    BottleneckMLP,
    OnlineFunctionDataset,
    make_fixed_val_dataset,
    train_model,
    evaluate,
    get_device,
    set_seed,
    next_run_path,
)

print("Torch:", torch.__version__)
device = get_device()
print("Device:", device)

# Create a unique run folder (keeps your old results intact)
_run_path = next_run_path(base_dir='runs', prefix='symmetry_run_')

@dataclass
class Config:
    # Data/function
    in_dim: int = 2
    out_dim: int = 1
    embed_dim: int = 4
    data_range: float = 3.0   # sample x in [-R, R]^in_dim

    # Model
    hidden: int = 64
    enc_layers: int = 2
    dec_layers: int = 2

    # Training
    seed: int = 42
    n_train: int = 50_000
    n_val: int = 5_000
    batch_size: int = 256
    lr: float = 1e-3
    weight_decay: float = 1e-4
    epochs: int = 25

    # IO
    device: torch.device = device
    model_path: str = os.path.join(_run_path, "model.pt")
    probe_path: str = os.path.join(_run_path, "embeddings_probe.npz")

cfg = Config()
set_seed(cfg.seed)

# Define your target function here. It must map [N, in_dim] -> [N, out_dim].
# Example: f(x) = sin(||x||) returning a single output regardless of in_dim.
def target_fn(x: torch.Tensor) -> torch.Tensor:
    r = torch.linalg.norm(x, dim=1, keepdim=True)
    return torch.sin(r)  # shape [N,1]

writer = SummaryWriter(log_dir=_run_path)
writer.add_text(
    'config',
    str(
        dict(
            in_dim=cfg.in_dim,
            out_dim=cfg.out_dim,
            embed_dim=cfg.embed_dim,
            hidden=cfg.hidden,
            enc_layers=cfg.enc_layers,
            dec_layers=cfg.dec_layers,
            data_range=cfg.data_range,
            seed=cfg.seed,
        )
    ),
)


Torch: 2.8.0
Device: mps


In [None]:
# --------------------
# Data: generic function via online sampling
# --------------------
train_dataset = OnlineFunctionDataset(
    n=cfg.n_train,
    in_dim=cfg.in_dim,
    target_fn=target_fn,
    data_range=cfg.data_range,
)
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=False, drop_last=True)

val_dataset = make_fixed_val_dataset(
    n_val=cfg.n_val,
    in_dim=cfg.in_dim,
    target_fn=target_fn,
    data_range=cfg.data_range,
)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)

print(f"Training: online sampling, {cfg.n_train} samples/epoch; Validation: {cfg.n_val} samples")


Training: online sampling, 50000 samples/epoch; Validation: 5000 samples


In [None]:
# --------------------
# Model: Encoder -> Embedding -> Decoder (generic)
# --------------------
model = BottleneckMLP(
    in_dim=cfg.in_dim,
    out_dim=cfg.out_dim,
    embed_dim=cfg.embed_dim,
    hidden=cfg.hidden,
    enc_layers=cfg.enc_layers,
    dec_layers=cfg.dec_layers,
).to(cfg.device)
model

SymmetricToyNet(
  (enc): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): GELU(approximate='none')
  )
  (embed): Linear(in_features=64, out_features=4, bias=True)
  (dec): Sequential(
    (0): GELU(approximate='none')
    (1): Linear(in_features=4, out_features=64, bias=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [None]:
# --------------------
# Training
# --------------------
result = train_model(
    model,
    train_loader,
    val_loader,
    epochs=cfg.epochs,
    lr=cfg.lr,
    weight_decay=cfg.weight_decay,
    writer=writer,
    device=cfg.device,
    log_hist_every=100,
)

print(f"Best val MSE: {result['best_val']:.6f}")


[[0.76660204 0.397386   1.4832675  0.7541363 ]
 [1.5072536  2.7673707  0.5099053  0.3163967 ]
 [0.74721813 0.38037446 1.404191   0.7703118 ]
 ...
 [0.78408617 0.7962326  1.3488846  0.75545484]
 [0.72979206 0.01133008 1.4043319  0.73903817]
 [0.90415114 1.1851424  1.6042862  0.7739054 ]]
Epoch 01/25 | val MSE: 0.003577
[[2.3480918  2.5691285  2.056397   0.97382075]
 [1.0164411  0.47309586 1.3806063  0.9433281 ]
 [2.2531257  2.803103   1.8752717  1.0170593 ]
 ...
 [2.591055   2.3546014  1.5087079  1.0067024 ]
 [1.5683738  1.2432367  1.5308281  1.0022707 ]
 [2.0241208  1.7844038  1.5353068  1.0841305 ]]
Epoch 01/25 | val MSE: 0.003577
[[2.3480918  2.5691285  2.056397   0.97382075]
 [1.0164411  0.47309586 1.3806063  0.9433281 ]
 [2.2531257  2.803103   1.8752717  1.0170593 ]
 ...
 [2.591055   2.3546014  1.5087079  1.0067024 ]
 [1.5683738  1.2432367  1.5308281  1.0022707 ]
 [2.0241208  1.7844038  1.5353068  1.0841305 ]]
[[2.8172402  2.8726158  2.4462614  1.3752297 ]
 [3.1162019  3.6734495  2

In [None]:
# --------------------
# Save model + handy metadata
# --------------------
payload = {
    "state_dict": model.state_dict(),
    "config": {
        "in_dim": cfg.in_dim,
        "out_dim": cfg.out_dim,
        "hidden": cfg.hidden,
        "embed_dim": cfg.embed_dim,
        "enc_layers": cfg.enc_layers,
        "dec_layers": cfg.dec_layers,
        "data_range": cfg.data_range,
        "seed": cfg.seed,
        "target_fn": getattr(target_fn, "__name__", str(target_fn)),
    },
    "model_class": "BottleneckMLP",
}
torch.save(payload, cfg.model_path)
print(f"Saved model to: {os.path.abspath(cfg.model_path)}")


Saved model to: /Users/noah-everett/Documents/Research/Embedding-Analysis/runs/symmetry_run_1/model.pt


In [None]:
# --------------------
# Save probe embeddings for downstream analysis (works for any in/out/embed sizes)
# --------------------
model.eval()
with torch.no_grad():
    # Random probe inputs in the training domain
    N_probe = 10240
    X = np.random.uniform(-cfg.data_range, cfg.data_range, size=(N_probe, cfg.in_dim)).astype(np.float32)

    xb = torch.from_numpy(X).to(cfg.device)
    pred, z = model(xb)
    pred = pred.cpu().numpy()
    Z = z.cpu().numpy()

    try:
        # Log embeddings to TensorBoard projector (works for any dimensionality)
        writer.add_embedding(torch.from_numpy(Z), global_step=0, tag='embeddings/probe')
        if cfg.embed_dim == 2:
            import matplotlib.pyplot as plt
            r = np.linalg.norm(X, axis=1)
            fig, ax = plt.subplots(figsize=(6, 6))
            sc = ax.scatter(Z[:, 0], Z[:, 1], c=r, s=5, cmap='viridis')
            ax.set_title('Embeddings (colored by ||x||)')
            ax.set_xlabel('z0')
            ax.set_ylabel('z1')
            plt.colorbar(sc, ax=ax, label='||x||')
            writer.add_figure('embeddings/2d_scatter', fig)
            plt.close(fig)
    except Exception as e:
        print('Warning: failed to log embeddings to TensorBoard:', e)

# Save everything needed for downstream math/plots (general)
np.savez_compressed(
    cfg.probe_path,
    x=X,                # [N,in_dim] inputs in input space
    z=Z,                # [N,embed_dim] embeddings
    pred=pred           # [N,out_dim] predicted f(x)
)
print(f"Saved probe embeddings to: {os.path.abspath(cfg.probe_path)}")

writer.close()


Saved embedding grid to: /Users/noah-everett/Documents/Research/Embedding-Analysis/runs/symmetry_run_1/embeddings_grid.npz
