## Define Function and its Symmetry

In [1]:
import random
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
RANDOM_SEED = 41

from constants import DTYPE, LABEL_FONT_SIZE

random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

%load_ext autoreload
%autoreload 2

N_RUNS = 3
_random_seeds = [random.randint(0, 10000) for _ in range(N_RUNS)]

## Generate Dataset

In [2]:
N_SAMPLES = 100_000
n_features = 2
LAMBDA = 1
RHO = 1

SYMMETRY_DIM = 1

cost_matrix = torch.tensor([[LAMBDA, 0], [0, RHO]], dtype=DTYPE)

def f(state, cost_matrix=cost_matrix):
    """Evaluates f:M \rightarrow N for a batch of states s and a cost_matrix.
    Args:
        state: torch.Tensor of shape (batch_size, n_features)
    Returns:
        torch.Tensor of shape (batch_size,)
    """
    return torch.einsum('bi,ij,jb->b', state, cost_matrix, state.T)

p = torch.randn(N_SAMPLES, n_features, dtype=DTYPE)
n = f(p)

## Differential Symmetry Discovery

### Step 1: Learn Basis of Kernel at each Point

In [3]:
KERNEL_DIM=1
EPSILON_BALL = 0.05
EPSILON_LEVEL_SET = 0.005

from src.learning.symmetry_discovery.differential.kernel_pprox import pointwise_kernel_approx

kernel_bases = pointwise_kernel_approx(p=p, n=n, kernel_dim=KERNEL_DIM, epsilon_ball=EPSILON_BALL, epsilon_level_set=EPSILON_LEVEL_SET)

Compute Kernel Samples...: 100000it [00:02, 37972.65it/s]
Compute Point-Wise Bases...: 100%|██████████| 100000/100000 [00:06<00:00, 14536.25it/s]
INFO:root:Computed kernel bases from:
  - multiple tangent vectors for 66.66% of samples (good)
  - one tangent vector for 13.72% of samples (okay)
  - no tangent vector for 19.62% of samples (not good, no basis).


## Compare Both Approaches

In [13]:
diff_func_generators = []
N_STEPS = 30_000

In [4]:
from src.experiments.diff_vs_func.compare_generators import DiffFuncGenerator

g_init_all = torch.rand((1,n_features,n_features), dtype=DTYPE)
g_0 = torch.nn.Parameter(g_init_all)
g_oracle = torch.tensor([[0, -1], [1, 0]], dtype=DTYPE).unsqueeze(0)

diff_func_generator = DiffFuncGenerator(
    g_0=g_0,
    p=p,
    bases=kernel_bases,
    func=f,
    batch_size=128,
    n_steps = 25_000,
    g_oracle=g_oracle
)
diff_func_generator.optimize()

Learning Differential and Functional Generator:   0%|          | 64/25000 [00:00<02:35, 160.61it/s, Diff. loss=190.81, Func. loss=190.97]

In [161]:
diff_func_generator.g_0_diff

Parameter containing:
tensor([[[ 0.0002, -0.0225],
         [ 0.0224,  0.0002]]], requires_grad=True)

In [162]:
diff_func_generator.g_0_func

Parameter containing:
tensor([[[ 3.5174e-05, -2.5424e-01],
         [ 2.5431e-01,  1.7940e-05]]], requires_grad=True)

In [None]:
from src.experiments.diff_vs_func.compare_generators import DiffFuncGenerator

for idx_run in range(N_RUNS):
    print(f"Run {idx_run+1} out of {N_RUNS}")

    g_init_all = torch.rand((1,n_features,n_features), dtype=DTYPE)
    g_0 = torch.nn.Parameter(g_init_all)
    g_oracle = torch.tensor([[0, -1], [1, 0]], dtype=DTYPE).unsqueeze(0)

    diff_func_generator = DiffFuncGenerator(
        g_0=g_0,
        p=p,
        bases=kernel_bases,
        func=f,
        batch_size=128,
        n_steps = 25_000,
        g_oracle=g_oracle
    )
    diff_func_generator.optimize()
    diff_func_generators.append(diff_func_generator)