## Load Data

In [2]:
import pickle
import os

from src.utils import load_replay_buffer
from src.learning.symmetry_discovery.differential.kernel_approx import KernelFrameEstimator

FOLDER_NAME: str="data/local/experiment/circle_rotation"
TASK_NAMES=["sac_circle_rotation_task_0", "sac_circle_rotation_task_1", "sac_circle_rotation_task_2", "sac_circle_rotation_task_3"]

LOAD_WHAT:str="next_observations"
KERNEL_DIM=1
N_SAMPLES=50_000


def load_replay_buffer_and_kernel(task_name:str, load_what:str, kernel_dim: int, n_samples:int, folder_name):
    """Loads samples and kernel evaluator of a task."""

    assert load_what in ["observations", "actions", "next_observations"], "Learn hereditary geometry for states, actions or next states."

    buffer_name= os.path.join(folder_name, f"{task_name}_replay_buffer.pkl")
    kernel_name= os.path.join(folder_name, f"{task_name}_kernel_bases.pkl")


    buffer= load_replay_buffer(buffer_name, N_steps=n_samples)
    ps=buffer[load_what]
    print(f"Loaded {load_what} from {buffer_name} with shape {ps.shape}")

    # Load kernel bases
    frameestimator=KernelFrameEstimator(ps=ps, kernel_dim=kernel_dim)
    with open(kernel_name, 'rb') as f:
        kernel_samples = pickle.load(f)
    frameestimator.set_frame(frame=kernel_samples)

    return ps, frameestimator

tasks_ps, tasks_frameestimators=[], []
for task_name in TASK_NAMES:
    ps, frameestimator = load_replay_buffer_and_kernel(task_name, LOAD_WHAT, KERNEL_DIM, N_SAMPLES, FOLDER_NAME)
    tasks_ps.append(ps)
    tasks_frameestimators.append(frameestimator)



Loaded next_observations from data/local/experiment/circle_rotation/sac_circle_rotation_task_0_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


Loaded next_observations from data/local/experiment/circle_rotation/sac_circle_rotation_task_1_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


Loaded next_observations from data/local/experiment/circle_rotation/sac_circle_rotation_task_2_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


Loaded next_observations from data/local/experiment/circle_rotation/sac_circle_rotation_task_3_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


In [14]:
%load_ext autoreload
%autoreload 2

import torch
from tqdm import tqdm
from src.learning.symmetry_discovery.differential.hereditary_geometry_discovery import HereditaryGeometryDiscovery
ORACLE_GENERATOR=torch.tensor([[0, -1], [1,0]], dtype=torch.float32, requires_grad=False).unsqueeze(0)
train_goal_locations=[
    {'goal': torch.tensor([-0.70506063,  0.70914702])},
 {'goal': torch.tensor([ 0.95243384, -0.30474544])},
 {'goal': torch.tensor([-0.11289421, -0.99360701])},
 {'goal': torch.tensor([-0.81394263, -0.58094525])}]

class Affine2D(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)

    def forward(self, x):
        return self.linear(x)

SEED=42
LEARN_LEFT_ACTIONS=False
LEARN_GENERATOR=False

ENCODER=Affine2D(input_dim=2, output_dim=2)
DECODER=Affine2D(input_dim=2, output_dim=2)

def identity_initialization_encoder_decoder(encoder: callable, 
                                            decoder: callable, 
                                            tasks_ps: list,
                                            n_steps: int = 5_000):
    """Initializes encoder and decoder to identity map on all tasks via gradient flow."""

    def stack_samples(ps: list):
        """Stacks samples from all tasks into a single tensor."""
        _n_tasks = len(tasks_ps)
        _n_samples_per_task, ambient_dim = tasks_ps[0].shape
        ps = torch.empty([_n_tasks, _n_samples_per_task, ambient_dim], dtype=torch.float32)
        for i, task_ps in enumerate(tasks_ps):
            ps[i] = task_ps
        return ps.reshape([-1, ambient_dim])

    ps = stack_samples(tasks_ps)
    print(f"Stacked samples shape: {ps.shape}")

    encoder_opt = torch.optim.Adam(encoder.parameters(), lr=1e-3)
    decoder_opt = torch.optim.Adam(decoder.parameters(), lr=1e-3)

    pbar = tqdm(range(n_steps), desc="Initializing to identity")

    for step in pbar:
        encoder_opt.zero_grad()
        decoder_opt.zero_grad()

        encoded_ps = encoder(ps)
        decoded_ps = decoder(ps)

        enc_loss = torch.nn.functional.mse_loss(encoded_ps, ps)
        dec_loss = torch.nn.functional.mse_loss(decoded_ps, ps)
        total_loss = enc_loss + dec_loss
        total_loss.backward()

        encoder_opt.step()
        decoder_opt.step()

        if step%100==0:
            pbar.set_postfix({
                "enc_loss": f"{enc_loss.item():.4e}",
                "dec_loss": f"{dec_loss.item():.4e}",
                "total": f"{total_loss.item():.4e}"
            })
    return encoder, decoder
    

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
encoder_init, decoder_init= identity_initialization_encoder_decoder(encoder=ENCODER, decoder=DECODER, tasks_ps=tasks_ps)

Stacked samples shape: torch.Size([400000, 2])


Initializing to identity: 100%|██████████| 5000/5000 [00:18<00:00, 268.43it/s, enc_loss=9.7672e-14, dec_loss=1.4319e-12, total=1.5295e-12]


In [None]:
her_geo_dis=HereditaryGeometryDiscovery(tasks_ps=tasks_ps,
                                        tasks_frameestimators=tasks_frameestimators, 
                                        kernel_dim=KERNEL_DIM, 
                                        batch_size=128, 
                                        seed=SEED, 
                                        bandwidth=0.5,
                                        learn_left_actions=False,
                                        learn_encoder_decoder=True,
                                        task_specifications=train_goal_locations,
                                        learn_left_actions=LEARN_LEFT_ACTIONS,
                                        learn_generator=LEARN_GENERATOR,
                                        oracle_generator=ORACLE_GENERATOR,
                                        encoder=encoder_init,
                                        decoder=decoder_init)
her_geo_dis.optimize(n_steps=250_000)