## Load Data

In [1]:
import pickle

from src.utils import load_replay_buffer
from src.learning.symmetry_discovery.differential.kernel_approx import KernelFrameEstimator

TASK_NAMES=["sac_circle_rotation_task_0", "sac_circle_rotation_task_1", "sac_circle_rotation_task_2", "sac_circle_rotation_task_3"]

LOAD_WHAT:str="next_observations"
KERNEL_DIM=1
N_SAMPLES=50_000


def load_replay_buffer_and_kernel(task_name:str, load_what:str, kernel_dim: int, n_samples:int):
    """Loads samples and kernel evaluator of a task."""

    assert load_what in ["observations", "actions", "next_observations"], "Learn hereditary geometry for states, actions or next states."

    buffer_name= f"{task_name}_replay_buffer.pkl"
    kernel_name= f"{task_name}_kernel_bases.pkl"

    buffer= load_replay_buffer(buffer_name, N_steps=n_samples)
    ps=buffer[load_what]
    print(f"Loaded {load_what} from {buffer_name} with shape {ps.shape}")

    # Load kernel bases
    frameestimator=KernelFrameEstimator(ps=ps, kernel_dim=kernel_dim)
    with open(kernel_name, 'rb') as f:
        kernel_samples = pickle.load(f)
    frameestimator.set_frame(frame=kernel_samples)

    return ps, frameestimator

tasks_ps, tasks_frameestimators=[], []
for task_name in TASK_NAMES:
    ps, frameestimator = load_replay_buffer_and_kernel(task_name, LOAD_WHAT, KERNEL_DIM, N_SAMPLES)
    tasks_ps.append(ps)
    tasks_frameestimators.append(frameestimator)

2025-07-07 16:50:17.335821: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Loaded next_observations from sac_circle_rotation_task_0_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


Loaded next_observations from sac_circle_rotation_task_1_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


Loaded next_observations from sac_circle_rotation_task_2_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


Loaded next_observations from sac_circle_rotation_task_3_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


In [11]:
%load_ext autoreload
%autoreload 2

import torch
from src.learning.symmetry_discovery.differential.hereditary_geometry_discovery import HereditaryGeometryDiscovery
ORACLE_GENERATOR=torch.tensor([[0, -1], [1,0]], dtype=torch.float32, requires_grad=False).unsqueeze(0)
train_goal_locations=[{'goal': torch.tensor([1,0])},
    # {'goal': torch.tensor([-0.70506063,  0.70914702])},
 {'goal': torch.tensor([ 0.95243384, -0.30474544])},
 {'goal': torch.tensor([-0.11289421, -0.99360701])},
 {'goal': torch.tensor([-0.81394263, -0.58094525])}]


class Affine2D(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)

    def forward(self, x):
        return self.linear(x)

SEED=42

ENCODER=Affine2D(input_dim=2, output_dim=2)
DECODER=Affine2D(input_dim=2, output_dim=2)

her_geo_dis=HereditaryGeometryDiscovery(tasks_ps=tasks_ps,
                                        tasks_frameestimators=tasks_frameestimators, 
                                        kernel_dim=KERNEL_DIM, 
                                        batch_size=128, 
                                        seed=SEED, 
                                        bandwidth=0.5,
                                        learn_left_actions=False,
                                        learn_encoder_decoder=True,
                                        task_specifications=train_goal_locations,
                                        oracle_generator=ORACLE_GENERATOR,
                                        encoder=ENCODER,
                                        decoder=DECODER)
her_geo_dis.optimize(n_steps=250_000)



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Symmetry Loss: 5.63e-03, Reconstruction Loss: 3.00e-01:  38%|███▊      | 95454/250000 [02:53<04:40, 550.36it/s]


KeyboardInterrupt: 

In [12]:
her_geo_dis.encoder.state_dict()

OrderedDict([('linear.weight',
              tensor([[ 0.1944,  0.1592],
                      [-0.0635, -0.0523]])),
             ('linear.bias', tensor([-0.5660, -0.0261]))])