In [8]:
import gym
import numpy as np
from gym import spaces
from stable_baselines3 import SAC

class SpacecraftEnv(gym.Env):
    def __init__(self):
        super(SpacecraftEnv, self).__init__()

        # Define continuous state space (spacecraft's position)
        self.observation_space = spaces.Box(low=-1.5, high=1.5, shape=(2,), dtype=np.float32)

        # Define continuous action space (movement in x, y)
        self.action_space = spaces.Box(low=-0.5, high=0.5, shape=(2,), dtype=np.float32)  # Faster movement

        self.state = np.array([0.0, 0.0])  # Initial position
        self.target = np.array([1.0, 1.0])  # Target position
        self.history = []  # For visualization

    def step(self, action):
        self.state += action  # Move spacecraft
        distance = np.linalg.norm(self.state - self.target)

        # Reward: Closer is better
        reward = -distance  # Negative reward when far from the target

        # Episode ends if spacecraft is close to target
        done = distance < 0.05  

        self.history.append(self.state.copy())  # Store trajectory
        return self.state, reward, done, {}, {}

    def reset(self):
        self.state = np.array([0.0, 0.0])  # Reset spacecraft position
        self.history = [self.state.copy()]
        return self.state, {}

    def get_trajectory(self):
        return np.array(self.history)  # For visualization

# Initialize environment
env = SpacecraftEnv()

# Use optimized SAC settings
model = SAC(
    "MlpPolicy",
    env,
    verbose=1,
    train_freq=(1, "episode"),  # Train only once per episode (faster)
    batch_size=32,  # Smaller batch size (less memory)
    buffer_size=10000,  # Smaller replay buffer (faster training)
    learning_starts=100,  # Start learning quickly
)

# Reduce number of episodes for faster results
num_episodes = 20  

for episode in range(num_episodes):
    obs, _ = env.reset()
    done = False
    while not done:
        action, _ = model.predict(obs)  # Predict action
        obs, reward, done, _, _ = env.step(action)

trajectory = env.get_trajectory()


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




KeyboardInterrupt: 