# Introduzione

Obiettivo di questo notebook: testare il training degli algoritmi Implicit Q-Learning, Conservative Q-Learning, Behavior Cloning, TD3 con Behavior Cloning sull dataset `D4RL/pen/expert-v2`.

Il task "pen" richiede a una mano robotica (Adroit Hand) di manipolare una penna per portarla in una certa posizione nello spazio.

# Caricamento del dataset

In [42]:
import minari
import time
import numpy as np
from d3rlpy.algos import IQLConfig, CQLConfig, BCConfig, TD3PlusBCConfig
from d3rlpy.datasets import MDPDataset
from d3rlpy.constants import ActionSpace
from d3rlpy.metrics import EnvironmentEvaluator

In [2]:
dataset = minari.load_dataset("D4RL/pen/expert-v2")

In [3]:
print("Episodi totali:", dataset.total_episodes)
print("Spazio osservazioni:", dataset.observation_space)
print("Spazio azioni:", dataset.action_space)

Episodi totali: 4958
Spazio osservazioni: Box(-inf, inf, (45,), float64)
Spazio azioni: Box(-1.0, 1.0, (24,), float32)


In [4]:
episode = next(dataset.iterate_episodes())
print(episode)

#print(f"Osservazioni: \n{episode.observations[0]}")
#print(f"Actions: \n{episode.actions[0]}")
#print(f"Rewards: \n{episode.rewards[0]}")
#print(f"Terminations: \n{episode.terminations[0]}")

EpisodeData(id=0, total_steps=100, observations=ndarray of shape (101, 45) and dtype float64, actions=ndarray of shape (100, 24) and dtype float32, rewards=ndarray of 100 floats, terminations=ndarray of 100 bools, truncations=ndarray of 100 bools, infos=dict with the following keys: ['success'])


The task to be completed consists on repositioning the blue pen to match the orientation of the green target. The base of the hand is fixed. The target is also randomized to cover all configurations. The task will be considered successful when the orientations match within tolerance

# Preparazione dataset

d3rlpy si aspetta che il dataset sia composto da transizioni, in cui ogni elemento contiene uno stato, un’azione, una ricompensa, lo stato successivo e un flag terminale, tutti allineati in modo che lo stato e l’azione alla posizione i corrispondano alla transizione verso lo stato alla posizione i+1. A tal fine, la libreria mette a disposizione la classe MDPDataset, che consente di creare facilmente un oggetto dataset nel formato richiesto. All'interno del dataset non c'è la distinzione in episodi, tutti gli step sono uniti in un unico array.

In [5]:
observations = []
actions = []
rewards = []
terminals = []

for episode in dataset.iterate_episodes():
    # si rimuove l'ultimo elemento, in quanto non ha una successiva azione associata
    obs = episode.observations[:-1]
    actions_ep = episode.actions
    rewards_ep = episode.rewards
    dones = np.array(episode.terminations) | np.array(episode.truncations)

    observations.append(obs)
    actions.append(actions_ep)
    rewards.append(rewards_ep)
    terminals.append(dones)

# ora observations è un array di 4958 array (episodi) di 100 array circa (step) di array (osservazioni). stesso discorso per gli altri

# si uniscono gli array in modo da avere, per ogni step del dataset osservazioni, azione, reward, terminali
observations = np.concatenate(observations)
actions = np.concatenate(actions)
rewards = np.concatenate(rewards)
terminals = np.concatenate(terminals)

# ora observations è un array di 499206 (step in tutto il dataset) di array (osservazioni) . stesso discorso per gli altri
print(observations.shape)
print(actions.shape)
print(rewards.shape)
print(terminals.shape)

d3_dataset = MDPDataset(observations, actions, rewards, terminals, action_space = ActionSpace.CONTINUOUS)

(499206, 45)
(499206, 24)
(499206,)
(499206,)
[2m2025-04-07 09:34.43[0m [[32m[1minfo     [0m] [1mSignatures have been automatically determined.[0m [36maction_signature[0m=[35mSignature(dtype=[dtype('float32')], shape=[(24,)])[0m [36mobservation_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(45,)])[0m [36mreward_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(1,)])[0m
[2m2025-04-07 09:34.43[0m [[32m[1minfo     [0m] [1mAction size has been automatically determined.[0m [36maction_size[0m=[35m24[0m


# Implicit Q-Learning

In [6]:
iql = IQLConfig().create(device="cpu")

In [7]:
iql.build_with_dataset(d3_dataset)

In [9]:
env = dataset.recover_environment()

iql.fit(
    dataset=d3_dataset,
    n_steps=10000,
    n_steps_per_epoch=1000,
    evaluators={"env": EnvironmentEvaluator(env)},
)

[2m2025-04-07 09:35.04[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('float64')], shape=[(45,)]), action_signature=Signature(dtype=[dtype('float32')], shape=[(24,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.CONTINUOUS: 1>, action_size=24)[0m
[2m2025-04-07 09:35.04[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/IQL_20250407093504[0m
[2m2025-04-07 09:35.04[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [45], 'action_size': 24, 'config': {'type': 'iql', 'params': {'batch_size': 256, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.0003, 'actor_opt

Epoch 1/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:35.10[0m [[32m[1minfo     [0m] [1mIQL_20250407093504: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015404837131500243, 'time_algorithm_update': 0.004327301740646362, 'critic_loss': 2675.132789794922, 'q_loss': 2659.8955614318847, 'v_loss': 15.23722476863861, 'actor_loss': 58.79665441417694, 'time_step': 0.005905513525009155, 'env': 2895.1029235642613}[0m [36mstep[0m=[35m1000[0m
[2m2025-04-07 09:35.10[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/IQL_20250407093504/model_1000.d3[0m


Epoch 2/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:35.17[0m [[32m[1minfo     [0m] [1mIQL_20250407093504: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015875694751739501, 'time_algorithm_update': 0.004651618719100952, 'critic_loss': 6148.825401062012, 'q_loss': 6121.134932830811, 'v_loss': 27.690466495513917, 'actor_loss': 47.209079043388364, 'time_step': 0.006278772592544555, 'env': 3243.914709515445}[0m [36mstep[0m=[35m2000[0m
[2m2025-04-07 09:35.17[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/IQL_20250407093504/model_2000.d3[0m


Epoch 3/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:35.24[0m [[32m[1minfo     [0m] [1mIQL_20250407093504: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015845096111297607, 'time_algorithm_update': 0.004515581607818604, 'critic_loss': 10893.418867370605, 'q_loss': 10849.978895690918, 'v_loss': 43.439971202850344, 'actor_loss': 41.11088452291489, 'time_step': 0.0061383905410766606, 'env': 3545.555285391757}[0m [36mstep[0m=[35m3000[0m
[2m2025-04-07 09:35.24[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/IQL_20250407093504/model_3000.d3[0m


Epoch 4/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:35.30[0m [[32m[1minfo     [0m] [1mIQL_20250407093504: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015312747955322265, 'time_algorithm_update': 0.004358739852905274, 'critic_loss': 15834.384896728516, 'q_loss': 15775.871246154786, 'v_loss': 58.513670345306394, 'actor_loss': 35.86555141592026, 'time_step': 0.005928244352340698, 'env': 2017.4625059093275}[0m [36mstep[0m=[35m4000[0m
[2m2025-04-07 09:35.30[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/IQL_20250407093504/model_4000.d3[0m


Epoch 5/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:35.36[0m [[32m[1minfo     [0m] [1mIQL_20250407093504: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015296976566314697, 'time_algorithm_update': 0.004357388019561767, 'critic_loss': 21307.02997314453, 'q_loss': 21231.473981201172, 'v_loss': 75.5559903755188, 'actor_loss': 34.97971667945385, 'time_step': 0.005925191164016724, 'env': 3401.0270595033753}[0m [36mstep[0m=[35m5000[0m
[2m2025-04-07 09:35.36[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/IQL_20250407093504/model_5000.d3[0m


Epoch 6/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:35.43[0m [[32m[1minfo     [0m] [1mIQL_20250407093504: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015251328945159912, 'time_algorithm_update': 0.004342862129211426, 'critic_loss': 26954.06423779297, 'q_loss': 26854.962735717774, 'v_loss': 99.10148280334472, 'actor_loss': 33.99680823934078, 'time_step': 0.00590580677986145, 'env': 3455.4211854682712}[0m [36mstep[0m=[35m6000[0m
[2m2025-04-07 09:35.43[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/IQL_20250407093504/model_6000.d3[0m


Epoch 7/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:35.49[0m [[32m[1minfo     [0m] [1mIQL_20250407093504: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001564725399017334, 'time_algorithm_update': 0.004535144567489624, 'critic_loss': 32913.338779052734, 'q_loss': 32783.045158203124, 'v_loss': 130.29362879180908, 'actor_loss': 34.158071663856504, 'time_step': 0.006138598203659058, 'env': 3723.874593359303}[0m [36mstep[0m=[35m7000[0m
[2m2025-04-07 09:35.49[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/IQL_20250407093504/model_7000.d3[0m


Epoch 8/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:35.56[0m [[32m[1minfo     [0m] [1mIQL_20250407093504: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015206351280212403, 'time_algorithm_update': 0.00432393741607666, 'critic_loss': 36687.09009033203, 'q_loss': 36537.14253442383, 'v_loss': 149.9475538635254, 'actor_loss': 33.591320398807525, 'time_step': 0.005881454706192017, 'env': 2922.41588492131}[0m [36mstep[0m=[35m8000[0m
[2m2025-04-07 09:35.56[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/IQL_20250407093504/model_8000.d3[0m


Epoch 9/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:36.02[0m [[32m[1minfo     [0m] [1mIQL_20250407093504: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015833544731140136, 'time_algorithm_update': 0.004581630229949951, 'critic_loss': 42668.74803686523, 'q_loss': 42504.558054931644, 'v_loss': 164.1900009765625, 'actor_loss': 32.31873790705204, 'time_step': 0.006203950166702271, 'env': 3990.6853455955543}[0m [36mstep[0m=[35m9000[0m
[2m2025-04-07 09:36.02[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/IQL_20250407093504/model_9000.d3[0m


Epoch 10/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:36.09[0m [[32m[1minfo     [0m] [1mIQL_20250407093504: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015127980709075928, 'time_algorithm_update': 0.004278781175613403, 'critic_loss': 45794.06400708008, 'q_loss': 45618.16446240235, 'v_loss': 175.89949571990968, 'actor_loss': 32.171769808888435, 'time_step': 0.005828266382217407, 'env': 2531.518444096145}[0m [36mstep[0m=[35m10000[0m
[2m2025-04-07 09:36.09[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/IQL_20250407093504/model_10000.d3[0m


[(1,
  {'time_sample_batch': 0.0015404837131500243,
   'time_algorithm_update': 0.004327301740646362,
   'critic_loss': 2675.132789794922,
   'q_loss': 2659.8955614318847,
   'v_loss': 15.23722476863861,
   'actor_loss': 58.79665441417694,
   'time_step': 0.005905513525009155,
   'env': 2895.1029235642613}),
 (2,
  {'time_sample_batch': 0.0015875694751739501,
   'time_algorithm_update': 0.004651618719100952,
   'critic_loss': 6148.825401062012,
   'q_loss': 6121.134932830811,
   'v_loss': 27.690466495513917,
   'actor_loss': 47.209079043388364,
   'time_step': 0.006278772592544555,
   'env': 3243.914709515445}),
 (3,
  {'time_sample_batch': 0.0015845096111297607,
   'time_algorithm_update': 0.004515581607818604,
   'critic_loss': 10893.418867370605,
   'q_loss': 10849.978895690918,
   'v_loss': 43.439971202850344,
   'actor_loss': 41.11088452291489,
   'time_step': 0.0061383905410766606,
   'env': 3545.555285391757}),
 (4,
  {'time_sample_batch': 0.0015312747955322265,
   'time_algorit

In [25]:
env = dataset.recover_environment(render_mode="human", camera_id=2)
obs, _ = env.reset()
done = False
total_reward = 0

for _ in range(1000):
    action = iql.predict(obs[None])[0]
    obs, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    if terminated:
        break

env.close()
print(f"Reward totale: {total_reward}")

Reward totale: 5152.226734340886


# Conservative Q-Learning

In [14]:
cql = CQLConfig().create(device="cpu")

In [16]:
cql.build_with_dataset(d3_dataset)

In [17]:
env = dataset.recover_environment()

cql.fit(
    dataset=d3_dataset,
    n_steps=10000,
    n_steps_per_epoch=1000,
    evaluators={"env": EnvironmentEvaluator(env)},
)

[2m2025-04-07 09:39.05[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('float64')], shape=[(45,)]), action_signature=Signature(dtype=[dtype('float32')], shape=[(24,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.CONTINUOUS: 1>, action_size=24)[0m
[2m2025-04-07 09:39.05[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/CQL_20250407093905[0m
[2m2025-04-07 09:39.05[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [45], 'action_size': 24, 'config': {'type': 'cql', 'params': {'batch_size': 256, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0001, 'critic_learning_rate': 0.0003, 'temp_lear

Epoch 1/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:39.41[0m [[32m[1minfo     [0m] [1mCQL_20250407093905: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0018825025558471679, 'time_algorithm_update': 0.033327054500579836, 'critic_loss': 850.7936302185059, 'conservative_loss': 3.736338350892067, 'alpha': 1.0130486673116683, 'actor_loss': -91.61962862110138, 'temp': 0.9556084146499634, 'temp_loss': 29.52512632369995, 'time_step': 0.0352595739364624, 'env': 185.83549634159868}[0m [36mstep[0m=[35m1000[0m
[2m2025-04-07 09:39.41[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20250407093905/model_1000.d3[0m


Epoch 2/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:40.20[0m [[32m[1minfo     [0m] [1mCQL_20250407093905: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0019532697200775146, 'time_algorithm_update': 0.03672051596641541, 'critic_loss': 2269.0222620391846, 'conservative_loss': -39.814375631332396, 'alpha': 0.9383266568779945, 'actor_loss': -244.5562688140869, 'temp': 0.8765679097175598, 'temp_loss': 21.79221344947815, 'time_step': 0.03872582507133484, 'env': 241.60821961057314}[0m [36mstep[0m=[35m2000[0m
[2m2025-04-07 09:40.20[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20250407093905/model_2000.d3[0m


Epoch 3/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:40.54[0m [[32m[1minfo     [0m] [1mCQL_20250407093905: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0017448365688323975, 'time_algorithm_update': 0.031003796339035033, 'critic_loss': 6058.591243255615, 'conservative_loss': -31.02224608898163, 'alpha': 0.8535002691745758, 'actor_loss': -411.08404415893557, 'temp': 0.8073461208939552, 'temp_loss': 17.309443717002868, 'time_step': 0.03279752731323242, 'env': 563.6578634575772}[0m [36mstep[0m=[35m3000[0m
[2m2025-04-07 09:40.54[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20250407093905/model_3000.d3[0m


Epoch 4/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:41.28[0m [[32m[1minfo     [0m] [1mCQL_20250407093905: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0017871615886688233, 'time_algorithm_update': 0.03166832685470581, 'critic_loss': 11751.692608215331, 'conservative_loss': -9.207659454524517, 'alpha': 0.8023225919008256, 'actor_loss': -585.5004121398925, 'temp': 0.7465361017584801, 'temp_loss': 13.238296122550965, 'time_step': 0.033505909204483035, 'env': 193.09222956213483}[0m [36mstep[0m=[35m4000[0m
[2m2025-04-07 09:41.28[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20250407093905/model_4000.d3[0m


Epoch 5/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:42.02[0m [[32m[1minfo     [0m] [1mCQL_20250407093905: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001770390272140503, 'time_algorithm_update': 0.03140075993537903, 'critic_loss': 17966.91237109375, 'conservative_loss': 31.700336929917334, 'alpha': 0.8337149719595909, 'actor_loss': -780.149253479004, 'temp': 0.6942932325601577, 'temp_loss': 9.425389189243317, 'time_step': 0.03322071266174317, 'env': 579.0848532741566}[0m [36mstep[0m=[35m5000[0m
[2m2025-04-07 09:42.02[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20250407093905/model_5000.d3[0m


Epoch 6/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:42.41[0m [[32m[1minfo     [0m] [1mCQL_20250407093905: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0020388431549072266, 'time_algorithm_update': 0.03641819334030151, 'critic_loss': 27500.07182458496, 'conservative_loss': 115.7870230178833, 'alpha': 0.9788761110901832, 'actor_loss': -1001.5450241088868, 'temp': 0.6517980610132217, 'temp_loss': 5.8995378651618955, 'time_step': 0.038510204315185544, 'env': 35.91653338683333}[0m [36mstep[0m=[35m6000[0m
[2m2025-04-07 09:42.41[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20250407093905/model_6000.d3[0m


Epoch 7/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:43.15[0m [[32m[1minfo     [0m] [1mCQL_20250407093905: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0018669300079345703, 'time_algorithm_update': 0.031592557430267336, 'critic_loss': 37798.14590039063, 'conservative_loss': 232.01046014404298, 'alpha': 1.1505273463726045, 'actor_loss': -1247.7033227539061, 'temp': 0.6181558321714401, 'temp_loss': 3.3790261780023574, 'time_step': 0.0335113799571991, 'env': 207.9944095984064}[0m [36mstep[0m=[35m7000[0m
[2m2025-04-07 09:43.15[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20250407093905/model_7000.d3[0m


Epoch 8/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:43.50[0m [[32m[1minfo     [0m] [1mCQL_20250407093905: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0019152917861938476, 'time_algorithm_update': 0.03220466494560242, 'critic_loss': 51627.58175048828, 'conservative_loss': 383.3483046722412, 'alpha': 1.3312392839193343, 'actor_loss': -1512.3378660888673, 'temp': 0.5947071011066437, 'temp_loss': 1.3631523427758365, 'time_step': 0.034172234773635866, 'env': 22.479925475155135}[0m [36mstep[0m=[35m8000[0m
[2m2025-04-07 09:43.50[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20250407093905/model_8000.d3[0m


Epoch 9/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:44.25[0m [[32m[1minfo     [0m] [1mCQL_20250407093905: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0019985873699188235, 'time_algorithm_update': 0.03291210985183716, 'critic_loss': 71410.32231298828, 'conservative_loss': 598.3275304870606, 'alpha': 1.5261711584329605, 'actor_loss': -1784.3933461914062, 'temp': 0.5873868223428726, 'temp_loss': -0.08099927717074752, 'time_step': 0.03496393513679504, 'env': 8.080350396701569}[0m [36mstep[0m=[35m9000[0m
[2m2025-04-07 09:44.25[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20250407093905/model_9000.d3[0m


Epoch 10/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:45.00[0m [[32m[1minfo     [0m] [1mCQL_20250407093905: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0019095804691314697, 'time_algorithm_update': 0.03276324963569641, 'critic_loss': 93127.96612402344, 'conservative_loss': 906.7369437255859, 'alpha': 1.748260232925415, 'actor_loss': -2063.5706540527344, 'temp': 0.5994024122357369, 'temp_loss': -0.845141943179071, 'time_step': 0.034725390672683716, 'env': 30.865614627095464}[0m [36mstep[0m=[35m10000[0m
[2m2025-04-07 09:45.00[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20250407093905/model_10000.d3[0m


[(1,
  {'time_sample_batch': 0.0018825025558471679,
   'time_algorithm_update': 0.033327054500579836,
   'critic_loss': 850.7936302185059,
   'conservative_loss': 3.736338350892067,
   'alpha': 1.0130486673116683,
   'actor_loss': -91.61962862110138,
   'temp': 0.9556084146499634,
   'temp_loss': 29.52512632369995,
   'time_step': 0.0352595739364624,
   'env': 185.83549634159868}),
 (2,
  {'time_sample_batch': 0.0019532697200775146,
   'time_algorithm_update': 0.03672051596641541,
   'critic_loss': 2269.0222620391846,
   'conservative_loss': -39.814375631332396,
   'alpha': 0.9383266568779945,
   'actor_loss': -244.5562688140869,
   'temp': 0.8765679097175598,
   'temp_loss': 21.79221344947815,
   'time_step': 0.03872582507133484,
   'env': 241.60821961057314}),
 (3,
  {'time_sample_batch': 0.0017448365688323975,
   'time_algorithm_update': 0.031003796339035033,
   'critic_loss': 6058.591243255615,
   'conservative_loss': -31.02224608898163,
   'alpha': 0.8535002691745758,
   'actor_lo

In [30]:
env = dataset.recover_environment(render_mode="human", camera_id=2)
obs, _ = env.reset()
done = False
total_reward = 0

for _ in range(1000):
    action = cql.predict(obs[None])[0]
    obs, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    if terminated:
        break

env.close()
print(f"Reward totale: {total_reward}")

Reward totale: -4405.210054516312


# Behavior Cloning

In [34]:
bc = BCConfig().create(device="cpu")

In [35]:
bc.build_with_dataset(d3_dataset)

In [36]:
env = dataset.recover_environment()

bc.fit(
    dataset=d3_dataset,
    n_steps=10000,
    n_steps_per_epoch=1000,
    evaluators={"env": EnvironmentEvaluator(env)},
)

[2m2025-04-07 09:50.31[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('float64')], shape=[(45,)]), action_signature=Signature(dtype=[dtype('float32')], shape=[(24,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.CONTINUOUS: 1>, action_size=24)[0m
[2m2025-04-07 09:50.31[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/BC_20250407095031[0m
[2m2025-04-07 09:50.31[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [45], 'action_size': 24, 'config': {'type': 'bc', 'params': {'batch_size': 100, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'learning_rate': 0.001, 'policy_type': 'deterministic', 'optim_factory': {'

Epoch 1/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:50.33[0m [[32m[1minfo     [0m] [1mBC_20250407095031: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0006606235504150391, 'time_algorithm_update': 0.000628838062286377, 'loss': 0.10639423998445273, 'time_step': 0.0013068647384643554, 'env': 2617.8516269407746}[0m [36mstep[0m=[35m1000[0m
[2m2025-04-07 09:50.33[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BC_20250407095031/model_1000.d3[0m


Epoch 2/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:50.34[0m [[32m[1minfo     [0m] [1mBC_20250407095031: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0005909712314605713, 'time_algorithm_update': 0.0005823712348937988, 'loss': 0.09810689052194357, 'time_step': 0.0011873888969421388, 'env': 3518.544428844261}[0m [36mstep[0m=[35m2000[0m
[2m2025-04-07 09:50.34[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BC_20250407095031/model_2000.d3[0m


Epoch 3/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:50.36[0m [[32m[1minfo     [0m] [1mBC_20250407095031: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0005932431221008301, 'time_algorithm_update': 0.0006242425441741943, 'loss': 0.09756314815580845, 'time_step': 0.001231632947921753, 'env': 2028.498268216287}[0m [36mstep[0m=[35m3000[0m
[2m2025-04-07 09:50.36[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BC_20250407095031/model_3000.d3[0m


Epoch 4/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:50.38[0m [[32m[1minfo     [0m] [1mBC_20250407095031: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000613410472869873, 'time_algorithm_update': 0.0006431434154510499, 'loss': 0.09678063146024943, 'time_step': 0.0012755467891693115, 'env': 3487.7231959085757}[0m [36mstep[0m=[35m4000[0m
[2m2025-04-07 09:50.38[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BC_20250407095031/model_4000.d3[0m


Epoch 5/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:50.39[0m [[32m[1minfo     [0m] [1mBC_20250407095031: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0005931143760681153, 'time_algorithm_update': 0.000586411714553833, 'loss': 0.09684043920785189, 'time_step': 0.0011939258575439453, 'env': 2906.980817689061}[0m [36mstep[0m=[35m5000[0m
[2m2025-04-07 09:50.39[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BC_20250407095031/model_5000.d3[0m


Epoch 6/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:50.41[0m [[32m[1minfo     [0m] [1mBC_20250407095031: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0005937471389770508, 'time_algorithm_update': 0.0005896284580230713, 'loss': 0.09647191342711449, 'time_step': 0.0011992058753967285, 'env': 3430.1356376729104}[0m [36mstep[0m=[35m6000[0m
[2m2025-04-07 09:50.41[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BC_20250407095031/model_6000.d3[0m


Epoch 7/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:50.42[0m [[32m[1minfo     [0m] [1mBC_20250407095031: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0005920298099517822, 'time_algorithm_update': 0.0005809595584869385, 'loss': 0.09625664103776216, 'time_step': 0.001187568187713623, 'env': 2969.9219700931403}[0m [36mstep[0m=[35m7000[0m
[2m2025-04-07 09:50.42[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BC_20250407095031/model_7000.d3[0m


Epoch 8/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:50.44[0m [[32m[1minfo     [0m] [1mBC_20250407095031: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0006138720512390137, 'time_algorithm_update': 0.0006170496940612793, 'loss': 0.09597478502988815, 'time_step': 0.0012484796047210693, 'env': 3041.2249805341603}[0m [36mstep[0m=[35m8000[0m
[2m2025-04-07 09:50.44[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BC_20250407095031/model_8000.d3[0m


Epoch 9/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:50.46[0m [[32m[1minfo     [0m] [1mBC_20250407095031: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0005875816345214844, 'time_algorithm_update': 0.000578500509262085, 'loss': 0.09610408852249384, 'time_step': 0.0011793177127838136, 'env': 2316.8861195542713}[0m [36mstep[0m=[35m9000[0m
[2m2025-04-07 09:50.46[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BC_20250407095031/model_9000.d3[0m


Epoch 10/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:50.47[0m [[32m[1minfo     [0m] [1mBC_20250407095031: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0006136603355407715, 'time_algorithm_update': 0.0006180744171142578, 'loss': 0.09613627146184445, 'time_step': 0.0012479488849639893, 'env': 2495.5037871943528}[0m [36mstep[0m=[35m10000[0m
[2m2025-04-07 09:50.47[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BC_20250407095031/model_10000.d3[0m


[(1,
  {'time_sample_batch': 0.0006606235504150391,
   'time_algorithm_update': 0.000628838062286377,
   'loss': 0.10639423998445273,
   'time_step': 0.0013068647384643554,
   'env': 2617.8516269407746}),
 (2,
  {'time_sample_batch': 0.0005909712314605713,
   'time_algorithm_update': 0.0005823712348937988,
   'loss': 0.09810689052194357,
   'time_step': 0.0011873888969421388,
   'env': 3518.544428844261}),
 (3,
  {'time_sample_batch': 0.0005932431221008301,
   'time_algorithm_update': 0.0006242425441741943,
   'loss': 0.09756314815580845,
   'time_step': 0.001231632947921753,
   'env': 2028.498268216287}),
 (4,
  {'time_sample_batch': 0.000613410472869873,
   'time_algorithm_update': 0.0006431434154510499,
   'loss': 0.09678063146024943,
   'time_step': 0.0012755467891693115,
   'env': 3487.7231959085757}),
 (5,
  {'time_sample_batch': 0.0005931143760681153,
   'time_algorithm_update': 0.000586411714553833,
   'loss': 0.09684043920785189,
   'time_step': 0.0011939258575439453,
   'env'

In [41]:
env = dataset.recover_environment(render_mode="human", camera_id=2)
obs, _ = env.reset()
done = False
total_reward = 0

for _ in range(1000):
    action = bc.predict(obs[None])[0]
    obs, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    if terminated:
        break

env.close()
print(f"Reward totale: {total_reward}")

Reward totale: 59847.596286878754


# TD3 + BC

In [43]:
td3bc = TD3PlusBCConfig().create(device="cpu")

In [44]:
td3bc.build_with_dataset(d3_dataset)

In [45]:
env = dataset.recover_environment()

td3bc.fit(
    dataset=d3_dataset,
    n_steps=10000,
    n_steps_per_epoch=1000,
    evaluators={"env": EnvironmentEvaluator(env)},
)

[2m2025-04-07 09:54.34[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('float64')], shape=[(45,)]), action_signature=Signature(dtype=[dtype('float32')], shape=[(24,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.CONTINUOUS: 1>, action_size=24)[0m
[2m2025-04-07 09:54.34[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/TD3PlusBC_20250407095434[0m
[2m2025-04-07 09:54.34[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [45], 'action_size': 24, 'config': {'type': 'td3_plus_bc', 'params': {'batch_size': 256, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'compile_graph': False, 'actor_learning_rate': 0.0003, 'critic_learning_rate': 0.00

Epoch 1/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:54.39[0m [[32m[1minfo     [0m] [1mTD3PlusBC_20250407095434: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016341910362243652, 'time_algorithm_update': 0.003619870901107788, 'critic_loss': 677.8931263122558, 'actor_loss': -2.057920249223709, 'bc_loss': 0.44182866632938383, 'time_step': 0.005290372610092163, 'env': 25.081717289435858}[0m [36mstep[0m=[35m1000[0m
[2m2025-04-07 09:54.39[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/TD3PlusBC_20250407095434/model_1000.d3[0m


Epoch 2/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:54.46[0m [[32m[1minfo     [0m] [1mTD3PlusBC_20250407095434: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016969678401947022, 'time_algorithm_update': 0.004318328142166138, 'critic_loss': 812.1644007720947, 'actor_loss': -2.307586257457733, 'bc_loss': 0.18955203637480736, 'time_step': 0.0060540554523468015, 'env': 254.08299282699954}[0m [36mstep[0m=[35m2000[0m
[2m2025-04-07 09:54.46[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/TD3PlusBC_20250407095434/model_2000.d3[0m


Epoch 3/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:54.52[0m [[32m[1minfo     [0m] [1mTD3PlusBC_20250407095434: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016006903648376465, 'time_algorithm_update': 0.0037991657257080078, 'critic_loss': 1875.3501327514648, 'actor_loss': -2.339436363220215, 'bc_loss': 0.1587618647515774, 'time_step': 0.0054336261749267575, 'env': 517.566570717139}[0m [36mstep[0m=[35m3000[0m
[2m2025-04-07 09:54.52[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/TD3PlusBC_20250407095434/model_3000.d3[0m


Epoch 4/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:54.57[0m [[32m[1minfo     [0m] [1mTD3PlusBC_20250407095434: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015545170307159424, 'time_algorithm_update': 0.0035663602352142333, 'critic_loss': 3618.996870025635, 'actor_loss': -2.347216776371002, 'bc_loss': 0.15163418546319007, 'time_step': 0.0051541953086853024, 'env': 189.48289926769175}[0m [36mstep[0m=[35m4000[0m
[2m2025-04-07 09:54.57[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/TD3PlusBC_20250407095434/model_4000.d3[0m


Epoch 5/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:55.03[0m [[32m[1minfo     [0m] [1mTD3PlusBC_20250407095434: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015469982624053955, 'time_algorithm_update': 0.0034670670032501223, 'critic_loss': 5997.99176171875, 'actor_loss': -2.3539725036621095, 'bc_loss': 0.14530336844921113, 'time_step': 0.0050465683937072755, 'env': 1424.536695643879}[0m [36mstep[0m=[35m5000[0m
[2m2025-04-07 09:55.03[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/TD3PlusBC_20250407095434/model_5000.d3[0m


Epoch 6/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:55.09[0m [[32m[1minfo     [0m] [1mTD3PlusBC_20250407095434: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016023187637329102, 'time_algorithm_update': 0.0038863098621368408, 'critic_loss': 8941.362687469482, 'actor_loss': -2.3555412368774413, 'bc_loss': 0.14386986243724822, 'time_step': 0.005524063110351562, 'env': 709.6473905673253}[0m [36mstep[0m=[35m6000[0m
[2m2025-04-07 09:55.09[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/TD3PlusBC_20250407095434/model_6000.d3[0m


Epoch 7/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:55.16[0m [[32m[1minfo     [0m] [1mTD3PlusBC_20250407095434: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0017428624629974366, 'time_algorithm_update': 0.00450113582611084, 'critic_loss': 12277.063099182129, 'actor_loss': -2.3561470046043396, 'bc_loss': 0.14320240366458892, 'time_step': 0.006277882337570191, 'env': 1217.860199508645}[0m [36mstep[0m=[35m7000[0m
[2m2025-04-07 09:55.16[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/TD3PlusBC_20250407095434/model_7000.d3[0m


Epoch 8/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:55.22[0m [[32m[1minfo     [0m] [1mTD3PlusBC_20250407095434: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001634397268295288, 'time_algorithm_update': 0.004120930910110474, 'critic_loss': 15581.037137268066, 'actor_loss': -2.3575564465522767, 'bc_loss': 0.14193287935853005, 'time_step': 0.0057871718406677245, 'env': 1111.8460099816907}[0m [36mstep[0m=[35m8000[0m
[2m2025-04-07 09:55.22[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/TD3PlusBC_20250407095434/model_8000.d3[0m


Epoch 9/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:55.28[0m [[32m[1minfo     [0m] [1mTD3PlusBC_20250407095434: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016840426921844482, 'time_algorithm_update': 0.004356842279434204, 'critic_loss': 20406.089936340333, 'actor_loss': -2.3565680599212646, 'bc_loss': 0.14288178727030754, 'time_step': 0.006075936317443848, 'env': 212.47088908348474}[0m [36mstep[0m=[35m9000[0m
[2m2025-04-07 09:55.28[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/TD3PlusBC_20250407095434/model_9000.d3[0m


Epoch 10/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2025-04-07 09:55.34[0m [[32m[1minfo     [0m] [1mTD3PlusBC_20250407095434: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001544571876525879, 'time_algorithm_update': 0.003459683656692505, 'critic_loss': 25075.283450073242, 'actor_loss': -2.3554778652191164, 'bc_loss': 0.14416065487265586, 'time_step': 0.005035207986831665, 'env': 856.5322604740846}[0m [36mstep[0m=[35m10000[0m
[2m2025-04-07 09:55.34[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/TD3PlusBC_20250407095434/model_10000.d3[0m


[(1,
  {'time_sample_batch': 0.0016341910362243652,
   'time_algorithm_update': 0.003619870901107788,
   'critic_loss': 677.8931263122558,
   'actor_loss': -2.057920249223709,
   'bc_loss': 0.44182866632938383,
   'time_step': 0.005290372610092163,
   'env': 25.081717289435858}),
 (2,
  {'time_sample_batch': 0.0016969678401947022,
   'time_algorithm_update': 0.004318328142166138,
   'critic_loss': 812.1644007720947,
   'actor_loss': -2.307586257457733,
   'bc_loss': 0.18955203637480736,
   'time_step': 0.0060540554523468015,
   'env': 254.08299282699954}),
 (3,
  {'time_sample_batch': 0.0016006903648376465,
   'time_algorithm_update': 0.0037991657257080078,
   'critic_loss': 1875.3501327514648,
   'actor_loss': -2.339436363220215,
   'bc_loss': 0.1587618647515774,
   'time_step': 0.0054336261749267575,
   'env': 517.566570717139}),
 (4,
  {'time_sample_batch': 0.0015545170307159424,
   'time_algorithm_update': 0.0035663602352142333,
   'critic_loss': 3618.996870025635,
   'actor_loss':

In [50]:
env = dataset.recover_environment(render_mode="human", camera_id=2)
obs, _ = env.reset()
done = False
total_reward = 0

for _ in range(1000):
    action = td3bc.predict(obs[None])[0]
    obs, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    if terminated:
        break

env.close()
print(f"Reward totale: {total_reward}")

Reward totale: 866.3860965817254
