In [3]:
import os
import wandb
import torch
from datetime import datetime

from src.learning.symmetry.hereditary_geometry_discovery import HereditaryGeometryDiscovery
from argparser import get_argparser, get_non_default_args
from src.utils import load_replay_buffer_and_kernel, Affine2D

FOLDER_NAME: str="data/local/experiment/circle_rotation"
TASK_NAMES=["sac_circle_rotation_task_0", "sac_circle_rotation_task_1", "sac_circle_rotation_task_2", "sac_circle_rotation_task_3"]

parser = get_argparser()
args = parser.parse_args(["--log_wandb", "false"])


train_goal_locations=[
    {'goal': torch.tensor([-0.70506063,  0.70914702])},
    {'goal': torch.tensor([ 0.95243384, -0.30474544])},
    {'goal': torch.tensor([-0.11289421, -0.99360701])},
    {'goal': torch.tensor([-0.81394263, -0.58094525])}]

LOAD_WHAT:str="next_observations"
N_SAMPLES=50_000
ENCODER=Affine2D(input_dim=2, output_dim=2)
DECODER=Affine2D(input_dim=2, output_dim=2)
ORACLE_ENCODER=Affine2D(input_dim=2, output_dim=2)
ORACLE_DECODER=Affine2D(input_dim=2, output_dim=2)

ORACLE_GENERATOR=torch.tensor([[0, -1], [1,0]], dtype=torch.float32, requires_grad=False).unsqueeze(0)

with torch.no_grad():
    ORACLE_ENCODER.linear.weight.copy_(torch.eye(2))
    ORACLE_DECODER.linear.weight.copy_(torch.eye(2))
    ORACLE_ENCODER.linear.bias.copy_(-train_goal_locations[0]["goal"])
    ORACLE_DECODER.linear.bias.copy_(train_goal_locations[0]["goal"])

# 1. Load replay buffers and frame estimators.
tasks_ps, tasks_frameestimators=[], []
for task_name in TASK_NAMES:
    ps, frameestimator = load_replay_buffer_and_kernel(task_name, LOAD_WHAT, args.kernel_dim, N_SAMPLES, FOLDER_NAME)
    tasks_ps.append(ps)
    tasks_frameestimators.append(frameestimator)



oracle_generator=ORACLE_GENERATOR if not args.learn_generator else None


# 3. Train.
her_geo_dis=HereditaryGeometryDiscovery(tasks_ps=tasks_ps,tasks_frameestimators=tasks_frameestimators, 
                                        oracle_generator=oracle_generator, encoder=ENCODER, decoder=DECODER,

                                        kernel_dim=args.kernel_dim, n_steps_pretrain_geo=args.n_steps_pretrain_geo,
                                        update_chart_every_n_steps=args.update_chart_every_n_steps, eval_span_how=args.eval_span_how,
                                        log_lg_inits_how=args.log_lg_inits_how,

                                        batch_size=args.batch_size, 
                                        lr_lgs=args.lr_lgs,lr_gen=args.lr_gen,lr_chart=args.lr_chart,
                                        lasso_coef_lgs=args.lasso_coef_lgs, lasso_coef_generator=args.lasso_coef_generator, lasso_coef_encoder_decoder=args.lasso_coef_encoder_decoder,
                                        
                                        seed=args.seed, log_wandb=args.log_wandb, log_wandb_gradients=args.log_wandb_gradients, save_every=args.save_every,
                                        bandwidth=args.bandwidth,

                                        task_specifications=train_goal_locations, 
                                        use_oracle_rotation_kernel=args.use_oracle_rotation_kernel,
                                        save_dir=None,

                                        eval_sym_in_follower=args.eval_sym_in_follower,
                                        oracle_encoder=ORACLE_ENCODER, oracle_decoder=ORACLE_DECODER
                                        )



Loaded next_observations from data/local/experiment/circle_rotation/sac_circle_rotation_task_0_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


Loaded next_observations from data/local/experiment/circle_rotation/sac_circle_rotation_task_1_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


Loaded next_observations from data/local/experiment/circle_rotation/sac_circle_rotation_task_2_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


Loaded next_observations from data/local/experiment/circle_rotation/sac_circle_rotation_task_3_replay_buffer.pkl with shape torch.Size([100000, 2])


INFO:root:Setup kernel frame evaluation.


In [5]:
def rotation_vector_field(p_batch: torch.tensor, center)->torch.tensor:
    """Returns kernel samples at batched points p from a task."""

    _generator=torch.tensor([[0, -1], [1,0]], requires_grad=False, dtype=torch.float32).unsqueeze(0)
    projected_state = p_batch-center
    gradients = torch.einsum("dmn, bn->bdm", _generator, projected_state)
    return gradients/gradients.norm(dim=-1, keepdim=True)

ps = torch.randn(100, 2)
center = train_goal_locations[0]["goal"]
gradients = rotation_vector_field(ps, center)

In [8]:
gradients.shape

torch.Size([100, 1, 2])