# Comparison Reinforcement Learning vs. CFE Approach

In this notebook, the results of using Reinforcement Learning for controlling Burgers' Equation are compared to a supervised control force estimator approach using a differentiable physics loss function, which was proposed by Holl et al. [2020](https://ge.in.tum.de/publications/2020-iclr-holl). Both methods use the differentiable PDE solver [Φ<sub>Flow</sub>](https://github.com/tum-pbs/PhiFlow).
The reinforcement learning method also uses the [stable_baselines3](https://github.com/DLR-RM/stable-baselines3) RL framework and works with the [PPO](https://arxiv.org/abs/1707.06347v2) learning algorithm.

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard
import sys; sys.path.append('../src'); sys.path.append('../PDE-Control/PhiFlow'); sys.path.append('../PDE-Control/src')
from phi.flow import *
import burgers_plots as bplt
import matplotlib.pyplot as plt
from envs.burgers_util import GaussianClash, GaussianForce

## Data generation

At first, we will generate a suite of test scenes on which both methods will be evaluated.

In [2]:
domain = Domain([128], box=box[0:1])
viscosity = 0.003
step_count = 32
dt = 0.03
diffusion_substeps = 1

data_path = 'forced-burgers-clash-128'
scene_count = 1000
batch_size = 100

train_range = range(200, 1000)
val_range = range(100, 200)
test_range = range(0, 100)

In [5]:
for batch_index in range(scene_count // batch_size):
    scene = Scene.create(data_path, count=batch_size)
    print(scene)
    world = World()
    u0 = BurgersVelocity(
        domain, 
        velocity=GaussianClash(batch_size, rank=domain.rank), 
        viscosity=viscosity, 
        batch_size=batch_size, 
        name='burgers'
    )
    u = world.add(u0, physics=Burgers(diffusion_substeps=diffusion_substeps))
    force = world.add(FieldEffect(GaussianForce(batch_size, rank=domain.rank), ['velocity']))
    scene.write(world.state, frame=0)
    for frame in range(1, step_count + 1):
        world.step(dt=dt)
        scene.write(world.state, frame=frame)


Failed to copy calling script to scene during Scene.create().
Cause: [Errno 2] No such file or directory: '<ipython-input-5-db5f8738772b>'

forced-burgers-clash-128/sim_000000
forced-burgers-clash-128/sim_000100
forced-burgers-clash-128/sim_000200
forced-burgers-clash-128/sim_000300
forced-burgers-clash-128/sim_000400
forced-burgers-clash-128/sim_000500
forced-burgers-clash-128/sim_000600
forced-burgers-clash-128/sim_000700
forced-burgers-clash-128/sim_000800
forced-burgers-clash-128/sim_000900


## Reinforcement Learning initialization

In [3]:
from experiment import BurgersTraining


Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.


Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.


Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.


Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.


Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.


Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.



In [4]:
n_envs = 10 # On how many environments to train in parallel, load balancing
final_reward_factor = step_count # How hard to punish the agent for not reaching the goal if that is the case
steps_per_rollout = step_count * 10 # How many steps to collect per environment between agent updates
n_epochs = 10 # How many epochs to perform during agent update
learning_rate = 1e-4 # Learning rate for agent updates
rl_batch_size = 128 # Batch size for agent updates

To start training, we create a trainer object, which manages the environment and the agent internally. Additionally, a directory for storing models, logs, and hyperparameters is created. This way, training can be continued at any later point using the same configuration. If the model folder specified in exp_name already exists, the agent within is loaded. Otherwise, a new agent is created

In [11]:
rl_trainer = BurgersTraining(
    exp_name='../networks/rl-models/time_bench',
    domain=domain,
    viscosity=viscosity,
    step_count=step_count,
    dt=dt,
    diffusion_substeps=diffusion_substeps,
    n_envs=n_envs,
    final_reward_factor=final_reward_factor,
    steps_per_rollout=steps_per_rollout,
    n_epochs=n_epochs,
    learning_rate=learning_rate,
    batch_size=rl_batch_size,
    test_path=None,#data_path,
    test_range=None,#test_range,
)

Called super constructor
Creating new agent...
Using cuda device


Now we are set up to start training the agent. The next line will take quite some time to execute, so grab a coffee or take your dog for a walk or so.

n_rollouts denotes the length of the training

save_freq specifies the number of epochs after which the stored model is overwritten


In [12]:
# Execute to run tensorboard
%tensorboard --logdir ../networks/rl-models/time_bench/tensorboard-log

Reusing TensorBoard on port 6006 (pid 11108), started 0:08:16 ago. (Use '!kill 11108' to kill it.)

In [13]:
rl_trainer.train(n_rollouts=100, save_freq=50)

resetto
Logging to ../networks/rl-models/time_bench/tensorboard-log/PPO_1
0.0
not storing time data on first rollout
average rollout collection time: 4.071154
average update time: nan
ratio forward to backward: nan
-----------------------------------------
| forces                  | 1164.4761   |
| learning time average   | nan         |
| ratio collection to ... | nan         |
| rew_unnormalized        | -287899.25  |
| rollout collection t... | 4.07        |
| rollout/                |             |
|    ep_len_mean          | 32.0        |
|    ep_rew_mean          | -0.72167295 |
| time/                   |             |
|    fps                  | 763         |
|    iterations           | 1           |
|    time_elapsed         | 4           |
|    total_timesteps      | 3200        |
-----------------------------------------



Mean of empty slice.


invalid value encountered in double_scalars



4.071153879165649
average rollout collection time: 4.078453
average update time: 7.583037
ratio forward to backward: 0.537839
------------------------------------------
| forces                  | 1156.9114    |
| learning time average   | 7.58         |
| ratio collection to ... | 0.538        |
| rew_unnormalized        | -284517.44   |
| rollout collection t... | 4.08         |
| rollout/                |              |
|    ep_len_mean          | 32.0         |
|    ep_rew_mean          | 0.008123274  |
| time/                   |              |
|    fps                  | 403          |
|    iterations           | 2            |
|    time_elapsed         | 15           |
|    total_timesteps      | 6400         |
| train/                  |              |
|    approx_kl            | 0.0075462926 |
|    clip_fraction        | 0.0162       |
|    clip_range           | 0.2          |
|    entropy_loss         | -45.4        |
|    explained_variance   | -5.99e+06    |
|    learning_

4.147509813308716
average rollout collection time: 4.103279
average update time: 8.003593
ratio forward to backward: 0.512680
------------------------------------------
| forces                  | 1183.27      |
| learning time average   | 8            |
| ratio collection to ... | 0.513        |
| rew_unnormalized        | -297732.97   |
| rollout collection t... | 4.1          |
| rollout/                |              |
|    ep_len_mean          | 32.0         |
|    ep_rew_mean          | -0.55520266  |
| time/                   |              |
|    fps                  | 284          |
|    iterations           | 9            |
|    time_elapsed         | 101          |
|    total_timesteps      | 28800        |
| train/                  |              |
|    approx_kl            | 0.0072862506 |
|    clip_fraction        | 0.0395       |
|    clip_range           | 0.2          |
|    entropy_loss         | -45.4        |
|    explained_variance   | 0.977        |
|    learning_

4.076853036880493
average rollout collection time: 4.083103
average update time: 7.972750
ratio forward to backward: 0.512132
-----------------------------------------
| forces                  | 1121.0164   |
| learning time average   | 7.97        |
| ratio collection to ... | 0.512       |
| rew_unnormalized        | -214912.77  |
| rollout collection t... | 4.08        |
| rollout/                |             |
|    ep_len_mean          | 32.0        |
|    ep_rew_mean          | 0.4909426   |
| time/                   |             |
|    fps                  | 276         |
|    iterations           | 16          |
|    time_elapsed         | 185         |
|    total_timesteps      | 51200       |
| train/                  |             |
|    approx_kl            | 0.006124487 |
|    clip_fraction        | 0.0473      |
|    clip_range           | 0.2         |
|    entropy_loss         | -45.5       |
|    explained_variance   | 0.98        |
|    learning_rate        | 0.0001

4.007997274398804
average rollout collection time: 4.058409
average update time: 7.877952
ratio forward to backward: 0.515160
----------------------------------------
| forces                  | 1079.7203  |
| learning time average   | 7.88       |
| ratio collection to ... | 0.515      |
| rew_unnormalized        | -154709.77 |
| rollout collection t... | 4.06       |
| rollout/                |            |
|    ep_len_mean          | 32.0       |
|    ep_rew_mean          | 1.2291064  |
| time/                   |            |
|    fps                  | 275        |
|    iterations           | 23         |
|    time_elapsed         | 266        |
|    total_timesteps      | 73600      |
| train/                  |            |
|    approx_kl            | 0.00680285 |
|    clip_fraction        | 0.0634     |
|    clip_range           | 0.2        |
|    entropy_loss         | -45.5      |
|    explained_variance   | 0.981      |
|    learning_rate        | 0.0001     |
|    loss    

3.9911489486694336
average rollout collection time: 4.059899
average update time: 7.901537
ratio forward to backward: 0.513811
-----------------------------------------
| forces                  | 1042.7782   |
| learning time average   | 7.9         |
| ratio collection to ... | 0.514       |
| rew_unnormalized        | -110745.16  |
| rollout collection t... | 4.06        |
| rollout/                |             |
|    ep_len_mean          | 32.0        |
|    ep_rew_mean          | 1.7675034   |
| time/                   |             |
|    fps                  | 273         |
|    iterations           | 30          |
|    time_elapsed         | 351         |
|    total_timesteps      | 96000       |
| train/                  |             |
|    approx_kl            | 0.008543223 |
|    clip_fraction        | 0.111       |
|    clip_range           | 0.2         |
|    entropy_loss         | -45.4       |
|    explained_variance   | 0.974       |
|    learning_rate        | 0.000

4.2192816734313965
average rollout collection time: 4.058688
average update time: 7.917558
ratio forward to backward: 0.512619
------------------------------------------
| forces                  | 1029.3018    |
| learning time average   | 7.92         |
| ratio collection to ... | 0.513        |
| rew_unnormalized        | -81915.58    |
| rollout collection t... | 4.06         |
| rollout/                |              |
|    ep_len_mean          | 32.0         |
|    ep_rew_mean          | 2.0790033    |
| time/                   |              |
|    fps                  | 271          |
|    iterations           | 37           |
|    time_elapsed         | 435          |
|    total_timesteps      | 118400       |
| train/                  |              |
|    approx_kl            | 0.0054564024 |
|    clip_fraction        | 0.0833       |
|    clip_range           | 0.2          |
|    entropy_loss         | -45.4        |
|    explained_variance   | 0.967        |
|    learning

4.053824424743652
average rollout collection time: 4.069026
average update time: 7.950445
ratio forward to backward: 0.511799
------------------------------------------
| forces                  | 1006.32086   |
| learning time average   | 7.95         |
| ratio collection to ... | 0.512        |
| rew_unnormalized        | -51469.434   |
| rollout collection t... | 4.07         |
| rollout/                |              |
|    ep_len_mean          | 32.0         |
|    ep_rew_mean          | 2.5353262    |
| time/                   |              |
|    fps                  | 270          |
|    iterations           | 44           |
|    time_elapsed         | 521          |
|    total_timesteps      | 140800       |
| train/                  |              |
|    approx_kl            | 0.0061504305 |
|    clip_fraction        | 0.088        |
|    clip_range           | 0.2          |
|    entropy_loss         | -45.4        |
|    explained_variance   | 0.957        |
|    learning_

4.039433002471924
average rollout collection time: 4.066724
average update time: 7.938277
ratio forward to backward: 0.512293
-----------------------------------------
| forces                  | 1019.0291   |
| learning time average   | 7.94        |
| ratio collection to ... | 0.512       |
| rew_unnormalized        | -40707.793  |
| rollout collection t... | 4.07        |
| rollout/                |             |
|    ep_len_mean          | 32.0        |
|    ep_rew_mean          | 2.5704074   |
| time/                   |             |
|    fps                  | 270         |
|    iterations           | 51          |
|    time_elapsed         | 604         |
|    total_timesteps      | 163200      |
| train/                  |             |
|    approx_kl            | 0.008712062 |
|    clip_fraction        | 0.113       |
|    clip_range           | 0.2         |
|    entropy_loss         | -45.3       |
|    explained_variance   | 0.953       |
|    learning_rate        | 0.0001

4.075203895568848
average rollout collection time: 4.058731
average update time: 7.933061
ratio forward to backward: 0.511622
-----------------------------------------
| forces                  | 1021.2356   |
| learning time average   | 7.93        |
| ratio collection to ... | 0.512       |
| rew_unnormalized        | -30374.803  |
| rollout collection t... | 4.06        |
| rollout/                |             |
|    ep_len_mean          | 32.0        |
|    ep_rew_mean          | 2.6493168   |
| time/                   |             |
|    fps                  | 269         |
|    iterations           | 58          |
|    time_elapsed         | 687         |
|    total_timesteps      | 185600      |
| train/                  |             |
|    approx_kl            | 0.005810341 |
|    clip_fraction        | 0.116       |
|    clip_range           | 0.2         |
|    entropy_loss         | -45.4       |
|    explained_variance   | 0.931       |
|    learning_rate        | 0.0001

4.066716432571411
average rollout collection time: 4.054453
average update time: 7.909863
ratio forward to backward: 0.512582
-----------------------------------------
| forces                  | 1018.24634  |
| learning time average   | 7.91        |
| ratio collection to ... | 0.513       |
| rew_unnormalized        | -27864.365  |
| rollout collection t... | 4.05        |
| rollout/                |             |
|    ep_len_mean          | 32.0        |
|    ep_rew_mean          | 2.5600004   |
| time/                   |             |
|    fps                  | 270         |
|    iterations           | 65          |
|    time_elapsed         | 769         |
|    total_timesteps      | 208000      |
| train/                  |             |
|    approx_kl            | 0.011810575 |
|    clip_fraction        | 0.122       |
|    clip_range           | 0.2         |
|    entropy_loss         | -45.3       |
|    explained_variance   | 0.897       |
|    learning_rate        | 0.0001

3.922801971435547
average rollout collection time: 4.053519
average update time: 7.910181
ratio forward to backward: 0.512443
-----------------------------------------
| forces                  | 1010.54193  |
| learning time average   | 7.91        |
| ratio collection to ... | 0.512       |
| rew_unnormalized        | -25226.418  |
| rollout collection t... | 4.05        |
| rollout/                |             |
|    ep_len_mean          | 32.0        |
|    ep_rew_mean          | 2.4983768   |
| time/                   |             |
|    fps                  | 269         |
|    iterations           | 72          |
|    time_elapsed         | 853         |
|    total_timesteps      | 230400      |
| train/                  |             |
|    approx_kl            | 0.011305554 |
|    clip_fraction        | 0.122       |
|    clip_range           | 0.2         |
|    entropy_loss         | -45.2       |
|    explained_variance   | 0.908       |
|    learning_rate        | 0.0001

4.066629886627197
average rollout collection time: 4.051788
average update time: 7.874072
ratio forward to backward: 0.514573
-----------------------------------------
| forces                  | 992.9084    |
| learning time average   | 7.87        |
| ratio collection to ... | 0.515       |
| rew_unnormalized        | -20908.352  |
| rollout collection t... | 4.05        |
| rollout/                |             |
|    ep_len_mean          | 32.0        |
|    ep_rew_mean          | 2.5045664   |
| time/                   |             |
|    fps                  | 270         |
|    iterations           | 79          |
|    time_elapsed         | 934         |
|    total_timesteps      | 252800      |
| train/                  |             |
|    approx_kl            | 0.011603858 |
|    clip_fraction        | 0.127       |
|    clip_range           | 0.2         |
|    entropy_loss         | -45.2       |
|    explained_variance   | 0.874       |
|    learning_rate        | 0.0001

4.157096862792969
average rollout collection time: 4.056193
average update time: 7.869388
ratio forward to backward: 0.515440
------------------------------------------
| forces                  | 966.5408     |
| learning time average   | 7.87         |
| ratio collection to ... | 0.515        |
| rew_unnormalized        | -16952.451   |
| rollout collection t... | 4.06         |
| rollout/                |              |
|    ep_len_mean          | 32.0         |
|    ep_rew_mean          | 2.5185854    |
| time/                   |              |
|    fps                  | 270          |
|    iterations           | 86           |
|    time_elapsed         | 1017         |
|    total_timesteps      | 275200       |
| train/                  |              |
|    approx_kl            | 0.0122506535 |
|    clip_fraction        | 0.122        |
|    clip_range           | 0.2          |
|    entropy_loss         | -45.1        |
|    explained_variance   | 0.937        |
|    learning_

KeyboardInterrupt: 

## CFE Chain Initialization

To classify the results of the reinforcement learning method, they are compared to a supervised control force estimator approach using a differentiable physics loss. This comparison seems fair as both algorithms work by optimizing through trial and error. 

The CFE approach has access to the gradient data provided by the differentiable solver, making it possible to trace the loss over multiple timesteps and enabling the model to better comprehend long term effects of generated forces. 

The reinforcement learning approach on the other hand uses a dedicated value estimator network (critic) to predict the sum of rewards generated from a certain state. These are then used to update a policy network (actor) which, analogously to the control force estimator network, predicts the forces to control the simulation. The reinforcement learning algorithm is not limited by training dataset size, as new training samples are generated on policy. However, this also introduces additional simulation overhead during training. This can increase training duration.

In [None]:
from control.pde.burgers import BurgersPDE
from control.control_training import ControlTraining
from control.sequences import StaggeredSequence

In [None]:
cfe_app = ControlTraining(
    step_count,
    BurgersPDE(domain, viscosity, dt),
    datapath=data_path,
    val_range=val_range,
    train_range=train_range,
    trace_to_channel=lambda trace: 'burgers_velocity',
    obs_loss_frames=[],
    trainable_networks=['CFE'],
    sequence_class=StaggeredSequence,
    batch_size=100,
    view_size=20,
    learning_rate=1e-3,
    learning_rate_half_life=1000,
    dt=dt
).prepare()

In [None]:
import time

cfe_training_eval_data = []

start_time = time.time()

for epoch in range(20000):
    cfe_app.progress()
    # Check the amount of forces the current model is producing
    if epoch % 10 == 0:
        # Divide by time delta to retrieve forces from L1 distances
        f = cfe_app.infer_scalars(test_range)['Total Force'] / dt
        cfe_training_eval_data.append((time.time() - start_time, epoch * 320, f))
        print('Forces: %f' % f)

In [None]:
import csv
import os

# Store the information about the training progress
cfe_store_path = '../networks/cfe-models/bench'
if not os.path.exists(cfe_store_path):
    os.makedirs(cfe_store_path)
with open(os.path.join(cfe_store_path, 'test_forces.csv'), 'at') as file:
    logger = csv.DictWriter(file, ('time', 'iteration', 'forces'))
    logger.writeheader()
    for (t, i, f) in cfe_training_eval_data:
        logger.writerow({'time': t, 'iteration': i, 'forces': f})

cfe_checkpoint = cfe_app.save_model()
cfe_checkpoint

Run the cell below to load a trained model:

In [None]:
cfe_app.load_checkpoints({net: 'cfe-models/bench/checkpoint_00010000' for net in ['OP2', 'OP4', 'OP8', 'OP16', 'OP32']})

In [None]:
path_template = os.path.join('..', 'networks', 'cfe-models', 'bench', 'checkpoint_%08i')

cp_forces = []

#for i in range(10000 // 50):
#    cp_path = path_template % (i * 50)
#    cfe_app.load_checkpoints({net: cp_path for net in ['OP2', 'OP4', 'OP8', 'OP16', 'OP32']})
#    cp_forces.append(cfe_app.infer_scalars(test_range)['Total Force'] / dt)
        
#plt.plot(cp_forces)
#plt.show

cp_path = path_template % 9950
cfe_app.load_checkpoints({net: cp_path for net in ['OP2', 'OP4', 'OP8', 'OP16', 'OP32']})

## Trajectory Comparison

Now we are set up to perform some face-to-face comparisons between the different methods.

In [None]:
rl_frames, gt_frames, pass_frames = rl_trainer.infer_test_set_frames()

cfe_frames = cfe_app.infer_all_frames(test_range)
cfe_frames = [s.burgers.velocity.data for s in cfe_frames]

In [None]:
index_in_set = 8
fig, axs = plt.subplots(2, 2, figsize=(12.8, 9.6))

axs[0, 0].set_title('Ground Truth')
axs[0, 1].set_title('Uncontrolled')
axs[1, 0].set_title('Reinforcement Learning')
axs[1, 1].set_title('Supervised Control Force Estimator')

for subplot_idcs in [(0,0), (0,1), (1,0), (1,1)]:
    axs[subplot_idcs].set_ylim(-2, 2)
    axs[subplot_idcs].set_xlabel('x')
    axs[subplot_idcs].set_ylabel('u(x)')
    axs[subplot_idcs].legend(['Initial state in dark red, final state in dark blue,'])
    
def plot_w_label(xy, field): 
    color = bplt.gradient_color(0, step_count+1)
    axs[xy].plot(field[0][index_in_set].squeeze(), color=color, linewidth=0.8, label='Initial state in dark red, final state in dark blue,')
    axs[xy].legend()
    
# Plot the first states and create a legend for each plot
plot_w_label((0, 0), gt_frames)
plot_w_label((0, 1), pass_frames)
plot_w_label((1, 0), rl_frames)
plot_w_label((1, 1), cfe_frames)
    
# Plot the remaining states
for frame in range(1, step_count + 1):
    color = bplt.gradient_color(frame, step_count+1)
    plot = lambda xy, field: axs[xy].plot(field[frame][index_in_set].squeeze(), color=color, linewidth=0.8)
    plot((0,0), gt_frames)
    plot((0,1), pass_frames)
    plot((1,0), rl_frames)
    plot((1,1), cfe_frames)

In [None]:
def infer_forces(frames):
    frames = np.array(frames)
    
    b = Burgers(diffusion_substeps=diffusion_substeps)
    to_state = lambda v: BurgersVelocity(domain, velocity=v, viscosity=viscosity) 
        
    # Simulate all timesteps of all trajectories at once
    # => concatenate all frames in batch dimension
    prv = to_state(frames[:-1].reshape((-1,) + frames.shape[2:]))
    prv_sim = b.step(prv, dt=dt)
    
    forces = (frames[1:] - prv_sim.velocity.data.reshape(step_count, -1, *frames.shape[2:])) / dt
    
    # Sanity check, should be able to reconstruct goal state with forces
    s = to_state(frames[0])
    for i in range(step_count):
        f = forces[i].reshape(s.velocity.data.shape)
        effect = FieldEffect(CenteredGrid(f, box=domain.box), ['velocity'])
        s = b.step(s, dt, (effect,))
    diff = frames[-1] - s.velocity.data
    print('Maximum deviation from target state: %f' % np.abs(diff).max())
    return forces

    
gt_forces = np.abs(infer_forces(gt_frames)).sum(axis=(0, 2)).squeeze()
cfe_forces = np.abs(infer_forces(cfe_frames)).sum(axis=(0, 2)).squeeze()
rl_forces = rl_trainer.infer_test_set_forces()

In [None]:
plt.figure(figsize=(12.8, 9.6))
plt.scatter(gt_forces, cfe_forces, label='CFE')
plt.scatter(gt_forces, rl_forces, label='RL')
plt.plot([x * 100 for x in range(15)], [x * 100 for x in range(15)], label='Same forces as original')
plt.xlabel('ground truth')
plt.xlim(0, 1500)
plt.ylim(0, 1500)
plt.ylabel('reconstruction')
plt.grid()
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(12.8, 9.6))
plt.scatter(rl_forces, cfe_forces)
plt.xlabel('reinforcement learning')
plt.ylabel('control force estimator')
plt.plot([x * 100 for x in range(15)], [x * 100 for x in range(15)], label='Same forces cfe rl')
plt.xlim(0, 1500)
plt.ylim(0, 1500)
plt.grid()
plt.legend()
plt.show()

In [None]:
w=0.25
plot_count=20
plt.figure(figsize=(12.8, 9.6))
plt.bar([i - w for i in range(plot_count)], rl_forces[:plot_count], width=w, align='center', label='RL')
plt.bar([i + w for i in range(plot_count)], cfe_forces[:plot_count], width=w, align='center', label='CFE')
plt.bar([i for i in range(plot_count)], gt_forces[:plot_count], width=w, align='center', label='GT')
plt.xlabel('Scenes')
plt.xticks(range(plot_count))
plt.ylabel('Forces')
plt.legend()
plt.show()

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import os
import pandas as pd

def get_rl_test_set_forces(experiment_name):
    path_template = os.path.join('..', 'networks', 'rl-models', experiment_name, 'tensorboard-log', 'PPO_%i')
    i = 0
    w_times, step_nums, forces = [], [], []
    while os.path.exists(path_template % i):
        event_acc = EventAccumulator(path_template % i)
        event_acc.Reload()
        #print(event_acc.Scalars('test_set_forces'))
        new_w_times, new_step_nums, new_forces = zip(*event_acc.Scalars('test_set_forces'))
        new_w_times = np.array(new_w_times) - new_w_times[0]
        if i > 0:
            new_w_times += w_times[-1]
            new_step_nums = np.array(new_step_nums) + step_nums[-1]
        w_times += list(new_w_times)
        step_nums += list(new_step_nums)
        forces += new_forces
        i += 1
    return w_times, step_nums, forces
      
def get_cfe_test_set_forces(experiment_name):
    path = os.path.join('..', 'networks', 'cfe-models', experiment_name, 'test_forces.csv')
    table = pd.read_csv(path, skiprows=[])
    print(table.keys())
    return list(table['time']), list(table['iteration']), list(table['forces'])
    

#rl_w_times, rl_step_nums, rl_forces_during_training = zip(*event_acc.Scalars('test_set_forces'))
rl_w_times, rl_step_nums, rl_forces_during_training = get_rl_test_set_forces('bench')
cfe_w_times, cfe_step_nums, cfe_forces_during_training = get_cfe_test_set_forces('bench')

print(cfe_step_nums[-1])

fig, axs = plt.subplots(2, 1, figsize=(12.8, 9.6))

axs[0].plot(np.array(rl_step_nums) / 320, rl_forces_during_training, label='RL')
axs[0].plot(np.array(cfe_step_nums) / 320, cfe_forces_during_training, label='CFE')
axs[0].set_xlabel('Epochs')
axs[0].set_ylabel('Forces')
axs[0].set_ylim(0, 1500)
axs[0].grid()
axs[0].legend()
axs[1].plot(np.array(rl_w_times) / 3600, rl_forces_during_training, label='RL')
axs[1].plot(np.array(cfe_w_times) / 3600, cfe_forces_during_training, label='CFE')
axs[1].set_xlabel('Wall time (hours)')
axs[1].set_ylabel('Forces')
axs[1].set_ylim(0, 1500)
axs[1].grid()
axs[1].legend()

fig.savefig('convergence_time_comparison.pdf')

plt.show()