In [1]:
%pip install numpy==1.24.0 stable-baselines3==1.5.0
%pip install 'shimmy>=2.0'

In [2]:
import gym
from stable_baselines3 import PPO, SAC

class ReinforcementLearningAgent:
    def __init__(self, env_id: str):
        self.env = gym.make(env_id)
        self.ppo_model = PPO('MlpPolicy', self.env, verbose=1)
        self.sac_model = SAC('MlpPolicy', self.env, verbose=1)

    def train_ppo(self, total_timesteps: int):
        self.ppo_model.learn(total_timesteps=total_timesteps)

    def train_sac(self, total_timesteps: int):
        self.sac_model.learn(total_timesteps=total_timesteps)

    def predict_ppo(self, observation):
        return self.ppo_model.predict(observation)

    def predict_sac(self, observation):
        return self.sac_model.predict(observation)

if __name__ == "__main__":
    agent = ReinforcementLearningAgent('Pendulum-v1')
    agent.train_ppo(total_timesteps=10000)
    agent.train_sac(total_timesteps=10000)
    observation = agent.env.reset()
    print('PPO Prediction:', agent.predict_ppo(observation))
    print('SAC Prediction:', agent.predict_sac(observation))