In [1]:
import os
import sys

import numpy as np

module_path = os.path.abspath(os.path.join('...'))
if module_path not in sys.path:
    sys.path.append(module_path)

from instructor.environment import FindAllShapesEnv
from instructor.callback import IoUCallback
from datasets.shapes import generate_image

%load_ext autoreload
%autoreload 2

In [5]:
from datasets.shapes import Triangle

for i in range(5):
    entry = generate_image(8, create_mask=False, combinations=[(Triangle, c) for c in ["red", "green", "blue", "purple"]], scale=60)
    entry.image.save(f"triangles_{i}.png")

In [2]:
import random
import torch

SEED = 57


def set_random_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

In [None]:
%%time
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import DQN, PPO, A2C
from stable_baselines3.common.vec_env import SubprocVecEnv

set_random_seed(SEED)
make_env = lambda: FindAllShapesEnv(lambda: generate_image(30, False, scale=1).shapes)
vec_env = make_vec_env(make_env, n_envs=1024)
model = DQN("MlpPolicy", vec_env, verbose=0,
            # gradient_steps=-1,
            device="cpu",
            policy_kwargs={'net_arch': [64, 64]},
            tensorboard_log='tb_test'
            ).learn(450000, progress_bar=False, tb_log_name='sparse_ppo', callback=IoUCallback())

In [None]:
ious = []
rewards = []
env = make_env()
set_random_seed(SEED)
obs = env.reset()
for _ in range(30000):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    rewards.append(reward)
    if done:
        obs = env.reset()
        if 'iou' in info:
            ious.append(info['iou'])
print(f"Mean IoU: {np.mean(ious)}\nMean reward: {np.mean(rewards)}")

In [None]:
import matplotlib.pyplot as plt

plt.plot(ious)

In [None]:
# Test the trained agent
env = FindAllShapesEnv(lambda: generate_image(8, False, scale=1).shapes)
obs = env.reset()

print(", ".join(map(str, env.shapes)))

n_steps = 70
for step in range(n_steps):
    action, _ = model.predict(obs, deterministic=True)
    print(f"Step {step + 1}")
    print("Action: ", env.action_dict[action.item()])
    obs, reward, done, info = env.step(action)
    print('obs=', obs, 'reward=', reward, 'done=', done)
    env.render(mode='console')
    if done:
        # Note that the VecEnv resets automatically
        # when a done signal is encountered
        print("Goal reached!", "reward=", reward)
        break

In [None]:
model.policy