Copyright 2022 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [15]:
# @title License
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Action-Angle Networks



In [16]:
%pwd

'/Users/ameyad/Documents/google-research/action_angle_networks'

In [17]:
%mkdir -p ../notebook_outputs

In [1]:
# @title Base Imports
from typing import *
import functools
import sys
import tempfile
import os

from absl import logging
import collections
import chex
from clu import checkpoint
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.core import frozen_dict
from flax.training import train_state
import optax
import distrax
import tensorflow as tf
import ml_collections
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib
import pysr
import yaml

PLT_STYLE_CONTEXT = ['science', 'ieee', 'grid']

sys.path.append("..")
matplotlib.rc("animation", html="jshtml")
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

In [6]:
%load_ext autoreload

In [7]:
# @title Source Imports
%autoreload 2
import harmonic_motion_simulation
import models
import train
import analysis
from configs.harmonic_motion import (action_angle_flow, action_angle_mlp, euler_update_flow, euler_update_mlp)

## Loading a Pre-Trained Model

In [63]:
# @title Location of Pretrained Model
# config_name = "euler_update_flow"
config_name = "action_angle_flow"
k_pair = "0."
num_samples = "500"
workdir = f"/Users/ameyad/Documents/google-research/workdirs/no_linear/action_angle_networks/configs/harmonic_motion/{config_name}/k_pair={k_pair}"

In [64]:
config, scaler, state, aux = analysis.load_from_workdir(workdir, default_config=action_angle_flow.get_config())

INFO:absl:Saved config found. Loading...
INFO:absl:Restoring checkpoint: /Users/ameyad/Documents/google-research/workdirs/no_linear/action_angle_networks/configs/harmonic_motion/action_angle_flow/k_pair=0./checkpoints/ckpt-1
INFO:absl:Restored save_counter=1 restored_checkpoint=/Users/ameyad/Documents/google-research/workdirs/no_linear/action_angle_networks/configs/harmonic_motion/action_angle_flow/k_pair=0./checkpoints/ckpt-1


In [65]:
train_positions = aux["train"]["positions"]
train_momentums = aux["train"]["momentums"]
train_simulation_parameters = aux["train"]["simulation_parameters"]
all_train_metrics = aux["train"]["metrics"]

test_positions = aux["test"]["positions"]
test_momentums = aux["test"]["momentums"]
test_simulation_parameters = aux["test"]["simulation_parameters"]
all_test_metrics = aux["test"]["metrics"]

In [73]:
# Setup.
jump = 1
curr_positions, curr_momentums, *_ = train.get_coordinates_for_time_jump(
    train_positions, train_momentums, jump
)

In [74]:
def compute_actions(curr_positions, curr_momentums):
    _, _, auxiliary_predictions = state.apply_fn(
        state.params, curr_positions, curr_momentums, 0
    )
    actions = auxiliary_predictions["actions"]
    return actions

In [75]:
actions = compute_actions(curr_positions, curr_momentums)
actions

DeviceArray([[1.7688025, 1.4246038],
             [1.7685826, 1.4250891],
             [1.7684244, 1.4255072],
             [1.7683202, 1.4258581],
             [1.7682613, 1.426145 ],
             [1.7682396, 1.4263697],
             [1.7682483, 1.4265366],
             [1.7682804, 1.4266493],
             [1.7683296, 1.4267136],
             [1.7683905, 1.4267342],
             [1.7684575, 1.4267173],
             [1.7685273, 1.4266689],
             [1.7685955, 1.4265945],
             [1.7686595, 1.4265007],
             [1.768717 , 1.4263927],
             [1.768766 , 1.4262767],
             [1.7688048, 1.4261578],
             [1.7688333, 1.4260402],
             [1.7688508, 1.4259286],
             [1.7688571, 1.4258262],
             [1.7688527, 1.4257367],
             [1.7688378, 1.4256616],
             [1.7688136, 1.425603 ],
             [1.7687805, 1.4255615],
             [1.76874  , 1.4255378],
             [1.7686934, 1.4255316],
             [1.7686421, 1.4255414],
 

In [81]:
def compute_action(curr_position, curr_momentum, index=1):
   print(curr_position.shape)
   curr_position = jnp.expand_dims(curr_position, axis=0)
   curr_momentum = jnp.expand_dims(curr_momentum, axis=0)
   return compute_actions(curr_position, curr_momentum)[0, index]
   

curr_positions, curr_momentums = jax.tree_map(jnp.asarray, (curr_positions, curr_momentums))
grad_actions = jax.jit(jax.vmap(jax.grad(compute_action)))(curr_positions, curr_momentums)
grad_actions

(2,)


DeviceArray([[-2.37277316e-04,  9.56650376e-01],
             [ 1.11943664e-04,  9.53731716e-01],
             [ 4.60843381e-04,  9.46089506e-01],
             [ 8.07723729e-04,  9.33763027e-01],
             [ 1.15089142e-03,  9.16813135e-01],
             [ 1.48866861e-03,  8.95322680e-01],
             [ 1.81940221e-03,  8.69394720e-01],
             [ 2.14146450e-03,  8.39154661e-01],
             [ 2.45326664e-03,  8.04747880e-01],
             [ 2.75326404e-03,  7.66340792e-01],
             [ 3.03996541e-03,  7.24118352e-01],
             [ 3.31194419e-03,  6.78285778e-01],
             [ 3.56784207e-03,  6.29066348e-01],
             [ 3.80638288e-03,  5.76700926e-01],
             [ 4.02637711e-03,  5.21446168e-01],
             [ 4.22673579e-03,  4.63575006e-01],
             [ 4.40646894e-03,  4.03373599e-01],
             [ 4.56470158e-03,  3.41141492e-01],
             [ 4.70067235e-03,  2.77188689e-01],
             [ 4.81374469e-03,  2.11835250e-01],
             [ 4.903

In [82]:
model = pysr.PySRRegressor(niterations=100)

In [83]:
# curr_positions, curr_momentums = train.inverse_transform_with_scaler(curr_positions, curr_momentums, scaler)
masses = np.tile(train_simulation_parameters['m'][np.newaxis, :], (curr_positions.shape[0], 1))

In [84]:
X = np.concatenate([curr_positions, curr_momentums], axis=1)
y = grad_actions
X.shape, y.shape

((79, 4), (79, 2))

In [85]:
model.fit(X=X, y=y)



Started!

Cycles per second: 3.380e+05
Head worker occupation: 8.4%
Progress: 933 / 3000 total iterations (31.100%)
Best equations for output 1
Hall of Fame:
-----------------------------------------
Complexity  loss       Score     Equation
1           1.367e-05  6.641e-08  0.00031428755
3           1.693e-07  2.196e+00  (x3 * -0.0036849491)
5           9.580e-08  2.847e-01  ((x3 * -0.0036817777) - -0.00027114965)
7           7.805e-08  1.025e-01  (((x1 * 0.08204619) + x3) * -0.003712801)
9           7.395e-09  1.178e+00  (((x1 * -0.0800334) - (x3 - 0.0726791)) * 0.0037089095)
11          7.374e-09  1.454e-03  ((((x1 * -0.16222996) + 0.1430803) - (x3 + x3)) * 0.0018546504)
13          7.372e-09  7.541e-05  (((((x1 + 0.20125447) * -0.16174024) - (x3 + x3)) * 0.0018546225) + 0.00032627647)
15          2.124e-09  6.222e-01  (((x1 * (-0.0800334 * (1.0879052 + (0.28728202 * x2)))) - (x3 - 0.0726791)) * 0.0037089095)
17          2.119e-09  1.119e-03  (((x1 * (-0.0800334 * (1.0879052 + ((0.2