In [1]:
import torch 
from torch import nn

import ray
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override

#from models import VisualEncoder
from train import *
from wrappers_2 import *



In [2]:
class VisualEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=2, stride=2, padding=0),  
            nn.ELU(),
            nn.Conv2d(32, 32, kernel_size=2, stride=2, padding=0), 
            nn.ELU(),
            nn.Conv2d(32, 64, kernel_size=2, stride=2, padding=0), 
            nn.ELU(),
            nn.Conv2d(64, 128, kernel_size=2, stride=2, padding=0),
            nn.ELU(), 
            nn.Conv2d(128, 256, kernel_size=2, stride=2, padding=0),
            nn.ELU(),
            nn.Conv2d(256, 512, kernel_size=2, stride=2, padding=0),
            nn.ELU(),
            nn.Flatten(),
        )

    def forward(self, x):
        return self.cnn(x)

In [3]:
from torch.nn.functional import one_hot

class MyModelClass(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        visual_features_dim = 512
        target_features_dim = 9 * 11 * 11 
        self.visual_encoder = VisualEncoder()
        self.visual_encoder.load_state_dict(
            torch.load("/IGLU-Minecraft/models/AngelaCNN/encoder_weigths.pth", map_location=torch.device('cpu'))
        )
        self.target_encoder = nn.Sequential(
            nn.Conv3d(7, 1, kernel_size=1, stride=1, padding=0),
            nn.ELU(),
        )
        policy_hidden_dim = 256 
        self.policy_network = nn.Sequential(
            nn.Linear(visual_features_dim + target_features_dim, 1024),
            nn.ELU(),
            nn.Linear(1024, 512),
            nn.ELU(),
            nn.Linear(512, policy_hidden_dim),
            nn.ELU(),
            nn.Linear(policy_hidden_dim, policy_hidden_dim),
            nn.ELU(),
            #nn.Linear(policy_hidden_dim, policy_hidden_dim),
            #nn.ELU(),
        )
        self.qvalue_head = nn.Linear(policy_hidden_dim, num_outputs)
        
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.visual_encoder.cuda()
            self.target_encoder.cuda()
            self.policy_network.cuda()
            self.qvalue_head.cuda()
        
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict['obs']
        pov = obs['pov'].permute(0, 3, 1, 2).float() / 255.0
        target = one_hot(obs['target_grid'].long(), num_classes=7).permute(0, 4, 1, 2, 3).float()
        if self.use_cuda:
            pov.cuda()
            target.cuda()
            
        with torch.no_grad():
            visual_features = self.visual_encoder(pov)
            
        target_features = self.target_encoder(target)
        target_features = target_features.reshape(target_features.shape[0], -1)
        features = torch.cat([visual_features, target_features], dim=1)
        features = self.policy_network(features)
        qvalues = self.qvalue_head(features)
        return qvalues, state
    

In [4]:
ModelCatalog.register_custom_model("my_torch_model", MyModelClass)

In [5]:
class VisualObservationWrapper(ObsWrapper):
    def __init__(self, env, include_target=False):
        super().__init__(env)
        self.observation_space = {   
            'pov': gym.spaces.Box(low=0, high=255, shape=(64, 64, 3)),
            'inventory': gym.spaces.Box(low=0.0, high=20.0, shape=(6,)),
            'compass': gym.spaces.Box(low=-180.0, high=180.0, shape=(1,))
        }
        if include_target:
            self.observation_space['target_grid'] = \
                gym.spaces.Box(low=0, high=6, shape=(9, 11, 11))
        self.observation_space = gym.spaces.Dict(self.observation_space)

    def observation(self, obs, reward=None, done=None, info=None):
        if info is not None:
            if 'target_grid' in info:
                target_grid = info['target_grid']
                del info['target_grid']
            else:
                logger.error(f'info: {info}')
                if hasattr(self.unwrapped, 'should_reset'):
                    self.unwrapped.should_reset(True)
                target_grid = self.env.unwrapped.tasks.current.target_grid
        else:
            target_grid = self.env.unwrapped.tasks.current.target_grid
        return {
            'pov': obs['pov'].astype(np.float32),
            'inventory': obs['inventory'],
            'compass': np.array([obs['compass']['angle'].item()]),
            'target_grid': target_grid
        }

In [6]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

class RewardWrapper(gym.RewardWrapper):
    def __init__(self, env):
        super().__init__(env)
    
    def reward(self, rew):
        if rew == 0:
            rew = -0.01
        if abs(rew) == 1:
            rew /= 10
        return rew
    
def env_creator(env_config):
    env = gym.make('IGLUSilentBuilder-v0', max_steps=250)
    env.update_taskset(TaskSet(preset=['C3', 'C17', 'C32', 'C8']))
    #env = PovOnlyWrapper(env)
    env = VisualObservationWrapper(env, include_target=True)
    env = SelectAndPlace(env)
    env = Discretization(env, flat_action_space('human-level'))
    env = RewardWrapper(env)
    return env

from ray.tune.registry import register_env
register_env("my_env", env_creator)

from ray import tune
from ray.rllib.agents.dqn import ApexTrainer

In [None]:
from ray.tune.integration.wandb import WandbLogger

analysis = tune.run(ApexTrainer, 
         config={
             "env": "my_env", 
             "framework": "torch",
             #"gamma": 0.99,
             "num_gpus": 1,
             "num_workers": 2,
             "buffer_size": 50_000,
             "learning_starts": 2_000,
             "train_batch_size": 1000,
             "target_network_update_freq": 2000,
             #"prioritized_replay_alpha": 0.5,
             #"final_prioritized_replay_beta": 1.0,
             "min_iter_time_s": 30, 
             "rollout_fragment_length": 4,
             "collect_metrics_timeout": 1800,
             
             "v_min": -10.0,
             "v_max": 100.0,
             
             "exploration_config": {
                  "initial_epsilon": 1,
                  "epsilon_timesteps": 50_000,
                  "final_epsilon": 0.05,
              },
             "model": {
                    # Specify our custom model from above.
                 
                    "custom_model": "my_torch_model",
                    # Extra kwargs to be passed to your model's c'tor.
                    "custom_model_config": {},
              },
             "logger_config": {
                  "wandb": {
                      "project": "IGLU-Minecraft",
                      "name": "APEX MultiTask (C3, C17, C32, C8) pretrained (AngelaCNN) (3 noops after placement) r: -0.01 div10"
                  }
              },
              #"training_intensity": 50,
              "lr": 1e-5,
             
              "evaluation_num_workers": 1,
              "evaluation_interval": 1,
              "evaluation_num_episodes": 1,
              "evaluation_config": {
                  #"input": "sampler",
                  "explore": False,  
              },
        },
        #loggers=[WandbLogger],
        #local_dir="/IGLU-Minecraft/checkpoints/4_tasks",
        #keep_checkpoints_num=50,
        #checkpoint_freq=5,
        #checkpoint_at_end=True,
        #restore="/IGLU-Minecraft/checkpoints/4_tasks/PPO_2021-11-08_20-28-45/PPO_my_env_78cf0_00000_0_2021-11-08_20-28-45/checkpoint_000050/checkpoint-50"
        )



Trial name,status,loc
APEX_my_env_5ad0b_00000,PENDING,


Trial name,status,loc
APEX_my_env_5ad0b_00000,PENDING,




Trial name,status,loc
APEX_my_env_5ad0b_00000,PENDING,


Trial name,status,loc
APEX_my_env_5ad0b_00000,PENDING,




Trial name,status,loc
APEX_my_env_5ad0b_00000,PENDING,


Trial name,status,loc
APEX_my_env_5ad0b_00000,PENDING,




Trial name,status,loc
APEX_my_env_5ad0b_00000,PENDING,
