In [1]:
import physics_environments
import numpy as np
from physics_environments import CartPendulumPhysicsConstants, CartPendulumPhysicsEnvironmentParams, CartPendulumPhysicsEnvironment, \
                                 RLCartPendulumEnvironment, RLCartPendulumEnvironmentParams
from jumanji.types import TimeStep
from physics_environments.envs.rl_pendulum.types import Observation, ObservationUtils
import jax
import jax.numpy as jnp
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px

In [2]:
print('Importable Modules: \n')
dir(physics_environments)

Importable Modules: 



['CartPendulumPhysicsConstants',
 'CartPendulumPhysicsEnvironment',
 'CartPendulumPhysicsEnvironmentParams',
 'RLCartPendulumEnvironment',
 'RLCartPendulumEnvironmentParams',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '__version__',
 'envs',
 'register',
 'types',
 'version']

In [3]:
n = 3
constants = CartPendulumPhysicsConstants(n    = n,         
                                         g    = 9.81,       
                                         l    = n*[0.300],
                                         r    = n*[0.200],
                                         m    = n*[0.800],
                                         I    = n*[0.011],
                                         mu   = n*[0.015])
#print(constants.to_dict())

env_type             = 'compound pendulum'
cart_pendulum_params = CartPendulumPhysicsEnvironmentParams(constants       = constants,         
                                                            env_type        = 'compound pendulum',      
                                                            save_after_init = True, 
                                                            save_filename   = f"env_{n}n_{env_type}_obj.pkl") 
cart_pendulum_params
#print(cart_pendulum_params.to_dict())

rl_env_params = RLCartPendulumEnvironmentParams(physics_env_params          = cart_pendulum_params,
                                                training_generator_type     = "DefaultGenerator",
                                                dt                          = 0.010,
                                                t_episode                   = 10.00,
                                                max_cart_acceleration       = 25.00,
                                                cart_width                  = 0.250,
                                                track_width                 = 0.800
                                                )
#rl_env_params.to_dict()

In [4]:
# env = CartPendulumPhysicsEnvironment(params = cart_pendulum_params)
# print(env)
# env.math_summary()

In [5]:
rl_env = RLCartPendulumEnvironment(params = rl_env_params)
rl_env


    ** RLCartPendulumEnvironment **
        A simple and flexible n-rodded RL cart pendulum environment.

    Environment Information:
    - action:
        - 1 continuous action a_x (x accelearation of the cart)
    - reward:
        - mean of all rod cosines (1 reward if all rods standing, -1 reward if all rods hanging)
        - additional -15 penalty if bumping into the environment boundaries
    - episode termination:
        - if at least one cart bumps into the environment boundaries
        - and on reaching the termination time t_episode
    - state:
        - Information for used for env.step, and to calculate observations.
        - only should have impact on the environment dynamics
    - observation:
        - what the agent will see and the agent's acting is based on
        - will be fully calculated from the environment state
        - 4 base features + 9n theta features + 9n rod2cart distance features = 18n + 4 features
    

In [6]:
key = jax.random.key(seed=42)
state, timestep = rl_env.reset(key)

In [7]:
next_state, next_timestep = rl_env.step(state = state, action=3)

False False


### Test viewer functions

#### One State

In [6]:
state

State(t=Array(0., dtype=float32), step_num=Array(0, dtype=int32), key=Array([4146024105,  967050713], dtype=uint32), solver_state=(Array(True, dtype=bool), Array([0., 0., 0., 0., 0.], dtype=float32)), y_solver=Array([3.1415927, 3.1415927, 0.       , 0.       , 0.       ], dtype=float32), thetas=Array([3.1415927, 3.1415927], dtype=float32), dthetas=Array([0., 0.], dtype=float32), ddthetas=Array([0., 0.], dtype=float32), s_x=Array(0., dtype=float32), v_x=Array(0., dtype=float32), a_x=Array(0., dtype=float32))

In [None]:
env.viewer._get_raw_state_data_dict([state])

{'t': array([Array(0., dtype=float32)], dtype=object),
 'step_num': array([Array(0, dtype=int32)], dtype=object),
 'key': array([[4146024105, 967050713]], dtype=object),
 'solver_state': array([[Array(True, dtype=bool),
         Array([0., 0., 0., 0., 0.], dtype=float32)]], dtype=object),
 'y_solver': array([[3.1415927410125732, 3.1415927410125732, 0.0, 0.0, 0.0]],
       dtype=object),
 'thetas': array([[3.1415927410125732, 3.1415927410125732]], dtype=object),
 'dthetas': array([[0.0, 0.0]], dtype=object),
 'ddthetas': array([[0.0, 0.0]], dtype=object),
 's_x': array([Array(0., dtype=float32)], dtype=object),
 'v_x': array([Array(0., dtype=float32)], dtype=object),
 'a_x': array([Array(0., dtype=float32)], dtype=object)}

In [None]:
env.viewer._get_prepared_state_data_dict([state])

{'$$t$$': array([Array(0., dtype=float32)], dtype=object),
 'step_num': array([Array(0, dtype=int32)], dtype=object),
 'key': [array([4146024105, 967050713], dtype=object)],
 'solver_state': [array([Array(True, dtype=bool),
         Array([0., 0., 0., 0., 0.], dtype=float32)], dtype=object)],
 'y_solver': [array([3.1415927410125732, 3.1415927410125732, 0.0, 0.0, 0.0],
        dtype=object)],
 '$$\\theta_{1}$$': array([3.1415927410125732], dtype=object),
 '$$\\theta_{2}$$': array([3.1415927410125732], dtype=object),
 '$$\\dot{\\theta_1}$$': array([0.0], dtype=object),
 '$$\\dot{\\theta_2}$$': array([0.0], dtype=object),
 '$$\\ddot{\\theta_1}$$': array([0.0], dtype=object),
 '$$\\ddot{\\theta_2}$$': array([0.0], dtype=object),
 '$$x_c$$': array([Array(0., dtype=float32)], dtype=object),
 '$$\\dot{x_c}$$': array([Array(0., dtype=float32)], dtype=object),
 '$$\\ddot{x_c}$$': array([Array(0., dtype=float32)], dtype=object)}

In [None]:
env.viewer.get_state_dataframe([state])

Unnamed: 0_level_0,step_num,key,solver_state,y_solver,$$\theta_{1}$$,$$\theta_{2}$$,$$\dot{\theta_1}$$,$$\dot{\theta_2}$$,$$\ddot{\theta_1}$$,$$\ddot{\theta_2}$$,$$x_c$$,$$\dot{x_c}$$,$$\ddot{x_c}$$
$$t$$,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0.0,0,"[4146024105, 967050713]","[True, [0.0, 0.0, 0.0, 0.0, 0.0]]","[3.1415927410125732, 3.1415927410125732, 0.0, ...",3.141593,3.141593,0.0,0.0,0.0,0.0,0.0,0.0,0.0


#### n States

In [None]:
env.viewer._get_raw_state_data_dict([state, state, state])

{'t': array([Array(0., dtype=float32), Array(0., dtype=float32),
        Array(0., dtype=float32)], dtype=object),
 'step_num': array([Array(0, dtype=int32), Array(0, dtype=int32),
        Array(0, dtype=int32)], dtype=object),
 'key': array([[4146024105, 967050713],
        [4146024105, 967050713],
        [4146024105, 967050713]], dtype=object),
 'solver_state': array([[Array(True, dtype=bool),
         Array([0., 0., 0., 0., 0.], dtype=float32)],
        [Array(True, dtype=bool),
         Array([0., 0., 0., 0., 0.], dtype=float32)],
        [Array(True, dtype=bool),
         Array([0., 0., 0., 0., 0.], dtype=float32)]], dtype=object),
 'y_solver': array([[3.1415927410125732, 3.1415927410125732, 0.0, 0.0, 0.0],
        [3.1415927410125732, 3.1415927410125732, 0.0, 0.0, 0.0],
        [3.1415927410125732, 3.1415927410125732, 0.0, 0.0, 0.0]],
       dtype=object),
 'thetas': array([[3.1415927410125732, 3.1415927410125732],
        [3.1415927410125732, 3.1415927410125732],
        [3.141

In [None]:
env.viewer._get_prepared_state_data_dict([state, state, state])

{'$$t$$': array([Array(0., dtype=float32), Array(0., dtype=float32),
        Array(0., dtype=float32)], dtype=object),
 'step_num': array([Array(0, dtype=int32), Array(0, dtype=int32),
        Array(0, dtype=int32)], dtype=object),
 'key': [array([4146024105, 967050713], dtype=object),
  array([4146024105, 967050713], dtype=object),
  array([4146024105, 967050713], dtype=object)],
 'solver_state': [array([Array(True, dtype=bool),
         Array([0., 0., 0., 0., 0.], dtype=float32)], dtype=object),
  array([Array(True, dtype=bool),
         Array([0., 0., 0., 0., 0.], dtype=float32)], dtype=object),
  array([Array(True, dtype=bool),
         Array([0., 0., 0., 0., 0.], dtype=float32)], dtype=object)],
 'y_solver': [array([3.1415927410125732, 3.1415927410125732, 0.0, 0.0, 0.0],
        dtype=object),
  array([3.1415927410125732, 3.1415927410125732, 0.0, 0.0, 0.0],
        dtype=object),
  array([3.1415927410125732, 3.1415927410125732, 0.0, 0.0, 0.0],
        dtype=object)],
 '$$\\theta_{

In [None]:
env.viewer.get_state_dataframe([state, state, state])

Unnamed: 0_level_0,step_num,key,solver_state,y_solver,$$\theta_{1}$$,$$\theta_{2}$$,$$\dot{\theta_1}$$,$$\dot{\theta_2}$$,$$\ddot{\theta_1}$$,$$\ddot{\theta_2}$$,$$x_c$$,$$\dot{x_c}$$,$$\ddot{x_c}$$
$$t$$,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0.0,0,"[4146024105, 967050713]","[True, [0.0, 0.0, 0.0, 0.0, 0.0]]","[3.1415927410125732, 3.1415927410125732, 0.0, ...",3.141593,3.141593,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0,"[4146024105, 967050713]","[True, [0.0, 0.0, 0.0, 0.0, 0.0]]","[3.1415927410125732, 3.1415927410125732, 0.0, ...",3.141593,3.141593,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0,"[4146024105, 967050713]","[True, [0.0, 0.0, 0.0, 0.0, 0.0]]","[3.1415927410125732, 3.1415927410125732, 0.0, ...",3.141593,3.141593,0.0,0.0,0.0,0.0,0.0,0.0,0.0


---

## Simulating and Rendering an Entire Trajectory

In [2]:
n = 3
constants = CartPendulumPhysicsConstants(n    = n,         
                                         g    = 9.81,        
                                         l    = n*[0.300],
                                         r    = n*[0.200],
                                         m    = n*[0.800],
                                         I    = n*[0.011],
                                         mu   = n*[0.015])
#print(constants.to_dict())

env_type             = 'compound pendulum'
cart_pendulum_params = CartPendulumPhysicsEnvironmentParams(constants       = constants,         
                                                            env_type        = 'compound pendulum',      
                                                            save_after_init = False, 
                                                            save_filename   = f"'/home/hbeyer/RL_Pendulum_Project/experiments/bu/env_{n}n_{env_type}_obj.pkl")

rl_env_params = RLCartPendulumEnvironmentParams(physics_env_params          = cart_pendulum_params,
                                                training_generator_type     = "DefaultGenerator",
                                                dt                          = 0.010,
                                                t_episode                   = 10.00,
                                                max_cart_acceleration       = 25.00,
                                                cart_width                  = 0.250,
                                                track_width                 = 0.800
                                                )

rl_env = RLCartPendulumEnvironment(params = rl_env_params)
key = jax.random.key(seed=42)
state, timestep = rl_env.reset(key)
#next_state, next_timestep = rl_env.step(state = state, action=3)

In [3]:
t = jnp.arange(0, 10 + 0.01, 0.01)
t

Array([ 0.  ,  0.01,  0.02, ...,  9.98,  9.99, 10.  ], dtype=float32)

In [4]:
def rollout(batch_size=2, rollout_length=3):
    # Note: setting device=jax.devices('cpu') as default argument actually hurts performance slightly, having device without defualt argument helps to fix this.

    #t = jnp.arange(0, 10 + 0.01, 0.01)
    def step_func(state, key):
        #action = jax.random.randint(key=key, minval=0, maxval=num_actions, shape=(1))
        action = jnp.sin(state.t)*15
        new_state, timestep = rl_env.step(state, action)
        output = timestep, new_state
        return new_state, output

    def rollout_func(state, key, n):
        random_keys = jax.random.split(key, n)
        state, output = jax.lax.scan(step_func, state, random_keys)
        return state, output

    # Constants
    num_actions = rl_env.action_spec.shape[0]

    # Create random Keys
    master_key = jax.random.PRNGKey(0)
    key1, key2 = jax.random.split(master_key) # create a separate random key for step() and reset()

    # Instantiate a batch of environment states (via vmap reset)
    keys = jax.random.split(key1, batch_size)
    state, timestep = jax.vmap(rl_env.reset)(keys)

    # Collect a batch of rollouts (via vmap rollout_func)
    keys = jax.random.split(key2, batch_size)
    output = jax.vmap(rollout_func, in_axes=(0, 0, None))(state, keys, rollout_length)
    return output

In [5]:
state_data, rollout_data=rollout(batch_size=1, rollout_length=1000)
#rollout_data

In [6]:
state_data

State(t=Array([10.0001335], dtype=float32), step_count=Array([1000], dtype=int32), key=Array([[2869462989, 2323445670]], dtype=uint32), solver_state=(Array([False], dtype=bool), Array([[ 1.7385747,  1.9778969,  2.1770155, -3.827721 , -4.304127 ,
        -5.9128604, -8.116454 ]], dtype=float32)), y_solver=Array([[ 4.0232058,  4.0768714,  4.1378994,  1.7385747,  1.9778969,
         2.1770155, -8.116454 ]], dtype=float32), thetas=Array([[4.0232058, 4.0768714, 4.1378994]], dtype=float32), dthetas=Array([[1.7385747, 1.9778969, 2.1770155]], dtype=float32), ddthetas=Array([[-3.6451578, -4.1853666, -5.8470726]], dtype=float32), s_x=Array([158.30066], dtype=float32), v_x=Array([27.625526], dtype=float32), a_x=Array([-8.035738], dtype=float32))

In [7]:
rollout_data[0].keys()

dict_keys(['step_type', 'reward', 'discount', 'observation', 'extras'])

In [8]:
type(rollout_data)

tuple

In [9]:
pd.DataFrame(rollout_data[0].observation.agent_inputs_global[0])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,48,49,50,51,52,53,54,55,56,57
0,0.000000,0.000000,0.000000,0.275000,3.141593,3.141593,3.141593,-3.078255e-08,2.326893e-09,-1.424134e-10,...,-7.471779e-14,0.3,0.600000,0.900000,-9.234767e-17,-8.536698e-17,-8.579423e-17,-9.234766e-15,-8.536698e-15,-8.579422e-15
1,0.000022,0.001500,0.149998,0.274978,3.141571,3.141594,3.141593,-4.881404e-03,3.642537e-04,-2.243835e-05,...,1.037192e-05,0.3,0.600000,0.900000,-1.464390e-11,-1.355113e-11,-1.361845e-11,-1.464309e-09,-1.355033e-09,-1.361764e-09
2,0.000082,0.004500,0.299980,0.274918,3.141474,3.141601,3.141593,-1.508038e-02,1.015747e-03,-4.139400e-05,...,1.050796e-04,0.3,0.600000,0.900000,-4.523575e-11,-4.218848e-11,-4.231266e-11,-3.058645e-09,-2.863193e-09,-2.868879e-09
3,0.000195,0.008999,0.449932,0.274805,3.141248,3.141615,3.141592,-3.044766e-02,1.775313e-03,-2.245880e-05,...,4.382589e-04,0.3,0.600000,0.900000,-9.131153e-11,-8.598548e-11,-8.605285e-11,-4.605816e-09,-4.377932e-09,-4.372251e-09
4,0.000375,0.014997,0.599840,0.274625,3.140845,3.141637,3.141592,-5.078099e-02,2.407992e-03,7.268818e-05,...,1.232509e-03,0.3,0.600000,0.900000,-1.522290e-10,-1.450047e-10,-1.447866e-10,-6.087692e-09,-5.897862e-09,-5.869318e-09
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,157.192429,27.939301,-7.522794,-156.917435,3.950920,3.994577,4.046248,1.868543e+00,2.129860e+00,2.400698e+00,...,2.059559e-01,0.3,0.599857,0.899317,7.925285e-09,1.694089e-08,2.705385e-08,-1.142710e-08,-2.606778e-08,-5.043094e-08
996,157.470673,27.862778,-7.652191,-157.195679,3.969466,4.015707,4.069992,1.839952e+00,2.095747e+00,2.347544e+00,...,-1.703483e-01,0.3,0.599840,0.899241,7.799216e-09,1.665575e-08,2.651387e-08,-1.272875e-08,-2.879593e-08,-5.444930e-08
997,157.748138,27.784969,-7.780823,-157.473145,3.987713,4.036484,4.093194,1.808802e+00,2.059039e+00,2.292399e+00,...,-5.446686e-01,0.3,0.599822,0.899164,7.659955e-09,1.634341e-08,2.593521e-08,-1.403387e-08,-3.147979e-08,-5.824209e-08
998,158.024796,27.705883,-7.908676,-157.749802,4.005635,4.056880,4.115835,1.775026e+00,2.019751e+00,2.235486e+00,...,-9.153764e-01,0.3,0.599803,0.899088,7.507522e-09,1.600441e-08,2.532017e-08,-1.533580e-08,-3.410593e-08,-6.180282e-08


In [10]:
def get_rollout_obs_dataframe(n : int, rollout_data : TimeStep):
    """ Get a DataFrame from a single observation with batched arrays. """
    data = rollout_data.observation.agent_inputs_global[0]
    columns = ObservationUtils.get_agent_inputs_index2str_mapping(n=n).values()
    df = pd.DataFrame(data=data, columns=columns)
    return df


In [11]:
df = get_rollout_obs_dataframe(n=n, rollout_data=rollout_data[0])
df

Unnamed: 0,$$x_c$$,$$\dot{{x_c}}$$,$$\ddot{{x_c}}$$,$$d_{{corner}}$$,$$\theta_{1}$$,$$\theta_{2}$$,$$\theta_{3}$$,$$\dot{\theta_{1}}$$,$$\dot{\theta_{2}}$$,$$\dot{\theta_{3}}$$,...,$$\ddot{d^y_{r_3-c}}$$,$$d_{r_1-c}$$,$$d_{r_2-c}$$,$$d_{r_3-c}$$,$$\dot{d_{r_1-c}}$$,$$\dot{d_{r_2-c}}$$,$$\dot{d_{r_3-c}}$$,$$\ddot{d_{r_1-c}}$$,$$\ddot{d_{r_2-c}}$$,$$\ddot{d_{r_3-c}}$$
0,0.000000,0.000000,0.000000,0.275000,3.141593,3.141593,3.141593,-3.078255e-08,2.326893e-09,-1.424134e-10,...,-7.471779e-14,0.3,0.600000,0.900000,-9.234767e-17,-8.536698e-17,-8.579423e-17,-9.234766e-15,-8.536698e-15,-8.579422e-15
1,0.000022,0.001500,0.149998,0.274978,3.141571,3.141594,3.141593,-4.881404e-03,3.642537e-04,-2.243835e-05,...,1.037192e-05,0.3,0.600000,0.900000,-1.464390e-11,-1.355113e-11,-1.361845e-11,-1.464309e-09,-1.355033e-09,-1.361764e-09
2,0.000082,0.004500,0.299980,0.274918,3.141474,3.141601,3.141593,-1.508038e-02,1.015747e-03,-4.139400e-05,...,1.050796e-04,0.3,0.600000,0.900000,-4.523575e-11,-4.218848e-11,-4.231266e-11,-3.058645e-09,-2.863193e-09,-2.868879e-09
3,0.000195,0.008999,0.449932,0.274805,3.141248,3.141615,3.141592,-3.044766e-02,1.775313e-03,-2.245880e-05,...,4.382589e-04,0.3,0.600000,0.900000,-9.131153e-11,-8.598548e-11,-8.605285e-11,-4.605816e-09,-4.377932e-09,-4.372251e-09
4,0.000375,0.014997,0.599840,0.274625,3.140845,3.141637,3.141592,-5.078099e-02,2.407992e-03,7.268818e-05,...,1.232509e-03,0.3,0.600000,0.900000,-1.522290e-10,-1.450047e-10,-1.447866e-10,-6.087692e-09,-5.897862e-09,-5.869318e-09
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,157.192429,27.939301,-7.522794,-156.917435,3.950920,3.994577,4.046248,1.868543e+00,2.129860e+00,2.400698e+00,...,2.059559e-01,0.3,0.599857,0.899317,7.925285e-09,1.694089e-08,2.705385e-08,-1.142710e-08,-2.606778e-08,-5.043094e-08
996,157.470673,27.862778,-7.652191,-157.195679,3.969466,4.015707,4.069992,1.839952e+00,2.095747e+00,2.347544e+00,...,-1.703483e-01,0.3,0.599840,0.899241,7.799216e-09,1.665575e-08,2.651387e-08,-1.272875e-08,-2.879593e-08,-5.444930e-08
997,157.748138,27.784969,-7.780823,-157.473145,3.987713,4.036484,4.093194,1.808802e+00,2.059039e+00,2.292399e+00,...,-5.446686e-01,0.3,0.599822,0.899164,7.659955e-09,1.634341e-08,2.593521e-08,-1.403387e-08,-3.147979e-08,-5.824209e-08
998,158.024796,27.705883,-7.908676,-157.749802,4.005635,4.056880,4.115835,1.775026e+00,2.019751e+00,2.235486e+00,...,-9.153764e-01,0.3,0.599803,0.899088,7.507522e-09,1.600441e-08,2.532017e-08,-1.533580e-08,-3.410593e-08,-6.180282e-08


In [14]:
#df.columns

In [19]:
def reward(angles):
    angles       = angles - (angles // (jnp.pi*2))*jnp.pi*2   # Convert n*[-2pi; +2pi] net angles --> to [0, 2pi], while accounting for n net angles (-inf; +inf)
    angles       = (angles + jnp.pi) % (2 * jnp.pi) - jnp.pi  # Convert [0, 2pi] --> [-pi, pi]
    rod_rewards  = jnp.cos(angles)                            # [-pi, pi] angles --> [-1, 1] rewards (standing rods have angle of 0, hanging rods hang with angle of pi)
    reward       = jnp.mean(rod_rewards)                      # ranges of [-1, 1]

    reward       = reward
    return reward

In [23]:
reward(angles = jnp.array([0,1, jnp.pi*3]))

Array(0.18010068, dtype=float32)

In [12]:
cols = ['$$x_c$$', r'$$\dot{{x_c}}$$', r'$$\ddot{{x_c}}$$', r'$$d_{{corner}}$$',
       r'$$\theta_{1}$$', r'$$\theta_{2}$$', r'$$\theta_{3}$$', '$$r$$']
rl_env.viewer.plot_episode_data_matplotlib(df = df, col_names=cols)
rl_env.viewer.plot_episode_data_plotly(df = df, col_names=cols)

KeyError: "['$$r$$'] not in index"

In [20]:
import plotly.express

In [24]:
%matplotlib widget



In [25]:
import matplotlib.pyplot as plt

plt.plot([1, 2, 3], [4, 5, 6])
plt.title("Test Plot")
plt.show()

<IPython.core.display.Javascript object>

In [12]:
s = jnp.array([2869462989, 2323445670])
s = s.at[..., -1].set(3)
s

Array([-1425504307,           3], dtype=int32)