In [5]:
import time
import datetime
import torch_ac
import tensorboardX
import sys

import utils
from utils import device
from model import ACModel

# Set the values for the arguments
args = {
    "algo": "a2c",
    "env": "MiniGrid-DoorKey-5x5-v0",
    "model": None,
    "seed": 1,
    "log_interval": 1,
    "save_interval": 10,
    "procs": 16,
    "frames": 10**7,
    "epochs": 4,
    "batch_size": 256,
    "frames_per_proc": None,
    "discount": 0.99,
    "lr": 0.001,
    "gae_lambda": 0.95,
    "entropy_coef": 0.01,
    "value_loss_coef": 0.5,
    "max_grad_norm": 0.5,
    "optim_eps": 1e-8,
    "optim_alpha": 0.99,
    "clip_eps": 0.2,
    "recurrence": 1,
    "text": False
}

# Set run dir

date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
default_model_name = f"{args['env']}_{args['algo']}_seed{args['seed']}_{date}"

model_name = args['model'] or default_model_name
model_dir = utils.get_model_dir(model_name)

# Load loggers and Tensorboard writer

txt_logger = utils.get_txt_logger(model_dir)
csv_file, csv_logger = utils.get_csv_logger(model_dir)
tb_writer = tensorboardX.SummaryWriter(model_dir)

# Log command and all script arguments

txt_logger.info("{}\n".format(" ".join(sys.argv)))
txt_logger.info("{}\n".format(args))

# Set seed for all randomness sources

utils.seed(args['seed'])

# Set device

txt_logger.info(f"Device: {device}\n")

# Load environments

envs = []
for i in range(args['procs']):
    envs.append(utils.make_env(args['env'], args['seed'] + 10000 * i))
txt_logger.info("Environments loaded\n")

# Load training status

try:
    status = utils.get_status(model_dir)
except OSError:
    status = {"num_frames": 0, "update": 0}
txt_logger.info("Training status loaded\n")

# Load observations preprocessor

obs_space, preprocess_obss = utils.get_obss_preprocessor(envs[0].observation_space)
if "vocab" in status:
    preprocess_obss.vocab.load_vocab(status["vocab"])
txt_logger.info("Observations preprocessor loaded")

# Load model

acmodel = ACModel(obs_space, envs[0].action_space, args['recurrence'], args['text'])
if "model_state" in status:
    acmodel.load_state_dict(status["model_state"])
acmodel.to(device)
txt_logger.info("Model loaded\n")
txt_logger.info("{}\n".format(acmodel))

# Load algo

if args['algo'] == "a2c":
    algo = torch_ac.A2CAlgo(envs, acmodel, device, args['frames_per_proc'], args['discount'], args['lr'], args['gae_lambda'],
                            args['entropy_coef'], args['value_loss_coef'], args['max_grad_norm'], args['recurrence'],
                            args['optim_alpha'], args['optim_eps'], preprocess_obss)
elif args['algo'] == "ppo":
    algo = torch_ac.PPOAlgo(envs, acmodel, device, args['frames_per_proc'], args['discount'], args['lr'], args['gae_lambda'],
                            args['entropy_coef'], args['value_loss_coef'], args['max_grad_norm'], args['recurrence'],
                            args['optim_eps'], args['clip_eps'], args['epochs'], args['batch_size'], preprocess_obss)
else:
    raise ValueError("Incorrect algorithm name: {}".format(args['algo']))

if "optimizer_state" in status:
    algo.optimizer.load_state_dict(status["optimizer_state"])
txt_logger.info("Optimizer loaded\n")

# Train model

num_frames = status["num_frames"]
update = status["update"]
start_time = time.time()

while num_frames < args['frames']:
    # Update model parameters
    update_start_time = time.time()
    exps, logs1 = algo.collect_experiences()
    logs2 = algo.update_parameters(exps)
    logs = {**logs1, **logs2}
    update_end_time = time.time()

    num_frames += logs["num_frames"]
    update += 1

    # Print logs

    if update % args['log_interval'] == 0:
        fps = logs["num_frames"] / (update_end_time - update_start_time)
        duration = int(time.time() - start_time)
        return_per_episode = utils.synthesize(logs["return_per_episode"])
        rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"])
        num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])

        header = ["update", "frames", "FPS", "duration"]
        data = [update, num_frames, fps, duration]
        header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
        data += rreturn_per_episode.values()
        header += ["num_frames_" + key for key in num_frames_per_episode.keys()]
        data += num_frames_per_episode.values()
        header += ["entropy", "value", "policy_loss", "value_loss", "grad_norm"]
        data += [logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"], logs["grad_norm"]]

        txt_logger.info(
            "U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f}"
            .format(*data))

        header += ["return_" + key for key in return_per_episode.keys()]
        data += return_per_episode.values()

        if status["num_frames"] == 0:
            csv_logger.writerow(header)
        csv_logger.writerow(data)
        csv_file.flush()

        for field, value in zip(header, data):
            tb_writer.add_scalar(field, value, num_frames)

        # Save status

    if args['save_interval'] > 0 and update % args['save_interval'] == 0:
        status = {"num_frames": num_frames, "update": update,
                  "model_state": acmodel.state_dict(), "optimizer_state": algo.optimizer.state_dict()}
        if hasattr(preprocess_obss, "vocab"):
            status["vocab"] = preprocess_obss.vocab.vocab
        utils.save_status(status, model_dir)
        txt_logger.info("Status saved")

/home/ben/miniconda3/envs/minigrid/lib/python3.7/site-packages/ipykernel_launcher.py -f /home/ben/.local/share/jupyter/runtime/kernel-7800215d-0ad8-4f85-8c57-55d871c7f51d.json

{'algo': 'a2c', 'env': 'MiniGrid-DoorKey-5x5-v0', 'model': None, 'seed': 1, 'log_interval': 1, 'save_interval': 10, 'procs': 16, 'frames': 10000000, 'epochs': 4, 'batch_size': 256, 'frames_per_proc': None, 'discount': 0.99, 'lr': 0.001, 'gae_lambda': 0.95, 'entropy_coef': 0.01, 'value_loss_coef': 0.5, 'max_grad_norm': 0.5, 'optim_eps': 1e-08, 'optim_alpha': 0.99, 'clip_eps': 0.2, 'recurrence': 1, 'text': False}

Device: cuda

Environments loaded

Training status loaded

Observations preprocessor loaded
Model loaded

ACModel(
  (image_conv): Sequential(
    (0): Conv2d(3, 16, kernel_size=(2, 2), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(2, 2), stride=(1, 1))
    (4): ReLU()
    (5): Conv2d(32, 6

U 48 | F 006144 | FPS 2677 | D 2 | rR:μσmM 0.10 0.21 0.00 0.66 | F:μσmM 227.8 47.7 94.0 250.0 | H 1.942 | V -0.055 | pL 0.001 | vL 0.000 | ∇ 0.006
U 49 | F 006272 | FPS 2641 | D 2 | rR:μσmM 0.10 0.21 0.00 0.66 | F:μσmM 227.8 47.7 94.0 250.0 | H 1.942 | V -0.049 | pL 0.013 | vL 0.000 | ∇ 0.035
U 50 | F 006400 | FPS 2642 | D 2 | rR:μσmM 0.07 0.19 0.00 0.66 | F:μσmM 234.4 42.7 94.0 250.0 | H 1.943 | V 0.048 | pL -0.005 | vL 0.000 | ∇ 0.011
Status saved
U 51 | F 006528 | FPS 2436 | D 2 | rR:μσmM 0.04 0.16 0.00 0.66 | F:μσmM 240.2 37.8 94.0 250.0 | H 1.941 | V 0.020 | pL 0.001 | vL 0.000 | ∇ 0.006
U 52 | F 006656 | FPS 2682 | D 2 | rR:μσmM 0.00 0.00 0.00 0.00 | F:μσmM 250.0 0.0 250.0 250.0 | H 1.941 | V 0.030 | pL -0.008 | vL 0.000 | ∇ 0.023
U 53 | F 006784 | FPS 2886 | D 2 | rR:μσmM 0.00 0.00 0.00 0.00 | F:μσmM 250.0 0.0 250.0 250.0 | H 1.940 | V -0.028 | pL -0.003 | vL 0.000 | ∇ 0.009
U 54 | F 006912 | FPS 2305 | D 2 | rR:μσmM 0.00 0.00 0.00 0.00 | F:μσmM 250.0 0.0 250.0 250.0 | H 1.942 |

U 104 | F 013312 | FPS 2883 | D 5 | rR:μσmM 0.06 0.17 0.00 0.58 | F:μσmM 236.4 37.1 116.0 250.0 | H 1.944 | V -0.025 | pL -0.000 | vL 0.000 | ∇ 0.002
U 105 | F 013440 | FPS 2759 | D 5 | rR:μσmM 0.06 0.17 0.00 0.58 | F:μσmM 236.4 37.1 116.0 250.0 | H 1.944 | V -0.027 | pL 0.001 | vL 0.000 | ∇ 0.005
U 106 | F 013568 | FPS 2712 | D 5 | rR:μσmM 0.06 0.17 0.00 0.58 | F:μσmM 236.4 37.1 116.0 250.0 | H 1.945 | V -0.014 | pL -0.000 | vL 0.000 | ∇ 0.002
U 107 | F 013696 | FPS 2782 | D 5 | rR:μσmM 0.06 0.17 0.00 0.58 | F:μσmM 236.4 37.1 116.0 250.0 | H 1.945 | V -0.014 | pL 0.004 | vL 0.000 | ∇ 0.012
U 108 | F 013824 | FPS 2565 | D 5 | rR:μσmM 0.06 0.17 0.00 0.58 | F:μσmM 236.4 37.1 116.0 250.0 | H 1.944 | V 0.023 | pL -0.004 | vL 0.000 | ∇ 0.010
U 109 | F 013952 | FPS 2730 | D 5 | rR:μσmM 0.03 0.10 0.00 0.40 | F:μσmM 244.8 20.3 166.0 250.0 | H 1.944 | V -0.011 | pL -0.001 | vL 0.000 | ∇ 0.004
U 110 | F 014080 | FPS 2589 | D 5 | rR:μσmM 0.06 0.16 0.00 0.53 | F:μσmM 237.2 34.3 130.0 250.0 | H 1.9

U 159 | F 020352 | FPS 2639 | D 8 | rR:μσmM 0.04 0.12 0.00 0.47 | F:μσmM 242.8 25.1 146.0 250.0 | H 1.928 | V 0.332 | pL -0.026 | vL 0.000 | ∇ 0.072
U 160 | F 020480 | FPS 2414 | D 8 | rR:μσmM 0.04 0.12 0.00 0.47 | F:μσmM 242.8 25.1 146.0 250.0 | H 1.929 | V 0.113 | pL -0.017 | vL 0.000 | ∇ 0.046
Status saved
U 161 | F 020608 | FPS 2560 | D 8 | rR:μσmM 0.04 0.12 0.00 0.47 | F:μσmM 242.8 25.1 146.0 250.0 | H 1.931 | V -0.033 | pL -0.003 | vL 0.000 | ∇ 0.007
U 162 | F 020736 | FPS 2733 | D 8 | rR:μσmM 0.04 0.12 0.00 0.47 | F:μσmM 242.8 25.1 146.0 250.0 | H 1.934 | V -0.055 | pL 0.005 | vL 0.000 | ∇ 0.015
U 163 | F 020864 | FPS 2530 | D 8 | rR:μσmM 0.04 0.12 0.00 0.47 | F:μσmM 242.8 25.1 146.0 250.0 | H 1.937 | V -0.010 | pL 0.001 | vL 0.000 | ∇ 0.003
U 164 | F 020992 | FPS 2594 | D 8 | rR:μσmM 0.04 0.12 0.00 0.47 | F:μσmM 242.8 25.1 146.0 250.0 | H 1.937 | V -0.001 | pL 0.001 | vL 0.000 | ∇ 0.004
U 165 | F 021120 | FPS 2884 | D 8 | rR:μσmM 0.04 0.12 0.00 0.47 | F:μσmM 242.8 25.1 146.0 25

Process Process-48:
Process Process-55:
Process Process-56:
Process Process-54:


KeyboardInterrupt: 

Process Process-51:
Process Process-47:
Process Process-53:
Process Process-50:
Traceback (most recent call last):
Process Process-46:
Process Process-57:
Process Process-52:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Process Process-60:
Process Process-59:
Process Process-49:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ben/miniconda3/envs/minigrid/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Process Process-58:
  File "/home/ben/miniconda3/envs/minigrid/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/ben/miniconda3/envs/minigrid/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/ben/miniconda3/envs/minigrid/lib/python3.

  File "/home/ben/miniconda3/envs/minigrid/lib/python3.7/multiprocessing/connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "/home/ben/miniconda3/envs/minigrid/lib/python3.7/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
  File "/home/ben/miniconda3/envs/minigrid/lib/python3.7/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/home/ben/miniconda3/envs/minigrid/lib/python3.7/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/home/ben/miniconda3/envs/minigrid/lib/python3.7/multiprocessing/connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "/home/ben/miniconda3/envs/minigrid/lib/python3.7/site-packages/torch_ac/utils/penv.py", line 9, in worker
    cmd, data = conn.recv()
  File "/home/ben/miniconda3/envs/minigrid/lib/python3.7/site-packages/gymnasium/wrappers/env_checker.py", line 49, in step
    return