Using a test Model

In [1]:
from stable_baselines3 import PPO
import gymnasium as gym
import assembly_game
from assembly_game.processor import PROCESSOR_ACTIONS, actions_to_asm

In [2]:
MAX_STEPS = 30
env = gym.make("MinGame", max_episode_steps=MAX_STEPS, size = 4)

In [15]:
from stable_baselines3.common.callbacks import BaseCallback
import numpy as np

class BestTrajectoryCallback(BaseCallback):
    def __init__(self, save_path, verbose=0):
        super(BestTrajectoryCallback, self).__init__(verbose)
        self.save_path = save_path
        self.best_reward = -np.inf
        self.current_ep_actions = []
        self.current_ep_rewards = []
        self.results = []

    def _on_step(self) -> bool:
        # Save obs and action at every step
        self.current_ep_actions.append(self.locals['actions'])

        # Info contains "episode" key at the end of an episode
        infos = self.locals['infos']
        rewards = self.locals['rewards']
        self.current_ep_rewards.append(rewards)

        for info in infos:
            if 'episode' in info:
                ep_reward = sum(self.current_ep_rewards)
                if ep_reward > self.best_reward:
                    self.best_reward = ep_reward
                    self._save_trajectory()

                self.current_ep_rewards = []
                self.current_ep_actions = []
                break

        return True
        
    def _save_trajectory(self):
        print(f"New best trajectory found with reward: {self.best_reward}")
        self.results.append((self.num_timesteps, [action[0] for action in self.current_ep_actions], self.best_reward))
    
    def _on_training_end(self):
        # Save the best trajectory to a file
        with open(self.save_path, "w") as f:
            for (timestep, actions, reward) in self.results:
                f.write(f"Timestep: {timestep}, Len: {len(actions)}, Reward: {reward}\n")
                f.write(actions_to_asm(actions))
                f.write("\n\n")
    
callback = BestTrajectoryCallback(verbose=1, save_path="results.txt")

In [None]:
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=100000, callback=callback)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
New best trajectory found with reward: [0.]
New best trajectory found with reward: [30.]
New best trajectory found with reward: [60.]
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 30       |
|    ep_rew_mean     | 20.6     |
|    success_rate    | 0        |
| time/              |          |
|    fps             | 2627     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 30          |
|    ep_rew_mean          | 21.2        |
|    success_rate         | 0           |
| time/                   |             |
|    fps                  | 1947        |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps

<stable_baselines3.ppo.ppo.PPO at 0x7fcd2a7b6510>

In [50]:
state, _ = env.reset()
cumreward = 0
for i in range(MAX_STEPS):
  action,_ = model.predict(state)
  state, reward, terminated, truncated, info = env.step(action)
  cumreward +=reward
  print(PROCESSOR_ACTIONS[action], info, reward)
  if terminated or truncated:
    print(terminated)
    print(truncated)
    print(f"Episode finished after {i+1} timestamps")
    break
print(f"total reward {cumreward}")

(<Instruction.MOV: 0>, <Operand.RSI: 1>, <Operand.RAX: 4>) {'example_0': 'rdi=1 rsi=2 rax=2 rdx=3 rcx=4 cmp_res=0', 'example_1': 'rdi=1 rsi=2 rax=2 rdx=4 rcx=3 cmp_res=0', 'example_2': 'rdi=1 rsi=3 rax=3 rdx=2 rcx=4 cmp_res=0', 'example_3': 'rdi=1 rsi=3 rax=3 rdx=4 rcx=2 cmp_res=0', 'example_4': 'rdi=1 rsi=4 rax=4 rdx=2 rcx=3 cmp_res=0', 'example_5': 'rdi=1 rsi=4 rax=4 rdx=3 rcx=2 cmp_res=0', 'example_6': 'rdi=2 rsi=1 rax=1 rdx=3 rcx=4 cmp_res=0', 'example_7': 'rdi=2 rsi=1 rax=1 rdx=4 rcx=3 cmp_res=0', 'example_8': 'rdi=2 rsi=3 rax=3 rdx=1 rcx=4 cmp_res=0', 'example_9': 'rdi=2 rsi=3 rax=3 rdx=4 rcx=1 cmp_res=0', 'example_10': 'rdi=2 rsi=4 rax=4 rdx=1 rcx=3 cmp_res=0', 'example_11': 'rdi=2 rsi=4 rax=4 rdx=3 rcx=1 cmp_res=0', 'example_12': 'rdi=3 rsi=1 rax=1 rdx=2 rcx=4 cmp_res=0', 'example_13': 'rdi=3 rsi=1 rax=1 rdx=4 rcx=2 cmp_res=0', 'example_14': 'rdi=3 rsi=2 rax=2 rdx=1 rcx=4 cmp_res=0', 'example_15': 'rdi=3 rsi=2 rax=2 rdx=4 rcx=1 cmp_res=0', 'example_16': 'rdi=3 rsi=4 rax=4 rdx=1