In [1]:
import torch 
from torch import nn

import ray
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override

#from models import VisualEncoder
from train import *



In [2]:
class VisualEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=8, stride=4, padding=0),  
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), 
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), 
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(), 
            nn.Conv2d(64, 64, kernel_size=2, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

    def forward(self, x):
        return self.cnn(x)

In [3]:
class MyModelClass(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        features_dim = 64
        self.encoder = VisualEncoder()
        self.encoder.load_state_dict(
            torch.load("/IGLU-Minecraft/models/AlinaCNN/encoder_weigths.pth", map_location=torch.device('cpu'))
        )
        self.gru
        self.action_head = nn.Linear(features_dim, action_space.n)
        self.value_head = nn.Linear(features_dim, 1)
        self.last_value = None
        
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.encoder.cuda()
            self.action_head.cuda()
            self.value_head.cuda()
        
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict['obs'].permute(0, 3, 1, 2).float() / 255.0
        if self.use_cuda:
            obs.cuda()
            
        features = self.encoder(obs)
        #action = self.action_head(features)
        #self.last_value = self.value_head(features).squeeze(1)
        return features, state
    
    @override(TorchModelV2)
    def value_function(self):
        assert self.last_value is not None, "must call forward() first"
        return self.last_value

In [4]:
ModelCatalog.register_custom_model("my_torch_model", MyModelClass)

In [5]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

def env_creator(env_config):
    env = gym.make('IGLUSilentBuilder-v0', max_steps=1000)
    env.update_taskset(TaskSet(preset=['C17']))
    env = PovOnlyWrapper(env)
    env = IgluActionWrapper(env)
    return env

from ray.tune.registry import register_env
register_env("my_env", env_creator)

from ray import tune
from ray.rllib.agents.ppo import PPOTrainer

In [6]:
from ray.tune.integration.wandb import WandbLogger

tune.run(PPOTrainer, 
         config={
             "env": "my_env", 
             "framework": "torch",
             "num_gpus": 1,
             "num_workers": 1,
             "sgd_minibatch_size": 256,
             "clip_param": 0.2,
             "entropy_coeff": 0.01,
             "lambda": 0.95,
             "train_batch_size": 1000,
             "model": {
                     # Auto-wrap the custom(!) model with an LSTM.
                    "use_lstm": True,
                    # To further customize the LSTM auto-wrapper.
                    "lstm_cell_size": 64,
                    # Specify our custom model from above.
                    "custom_model": "my_torch_model",
                    # Extra kwargs to be passed to your model's c'tor.
                    "custom_model_config": {},
              },
             "logger_config": {
                  "wandb": {
                      "project": "IGLU-Minecraft",
                      "name": "PPO C17 pretrained (AlinaCNN)"
                  }
              }

        },
        loggers=[WandbLogger])



Trial name,status,loc
PPO_my_env_c7a10_00000,PENDING,


2021-09-27 11:26:38,439	INFO wandb.py:170 -- Already logged into W&B.
2021-09-27 11:26:38,448	ERROR syncer.py:72 -- Log sync requires rsync to be installed.
[34m[1mwandb[0m: Currently logged in as: [33mlinar[0m (use `wandb login --relogin` to force relogin)


[2m[36m(pid=165)[0m 2021-09-27 11:26:41,855	INFO ppo.py:159 -- In multi-agent mode, policies will be optimized sequentially by the multi-GPU optimizer. Consider setting simple_optimizer=True if this doesn't work for you.
[2m[36m(pid=165)[0m 2021-09-27 11:26:41,855	INFO trainer.py:728 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
2021-09-27 11:26:45,427	ERROR trial_runner.py:773 -- Trial PPO_my_env_c7a10_00000: Error processing event.
Traceback (most recent call last):
  File "/root/miniconda/envs/py37/lib/python3.7/site-packages/ray/tune/trial_runner.py", line 739, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/root/miniconda/envs/py37/lib/python3.7/site-packages/ray/tune/ray_trial_executor.py", line 746, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/root/miniconda/envs/py37/lib/python3.7/site-packages/ray/_private/client_mode_hook.p

Result for PPO_my_env_c7a10_00000:
  {}
  


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Trial name,status,loc
PPO_my_env_c7a10_00000,ERROR,

Trial name,# failures,error file
PPO_my_env_c7a10_00000,1,/root/ray_results/PPO_2021-09-27_11-26-38/PPO_my_env_c7a10_00000_0_2021-09-27_11-26-38/error.txt


Trial name,status,loc
PPO_my_env_c7a10_00000,ERROR,

Trial name,# failures,error file
PPO_my_env_c7a10_00000,1,/root/ray_results/PPO_2021-09-27_11-26-38/PPO_my_env_c7a10_00000_0_2021-09-27_11-26-38/error.txt


TuneError: ('Trials did not complete', [PPO_my_env_c7a10_00000])

In [15]:
rnn = nn.GRU(64, 64, 1)

In [16]:
sum(p.numel() for p in rnn.parameters())

24960

In [12]:
net = VisualEncoder()
sum(p.numel() for p in net.parameters())

129312