## Load replay buffer

In [6]:
import pickle
from src.utils import load_replay_buffer

TASK_NAME="sac_circle_rotation_task_0"
N_SAMPLES:int=100_000
KERNEL_DIM=1
EPSILON_BALL = 0.025
EPSILON_LEVEL_SET = 0.0025

LEARN_KERNEL_BASES: bool=True


replay_buffer_name:str=TASK_NAME+"_replay_buffer.pkl"
kernel_bases_name:str=TASK_NAME+"_kernel_bases.pkl"


replay_buffer_task_1= load_replay_buffer(replay_buffer_name, N_steps=N_SAMPLES)

## Pointwise kernel bases of reward function

- We have $R(p)=n$ where $p=(s,a)$ and $n$ is the real-valued reward value.
- Learn the component of $G$ that acts on $S$ and the component of $G$ that acts on $A$ independently.

In [7]:
from src.learning.symmetry_discovery.differential.kernel_approx import pointwise_kernel_approx


ps=replay_buffer_task_1["observations"]
ns=replay_buffer_task_1["rewards"]

print("Shape of ps: ", ps.shape, " (should be (N_steps, |S|))")
print("Shape of ns: ", ns.shape, " (should be (N_steps))")

if LEARN_KERNEL_BASES:
    kernel_samples=pointwise_kernel_approx(p=ps, n=ns, kernel_dim=KERNEL_DIM, epsilon_ball=EPSILON_BALL, epsilon_level_set=EPSILON_LEVEL_SET)
    with open(kernel_bases_name, 'wb') as f:
        pickle.dump(kernel_samples, f)
else:
    with open(kernel_bases_name, 'rb') as f:
        kernel_samples = pickle.load(f)  

INFO:root:Computing neighborhood of samples via kdtree...


Shape of ps:  torch.Size([100000, 2])  (should be (N_steps, |S|))
Shape of ns:  torch.Size([100000])  (should be (N_steps))


Locate samples in neighborhood...: 100%|██████████| 100000/100000 [00:32<00:00, 3073.26it/s]
Compute pointwise kernel samples...: 100%|██████████| 100000/100000 [00:09<00:00, 10002.81it/s]
Compute Point-Wise Bases via PCA...: 100%|██████████| 100000/100000 [00:09<00:00, 10113.73it/s]
INFO:root:Computed kernel bases from:
  - multiple tangent vectors for 92.94% of samples (good)
  - one tangent vector for 2.01% of samples (okay)
  - no tangent vector for 5.05% of samples (not good, no basis).


## Learn Grassman Subspace

In [13]:
import torch
import torch.nn.functional as F

def kernel_weights(p_query, ps, bandwidth):
    # Gaussian kernel weights between p_query and each p_i in ps
    dists = torch.norm(ps - p_query, dim=1)
    weights = torch.exp(-dists**2 / (2 * bandwidth**2))
    weights = weights / weights.sum()
    return weights

def log_map_grassmann(V_ref, V):
    # V_ref: (d, k), V: (d, k), both orthonormal
    M = V_ref.T @ V  # (k, k)
    U, S, Vt = torch.linalg.svd(M)
    # Numerical stability for arccos
    S_clamped = torch.clamp(S, -1.0 + 1e-6, 1.0 - 1e-6)
    Theta = torch.arccos(S_clamped)
    sin_Theta = torch.sin(Theta)
    sin_Theta[sin_Theta == 0] = 1e-6
    A = (V - V_ref @ M) @ torch.diag(Theta / sin_Theta)
    return A  # Tangent vector in T_{V_ref} Gr(k, d)

def exp_map_grassmann(V_ref, A):
    # A: tangent vector in T_{V_ref} Gr(k, d), shape (d, k)
    U, S, Vt = torch.linalg.svd(A, full_matrices=False)
    Theta = S
    sin_Theta = torch.sin(Theta)
    cos_Theta = torch.cos(Theta)
    term1 = V_ref @ (Vt.T @ torch.diag(cos_Theta) @ Vt)
    term2 = U @ torch.diag(sin_Theta) @ Vt
    V_new = term1 + term2
    return torch.linalg.qr(V_new, mode='reduced').Q  # Ensure orthonormal output

def smooth_subspace(p_query, ps, kernel_samples, bandwidth):
    N = ps.shape[0]
    
    # Select nearest kernel_samples indices
    available_indices = list(kernel_samples.keys())
    V_bases = [kernel_samples[i] for i in available_indices]
    ps_subset = ps[available_indices]  # (n_subset, d)

    # Kernel weights
    weights = kernel_weights(p_query, ps_subset, bandwidth)  # (n_subset,)
    
    # Reference basis (e.g., nearest neighbor)
    ref_idx = torch.argmin(torch.norm(ps_subset - p_query, dim=1)).item()
    V_ref = V_bases[ref_idx]  # (d, k)
    
    # Compute log maps and weighted sum
    tangent_sum = torch.zeros_like(V_ref)
    for i, V_i in enumerate(V_bases):
        W = log_map_grassmann(V_ref, V_i)  # (d, k)
        tangent_sum += weights[i] * W
    
    # Project back using exp map
    V_smooth = exp_map_grassmann(V_ref, tangent_sum)  # (d, k)
    return V_smooth

In [14]:
p_query=torch.tensor([0.5, 0.7])

V_smooth = smooth_subspace(p_query, ps, kernel_samples, bandwidth=0.25)  # (d, k)



  M = V_ref.T @ V  # (k, k)


RuntimeError: linalg.svd: The input tensor A must have at least 2 dimensions.

In [8]:
import torch as th
from src.learning.symmetry_discovery.differential.diff_generator import DiffGenerator

g_0=th.randn(KERNEL_DIM, ps.shape[1], ps.shape[1], requires_grad=True)
optimizer= th.optim.Adam([g_0], lr=5e-5)
N_steps=10_000

linear_kernel=DiffGenerator(g_0=g_0, p=ps, bases=kernel_samples, n_steps=N_steps, optimizer=optimizer, batch_size=256)
linear_kernel.optimize()

Loss: 2.0795e+02: 100%|██████████| 10000/10000 [00:21<00:00, 470.47it/s]


tensor([[[-1.0462, -0.0700],
         [-0.6525,  0.9691]]], requires_grad=True)

In [5]:
linear_kernel.g

tensor([[[ 0.0775,  0.2187],
         [-0.5813, -0.4310]]], requires_grad=True)