In [1]:
pip install stable-baselines3[extra] gymnasium numpy matplotlib


Collecting opencv-python (from stable-baselines3[extra])
  Using cached opencv_python-4.11.0.86-cp37-abi3-win_amd64.whl (39.5 MB)
Collecting tensorboard>=2.9.1 (from stable-baselines3[extra])
  Using cached tensorboard-2.19.0-py3-none-any.whl (5.5 MB)
Collecting rich (from stable-baselines3[extra])
  Using cached rich-13.9.4-py3-none-any.whl (242 kB)
Collecting ale-py>=0.9.0 (from stable-baselines3[extra])
  Downloading ale_py-0.10.2-cp311-cp311-win_amd64.whl (1.5 MB)
                                              0.0/1.5 MB ? eta -:--:--
                                              0.0/1.5 MB 682.7 kB/s eta 0:00:03
                                              0.0/1.5 MB 435.7 kB/s eta 0:00:04
     -                                        0.1/1.5 MB 469.7 kB/s eta 0:00:04
     ---                                      0.1/1.5 MB 731.4 kB/s eta 0:00:02
     ---                                      0.1/1.5 MB 657.1 kB/s eta 0:00:03
     -----                                    0.2/1.5 MB


[notice] A new release of pip is available: 23.1.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces

class SpacecraftEnv(gym.Env):
    def __init__(self):
        super(SpacecraftEnv, self).__init__()
        
        # Continuous state space (position, velocity, angle, angular velocity)
        self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float32)

        # Continuous action space (thrust in X and Y directions)
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
        
        self.state = np.zeros(4)  # [x, y, vx, vy]
        self.target = np.array([0.5, 0.5])  # Target position in space
        self.time_step = 0

    def step(self, action):
        thrust_x, thrust_y = action  # Continuous thrust in X and Y

        # Update spacecraft velocity with thrust
        self.state[2] += thrust_x * 0.01  # Small acceleration factor
        self.state[3] += thrust_y * 0.01

        # Update position with velocity
        self.state[0] += self.state[2]
        self.state[1] += self.state[3]

        # Compute distance to target
        distance = np.linalg.norm(self.state[:2] - self.target)
        
        # Reward function: Closer to target is better
        reward = -distance
        
        # Check if spacecraft reaches the target
        done = distance < 0.05 or self.time_step > 500

        self.time_step += 1
        return self.state, reward, done, False, {}

    def reset(self, seed=None, options=None):
        self.state = np.random.uniform(-1, 1, size=(4,))
        self.time_step = 0
        return self.state, {}

    def render(self):
        print(f"Position: {self.state[:2]}, Velocity: {self.state[2:]}")



In [6]:
from stable_baselines3 import SAC

# Create the environment
env = SpacecraftEnv()

# Train the agent using Soft Actor-Critic (SAC)
model = SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=100000)

# Save the trained model
model.save("spacecraft_sac")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 502       |
|    ep_rew_mean     | -1.74e+05 |
| time/              |           |
|    episodes        | 4         |
|    fps             | 56        |
|    time_elapsed    | 35        |
|    total_timesteps | 2008      |
| train/             |           |
|    actor_loss      | 2.98e+03  |
|    critic_loss     | 817       |
|    ent_coef        | 0.846     |
|    ent_coef_loss   | 0.635     |
|    learning_rate   | 0.0003    |
|    n_updates       | 1907      |
----------------------------------
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 502       |
|    ep_rew_mean     | -1.31e+05 |
| time/              |           |
|    episodes        | 8         |
|    fps             | 49        |
|    time_elapsed    | 80        |
|    total_timesteps | 4016    

KeyboardInterrupt: 