# DDPG Training for Balancing Robot

This notebook trains a DDPG agent for the balancing robot environment using the provided configurations.

In [1]:
# Uncomment the following lines to run in Google Colab

# %cd /content
# !git clone https://github.com/EyalPorat/ddpg-balancing-robot.git
# %cd ddpg-balancing-robot
# !git checkout master-organize-python-proj
# %cd /content/ddpg-balancing-robot/python/notebooks

# import sys
# sys.path.append('/content/ddpg-balancing-robot/python')  # Add the repo root to Python path

In [2]:
import sys
sys.path.append('..')

import torch
import numpy as np
from pathlib import Path
import yaml
import matplotlib.pyplot as plt

from src.balancing_robot.models import Actor, Critic, ReplayBuffer, SimNet
from src.balancing_robot.environment import BalancerEnv
from src.balancing_robot.training import DDPGTrainer
from src.balancing_robot.visualization import plot_training_metrics, create_episode_animation

## Load Configurations

In [3]:
# Load DDPG and environment configurations
with open('../configs/ddpg_config.yaml', 'r') as f:
    ddpg_config = yaml.safe_load(f)

with open('../configs/env_config.yaml', 'r') as f:
    env_config = yaml.safe_load(f)

with open('../configs/simnet_config.yaml', 'r') as f:
    simnet_config = yaml.safe_load(f)

# Create log directory
log_dir = Path('logs/ddpg_training')
log_dir.mkdir(parents=True, exist_ok=True)

## Initialize Environment and Models

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Set random seeds from config
torch.manual_seed(ddpg_config["training"].get("random_seed", 42))
np.random.seed(ddpg_config["training"].get("random_seed", 42))

# Create environment
env = BalancerEnv(config_path="../configs/env_config.yaml", render_mode="rgb_array")

# Initialize SimNet from config
simnet = SimNet(
    state_dim=env.observation_space.shape[0],
    action_dim=env.action_space.shape[0],
    hidden_dims=simnet_config["model"]["hidden_dims"],
).to(device)

# Load the state dictionary
simnet.load_state_dict(torch.load("logs/simnet_training/simnet_final.pt", map_location=device)["state_dict"])

# Set the simnet in the environment
env.simnet = simnet

# Initialize trainer with config
trainer = DDPGTrainer(env=env, config_path="../configs/ddpg_config.yaml")

# Print model summaries
trainer.print_model_info()

  gym.logger.warn(
  gym.logger.warn(


state_dim 6
action_dim 1
Actor(
  (network): Sequential(
    (0): Linear(in_features=6, out_features=8, bias=True)
    (1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=8, out_features=8, bias=True)
    (4): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
    (5): ReLU()
  )
  (output_layer): Linear(in_features=8, out_features=1, bias=True)
)
Critic(
  (l1): Linear(in_features=7, out_features=256, bias=True)
  (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (hidden_layers): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
  )
  (output_layer): Linear(in_features=256, out_features=1, bias=True)
)


## Training

In [5]:
# Extract training parameters from config
train_config = ddpg_config['training']

# Train agent
history = trainer.train(
    num_episodes=train_config['total_episodes'],
    max_steps=train_config['max_steps_per_episode'],
    batch_size=train_config['batch_size'],
    eval_freq=train_config['eval_frequency'],
    save_freq=train_config['save_frequency'],
    log_dir=log_dir
)

Training:   0%|          | 0/2000 [00:00<?, ?it/s]

Training:   0%|          | 5/2000 [00:00<04:11,  7.94it/s]

states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])

Training:   0%|          | 7/2000 [00:02<15:02,  2.21it/s]

states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])

Training:   0%|          | 7/2000 [00:05<25:14,  1.32it/s]

states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])
states torch.Size([512, 6])
actions torch.Size([512, 1])
current_Q torch.Size([512, 1])
target_Q torch.Size([512, 1])





KeyboardInterrupt: 

## Analysis and Visualization

In [None]:
# Plot training metrics
fig = plot_training_metrics(history, save_path=log_dir / 'training_metrics.png')
plt.show()

In [None]:
# Create demonstration video
def collect_demo_episode(max_steps=500):
    state, _ = env.reset()
    states = []
    actions = []
    total_reward = 0

    for _ in range(max_steps):
        action = trainer.select_action(state, training=False)
        next_state, reward, done, truncated, info = env.step(action)
        states.append(state)
        actions.append(action)
        total_reward += reward
        
        if done:
            break
            
        state = next_state
    
    return np.array(states), np.array(actions), total_reward

# Collect several episodes and use the best one for visualization
num_episodes = 5
best_reward = float('-inf')
best_states = None
best_actions = None

for _ in range(num_episodes):
    states, actions, reward = collect_demo_episode()
    if reward > best_reward:
        best_reward = reward
        best_states = states
        best_actions = actions

print(f"Best episode reward: {best_reward:.2f}")

# Create and display animation
anim = create_episode_animation(
    states=best_states,
    actions=best_actions,
    save_path=log_dir / 'demo.mp4'
)
from IPython.display import HTML
HTML(anim.to_jshtml())

## Save Final Model with Metadata

In [None]:
# Save final model with config and training history
torch.save({
    'actor_state_dict': trainer.actor.state_dict(),
    'critic_state_dict': trainer.critic.state_dict(),
    'training_history': history,
    'config': ddpg_config,
    'env_config': env_config,
    'metadata': {
        'state_dim': env.observation_space.shape[0],
        'action_dim': env.action_space.shape[0],
        'max_action': float(env.action_space.high[0]),
        'final_eval_reward': trainer.evaluate(num_episodes=10)
    }
}, log_dir / 'final_model.pt')

print("Training complete! Model saved with metadata.")