Dependency Installation

In [1]:
!pip install stable-baselines3 tensorboard gym pygame


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


Environment Loading

In [2]:
import importlib
import env
importlib.reload(env)

from env import droneEnv
env = droneEnv(render_mode="none")


pygame 2.6.1 (SDL 2.28.4, Python 3.10.19)
Hello from the pygame community. https://www.pygame.org/contribute.html


Setting up the Stable-Baseline for training both A2C and DQN

In [3]:
from stable_baselines3.common.monitor import Monitor

logdir = "./logs/a2c"
env = Monitor(droneEnv(render_mode="none"), logdir)


A2C Training

In [29]:
from stable_baselines3 import A2C

model_a2c = A2C(
    "MlpPolicy",
    env,
    learning_rate=3e-4,
    gamma=0.99,
    tensorboard_log="./DroneLog"
)

model_a2c.learn(total_timesteps=300_000)
model_a2c.save("A2C_drone")


In [4]:
from stable_baselines3 import A2C
from env import droneEnv

env = droneEnv(render_mode="human")
model_a2c = A2C.load("A2C_drone", env=env)


In [None]:
obs, info = env.reset()
terminated = False
truncated = False

while not (terminated or truncated):
    action, _ = model_a2c.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    env.render()

env.close()

A2C Demo Video Generation

In [None]:
import imageio
import numpy as np
from stable_baselines3 import A2C
from env import droneEnv


env = droneEnv(render_mode="rgb_array")  # important for video recording

model = A2C.load("A2C_drone", env=env)

video_path = "a2c_drone_demo.mp4"
fps = 30
writer = imageio.get_writer(video_path, fps=fps)

obs, _ = env.reset()
done = False
truncated = False

while not (done or truncated):
    # Predict action
    action, _ = model.predict(obs, deterministic=True)
    
    # Step environment
    obs, reward, done, truncated, info = env.step(action)

    # Capture frame
    frame = env.render()  # returns RGB frame when rgb_array mode is used
    writer.append_data(frame)

writer.close()
env.close()

print(f"Video saved at: {video_path}")


DQN Training

In [None]:
from stable_baselines3 import DQN

model_dqn = DQN(
    "MlpPolicy",
    env,
    learning_rate=3e-4,
    gamma=0.99,
    buffer_size=100_000,      # Replay buffer size
    batch_size=64,            # Mini-batch size
    learning_starts=1_000,    # Start learning after some experience
    target_update_interval=500, # How often to update target network
    exploration_initial_eps=1.0,
    exploration_final_eps=0.05,
    exploration_fraction=0.3,
    tensorboard_log="./DroneLog"
)

model_dqn.learn(total_timesteps=300_000)
model_dqn.save("DQN_drone")


In [None]:
from stable_baselines3 import DQN
from env import droneEnv

# Create environment (same settings used during training)
env = droneEnv(render_mode="human")

# Load the trained DQN model
model_dqn = DQN.load("DQN_drone", env=env)

# Run a single episode
obs, _ = env.reset()
done = False

while not done:
    action, _ = model_dqn.predict(obs, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)
    env.render()

env.close()


DQN Video Demo

In [None]:
env = droneEnv(render_mode="rgb_array")  

# Load trained DQN model
model = DQN.load("DQN_drone", env=env)

video_path = "dqn_drone_demo.mp4"
fps = 30
writer = imageio.get_writer(video_path, fps=fps)

obs, _ = env.reset()
done = False
truncated = False

while not (done or truncated):
    # Predict action
    action, _ = model.predict(obs, deterministic=True)
    
    # Step environment
    obs, reward, done, truncated, info = env.step(action)

    # Capture frame
    frame = env.render()  # returns RGB frame when rgb_array mode is used
    writer.append_data(frame)

writer.close()
env.close()

print(f"Video saved at: {video_path}")

Data Monitoring through TensorBoard

In [24]:
%load_ext tensorboard
%tensorboard --logdir DroneLog


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6007 (pid 20521), started 0:56:49 ago. (Use '!kill 20521' to kill it.)