# DDPG Training for Balancing Robot

This notebook trains a DDPG agent for the balancing robot environment using the provided configurations.

In [None]:
# Uncomment the following lines to run in Google Colab

# %cd /content
# !git clone https://github.com/EyalPorat/ddpg-balancing-robot.git
# %cd ddpg-balancing-robot
# !git checkout master-organize-python-proj
# %cd /content/ddpg-balancing-robot/python/notebooks

# import sys
# sys.path.append('/content/ddpg-balancing-robot/python')  # Add the repo root to Python path

In [1]:
import sys
sys.path.append('..')

import torch
import numpy as np
from pathlib import Path
import yaml
import matplotlib.pyplot as plt

from src.balancing_robot.models import Actor, Critic, ReplayBuffer, SimNet
from src.balancing_robot.environment import BalancerEnv
from src.balancing_robot.training import DDPGTrainer
from src.balancing_robot.visualization import plot_training_metrics, create_episode_animation

## Load Configurations

In [2]:
# Load DDPG and environment configurations
with open('../configs/ddpg_config.yaml', 'r') as f:
    ddpg_config = yaml.safe_load(f)

with open('../configs/env_config.yaml', 'r') as f:
    env_config = yaml.safe_load(f)

with open('../configs/simnet_config.yaml', 'r') as f:
    simnet_config = yaml.safe_load(f)

# Create log directory
log_dir = Path('logs/ddpg_training')
log_dir.mkdir(parents=True, exist_ok=True)

## Initialize Environment and Models

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Set random seeds from config
torch.manual_seed(ddpg_config["training"].get("random_seed", 42))
np.random.seed(ddpg_config["training"].get("random_seed", 42))

# Create environment
env = BalancerEnv(config_path="../configs/env_config.yaml", render_mode="rgb_array")

# Initialize SimNet from config
simnet = SimNet(
    state_dim=env.observation_space.shape[0],
    action_dim=env.action_space.shape[0],
    hidden_dims=simnet_config["model"]["hidden_dims"],
).to(device)

# Load the state dictionary
simnet.load_state_dict(torch.load("logs/simnet_training/simnet_final.pt", map_location=device)["state_dict"])

# Set the simnet in the environment
env.simnet = simnet

# Initialize trainer with config
trainer = DDPGTrainer(env=env, config_path="../configs/ddpg_config.yaml")

# Print model summaries
trainer.print_model_info()

  gym.logger.warn(
  gym.logger.warn(


Actor(
  (network): Sequential(
    (0): Linear(in_features=6, out_features=8, bias=True)
    (1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=8, out_features=8, bias=True)
    (4): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
    (5): ReLU()
  )
  (output_layer): Linear(in_features=8, out_features=1, bias=True)
)
Critic(
  (l1): Linear(in_features=7, out_features=256, bias=True)
  (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (hidden_layers): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
  )
  (output_layer): Linear(in_features=256, out_features=1, bias=True)
)


## Training

In [4]:
# Extract training parameters from config
train_config = ddpg_config['training']

# Train agent
history = trainer.train(
    num_episodes=train_config['total_episodes'],
    max_steps=train_config['max_steps_per_episode'],
    batch_size=train_config['batch_size'],
    eval_freq=train_config['eval_frequency'],
    save_freq=train_config['save_frequency'],
    log_dir=log_dir
)

  critic_loss = torch.nn.functional.mse_loss(current_Q, target_Q)
Episodes:   0%|          | 10/2000 [00:06<28:29,  1.16it/s]

Episode 10: Eval reward = 120.69


Episodes:   1%|          | 20/2000 [00:22<42:04,  1.28s/it]  

Episode 20: Eval reward = -173.78


Episodes:   2%|▏         | 30/2000 [00:27<13:52,  2.37it/s]

Episode 30: Eval reward = -2391.51


Episodes:   2%|▏         | 40/2000 [00:30<09:22,  3.49it/s]

Episode 40: Eval reward = -2155.19


Episodes:   2%|▎         | 50/2000 [00:33<11:31,  2.82it/s]

Episode 50: Eval reward = -2553.63


Episodes:   3%|▎         | 60/2000 [00:37<11:32,  2.80it/s]

Episode 60: Eval reward = -2755.43


Episodes:   4%|▎         | 70/2000 [00:40<11:00,  2.92it/s]

Episode 70: Eval reward = -2195.51


Episodes:   4%|▍         | 80/2000 [00:44<12:31,  2.55it/s]

Episode 80: Eval reward = -2518.00


Episodes:   4%|▍         | 90/2000 [00:48<10:51,  2.93it/s]

Episode 90: Eval reward = -2550.93


Episodes:   5%|▌         | 100/2000 [00:51<11:27,  2.76it/s]

Episode 100: Eval reward = -2977.53


Episodes:   6%|▌         | 110/2000 [00:55<11:59,  2.63it/s]

Episode 110: Eval reward = -2588.36


Episodes:   6%|▌         | 120/2000 [00:58<11:03,  2.83it/s]

Episode 120: Eval reward = -2623.31


Episodes:   6%|▋         | 130/2000 [01:02<11:16,  2.76it/s]

Episode 130: Eval reward = -2162.74


Episodes:   7%|▋         | 140/2000 [01:06<09:58,  3.11it/s]

Episode 140: Eval reward = -2395.77


Episodes:   8%|▊         | 150/2000 [01:09<10:56,  2.82it/s]

Episode 150: Eval reward = -2282.70


Episodes:   8%|▊         | 160/2000 [01:14<16:22,  1.87it/s]

Episode 160: Eval reward = -1909.23


Episodes:   8%|▊         | 170/2000 [01:17<12:04,  2.53it/s]

Episode 170: Eval reward = -2764.92


Episodes:   9%|▉         | 180/2000 [01:21<11:49,  2.57it/s]

Episode 180: Eval reward = -2361.71


Episodes:  10%|▉         | 190/2000 [01:25<10:11,  2.96it/s]

Episode 190: Eval reward = -2456.87


Episodes:  10%|█         | 200/2000 [01:28<09:49,  3.05it/s]

Episode 200: Eval reward = -2498.50


Episodes:  10%|█         | 210/2000 [01:32<12:50,  2.32it/s]

Episode 210: Eval reward = -2505.80


Episodes:  11%|█         | 220/2000 [01:36<09:15,  3.21it/s]

Episode 220: Eval reward = -2593.12


Episodes:  12%|█▏        | 230/2000 [01:41<20:55,  1.41it/s]

Episode 230: Eval reward = -2370.30


Episodes:  12%|█▏        | 240/2000 [01:46<11:08,  2.63it/s]

Episode 240: Eval reward = -2249.00


Episodes:  12%|█▎        | 250/2000 [01:49<10:21,  2.82it/s]

Episode 250: Eval reward = -2854.44


Episodes:  13%|█▎        | 260/2000 [01:53<11:59,  2.42it/s]

Episode 260: Eval reward = -1968.11


Episodes:  14%|█▎        | 270/2000 [01:57<11:58,  2.41it/s]

Episode 270: Eval reward = -2251.58


Episodes:  14%|█▍        | 280/2000 [02:03<17:41,  1.62it/s]

Episode 280: Eval reward = -2855.76


Episodes:  14%|█▍        | 290/2000 [02:07<09:54,  2.88it/s]

Episode 290: Eval reward = -2471.36


Episodes:  15%|█▌        | 300/2000 [02:10<10:12,  2.77it/s]

Episode 300: Eval reward = -2396.69


Episodes:  16%|█▌        | 310/2000 [02:14<11:14,  2.50it/s]

Episode 310: Eval reward = -2861.46


Episodes:  16%|█▌        | 320/2000 [02:18<09:31,  2.94it/s]

Episode 320: Eval reward = -2699.78


Episodes:  16%|█▋        | 330/2000 [02:22<11:54,  2.34it/s]

Episode 330: Eval reward = -2241.40


Episodes:  17%|█▋        | 340/2000 [02:26<09:09,  3.02it/s]

Episode 340: Eval reward = -2088.17


Episodes:  18%|█▊        | 350/2000 [02:30<12:29,  2.20it/s]

Episode 350: Eval reward = -2317.15


Episodes:  18%|█▊        | 360/2000 [02:36<14:33,  1.88it/s]

Episode 360: Eval reward = -2479.31


Episodes:  18%|█▊        | 370/2000 [02:41<12:56,  2.10it/s]

Episode 370: Eval reward = -2253.38


Episodes:  19%|█▉        | 380/2000 [02:45<11:05,  2.44it/s]

Episode 380: Eval reward = -3067.23


Episodes:  20%|█▉        | 390/2000 [02:49<10:22,  2.59it/s]

Episode 390: Eval reward = -3141.42


Episodes:  20%|██        | 400/2000 [02:54<12:14,  2.18it/s]

Episode 400: Eval reward = -3013.25


Episodes:  20%|██        | 410/2000 [02:58<11:20,  2.33it/s]

Episode 410: Eval reward = -2814.45


Episodes:  21%|██        | 420/2000 [03:02<09:39,  2.72it/s]

Episode 420: Eval reward = -2558.83


Episodes:  22%|██▏       | 430/2000 [03:06<09:14,  2.83it/s]

Episode 430: Eval reward = -2379.42


Episodes:  22%|██▏       | 440/2000 [03:09<09:23,  2.77it/s]

Episode 440: Eval reward = -2484.00


Episodes:  22%|██▎       | 450/2000 [03:13<10:02,  2.57it/s]

Episode 450: Eval reward = -2436.17


Episodes:  23%|██▎       | 460/2000 [03:17<09:30,  2.70it/s]

Episode 460: Eval reward = -2778.15


Episodes:  24%|██▎       | 470/2000 [03:21<10:37,  2.40it/s]

Episode 470: Eval reward = -2549.50


Episodes:  24%|██▍       | 480/2000 [03:24<09:03,  2.80it/s]

Episode 480: Eval reward = -2774.79


Episodes:  24%|██▍       | 490/2000 [03:30<13:54,  1.81it/s]

Episode 490: Eval reward = -2426.41


Episodes:  25%|██▌       | 500/2000 [03:35<10:37,  2.35it/s]

Episode 500: Eval reward = -2201.44


Episodes:  26%|██▌       | 510/2000 [03:39<10:38,  2.33it/s]

Episode 510: Eval reward = -2700.66


Episodes:  26%|██▌       | 520/2000 [03:43<11:39,  2.12it/s]

Episode 520: Eval reward = -2642.24


Episodes:  26%|██▋       | 530/2000 [03:46<08:00,  3.06it/s]

Episode 530: Eval reward = -2722.78


Episodes:  27%|██▋       | 540/2000 [03:50<08:13,  2.96it/s]

Episode 540: Eval reward = -2256.59


Episodes:  28%|██▊       | 550/2000 [03:54<09:37,  2.51it/s]

Episode 550: Eval reward = -2696.00


Episodes:  28%|██▊       | 560/2000 [03:58<07:49,  3.07it/s]

Episode 560: Eval reward = -2443.17


Episodes:  28%|██▊       | 570/2000 [04:02<10:21,  2.30it/s]

Episode 570: Eval reward = -2310.08


Episodes:  29%|██▉       | 580/2000 [04:06<09:06,  2.60it/s]

Episode 580: Eval reward = -2832.17


Episodes:  30%|██▉       | 590/2000 [04:10<10:24,  2.26it/s]

Episode 590: Eval reward = -2648.71


Episodes:  30%|███       | 600/2000 [04:15<13:39,  1.71it/s]

Episode 600: Eval reward = -2660.93


Episodes:  30%|███       | 610/2000 [04:19<09:05,  2.55it/s]

Episode 610: Eval reward = -3108.96


Episodes:  31%|███       | 620/2000 [04:24<11:00,  2.09it/s]

Episode 620: Eval reward = -2328.83


Episodes:  32%|███▏      | 630/2000 [04:28<09:01,  2.53it/s]

Episode 630: Eval reward = -2293.39


Episodes:  32%|███▏      | 640/2000 [04:33<09:20,  2.43it/s]

Episode 640: Eval reward = -2278.84


Episodes:  32%|███▎      | 650/2000 [04:36<08:12,  2.74it/s]

Episode 650: Eval reward = -2696.73


Episodes:  33%|███▎      | 660/2000 [04:40<09:11,  2.43it/s]

Episode 660: Eval reward = -2349.43


Episodes:  34%|███▎      | 670/2000 [04:44<08:30,  2.61it/s]

Episode 670: Eval reward = -2853.99


Episodes:  34%|███▍      | 680/2000 [04:48<08:05,  2.72it/s]

Episode 680: Eval reward = -1902.90


Episodes:  34%|███▍      | 690/2000 [04:52<08:01,  2.72it/s]

Episode 690: Eval reward = -2834.90


Episodes:  35%|███▌      | 700/2000 [04:57<12:48,  1.69it/s]

Episode 700: Eval reward = -2924.98


Episodes:  36%|███▌      | 710/2000 [05:02<10:08,  2.12it/s]

Episode 710: Eval reward = -2349.44


Episodes:  36%|███▌      | 720/2000 [05:06<08:24,  2.54it/s]

Episode 720: Eval reward = -2844.15


Episodes:  36%|███▋      | 730/2000 [05:10<07:21,  2.88it/s]

Episode 730: Eval reward = -3208.70


Episodes:  37%|███▋      | 740/2000 [05:14<09:51,  2.13it/s]

Episode 740: Eval reward = -2555.66


Episodes:  38%|███▊      | 750/2000 [05:18<08:05,  2.57it/s]

Episode 750: Eval reward = -3181.75


Episodes:  38%|███▊      | 760/2000 [05:22<08:06,  2.55it/s]

Episode 760: Eval reward = -2399.15


Episodes:  38%|███▊      | 764/2000 [05:24<08:45,  2.35it/s]


KeyboardInterrupt: 

## Analysis and Visualization

In [None]:
# Plot training metrics
fig = plot_training_metrics(history, save_path=log_dir / 'training_metrics.png')
plt.show()

In [None]:
# Create demonstration video
def collect_demo_episode(max_steps=500):
    state = env.reset()
    states = []
    total_reward = 0

    for _ in range(max_steps):
        action = trainer.select_action(state, training=False)
        next_state, reward, done, info = env.step(action)
        states.append(state)
        total_reward += reward
        
        if done:
            break
            
        state = next_state
    
    return np.array(states), total_reward

# Collect several episodes and use the best one for visualization
num_episodes = 5
best_reward = float('-inf')
best_states = None

for _ in range(num_episodes):
    states, reward = collect_demo_episode()
    if reward > best_reward:
        best_reward = reward
        best_states = states

print(f"Best episode reward: {best_reward:.2f}")

# Create and display animation
anim = create_episode_animation(
    states=best_states,
    save_path=log_dir / 'demo.mp4'
)
from IPython.display import HTML
HTML(anim.to_jshtml())

## Save Final Model with Metadata

In [None]:
# Save final model with config and training history
torch.save({
    'actor_state_dict': trainer.actor.state_dict(),
    'critic_state_dict': trainer.critic.state_dict(),
    'training_history': history,
    'config': ddpg_config,
    'env_config': env_config,
    'metadata': {
        'state_dim': env.observation_space.shape[0],
        'action_dim': env.action_space.shape[0],
        'max_action': float(env.action_space.high[0]),
        'final_eval_reward': trainer.evaluate(num_episodes=10)
    }
}, log_dir / 'final_model.pt')

print("Training complete! Model saved with metadata.")