-
Notifications
You must be signed in to change notification settings - Fork 19
/
walkthrough_rgb_mapping_ppo.py
130 lines (116 loc) Β· 4.44 KB
/
walkthrough_rgb_mapping_ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import Dict, Any, cast
import gym
import torch
from allenact.algorithms.onpolicy_sync.losses import PPO
from allenact.algorithms.onpolicy_sync.losses.ppo import PPOConfig
from allenact.base_abstractions.sensor import SensorSuite
from allenact.embodiedai.mapping.mapping_losses import (
BinnedPointCloudMapLoss,
SemanticMapFocalLoss,
)
from allenact.utils.experiment_utils import LinearDecay, PipelineStage
from allenact_plugins.ithor_plugin.ithor_sensors import (
RelativePositionChangeTHORSensor,
ReachableBoundsTHORSensor,
BinnedPointCloudMapTHORSensor,
SemanticMapTHORSensor,
)
from allenact_plugins.robothor_plugin.robothor_sensors import DepthSensorThor
from baseline_configs.walkthrough.walkthrough_rgb_base import (
WalkthroughBaseExperimentConfig,
)
from rearrange.baseline_models import WalkthroughActorCriticResNetWithPassiveMap
from rearrange.constants import (
FOV,
PICKUPABLE_OBJECTS,
OPENABLE_OBJECTS,
)
class WalkthroughRGBMappingPPOExperimentConfig(WalkthroughBaseExperimentConfig):
ORDERED_OBJECT_TYPES = list(sorted(PICKUPABLE_OBJECTS + OPENABLE_OBJECTS))
MAP_RANGE_SENSOR = ReachableBoundsTHORSensor(margin=1.0)
MAP_INFO = dict(
map_range_sensor=MAP_RANGE_SENSOR,
vision_range_in_cm=40 * 5,
map_size_in_cm=1050
if isinstance(MAP_RANGE_SENSOR, ReachableBoundsTHORSensor)
else 2200,
resolution_in_cm=5,
)
SENSORS = WalkthroughBaseExperimentConfig.SENSORS + [
RelativePositionChangeTHORSensor(),
MAP_RANGE_SENSOR,
DepthSensorThor(
height=WalkthroughBaseExperimentConfig.SCREEN_SIZE,
width=WalkthroughBaseExperimentConfig.SCREEN_SIZE,
use_normalization=False,
uuid="depth",
),
BinnedPointCloudMapTHORSensor(fov=FOV, **MAP_INFO),
SemanticMapTHORSensor(
fov=FOV, **MAP_INFO, ordered_object_types=ORDERED_OBJECT_TYPES,
),
]
@classmethod
def tag(cls) -> str:
return "WalkthroughRGBMappingPPO"
@classmethod
def num_train_processes(cls) -> int:
return max(1, torch.cuda.device_count() * 5)
@classmethod
def create_model(cls, **kwargs) -> WalkthroughActorCriticResNetWithPassiveMap:
map_sensor = cast(
BinnedPointCloudMapTHORSensor,
next(
s for s in cls.SENSORS if isinstance(s, BinnedPointCloudMapTHORSensor)
),
)
map_kwargs = dict(
frame_height=224,
frame_width=224,
vision_range_in_cm=map_sensor.vision_range_in_cm,
resolution_in_cm=map_sensor.resolution_in_cm,
map_size_in_cm=map_sensor.map_size_in_cm,
)
observation_space = (
SensorSuite(cls.SENSORS).observation_spaces
if kwargs.get("sensor_preprocessor_graph") is None
else kwargs["sensor_preprocessor_graph"].observation_spaces
)
return WalkthroughActorCriticResNetWithPassiveMap(
action_space=gym.spaces.Discrete(len(cls.actions())),
observation_space=observation_space,
rgb_uuid=cls.EGOCENTRIC_RGB_UUID,
unshuffled_rgb_uuid=cls.UNSHUFFLED_RGB_UUID,
semantic_map_channels=len(cls.ORDERED_OBJECT_TYPES),
height_map_channels=3,
map_kwargs=map_kwargs,
)
@classmethod
def _training_pipeline_info(cls, **kwargs) -> Dict[str, Any]:
"""Define how the model trains."""
training_steps = cls.TRAINING_STEPS
return dict(
named_losses=dict(
ppo_loss=PPO(clip_decay=LinearDecay(training_steps), **PPOConfig),
binned_map_loss=BinnedPointCloudMapLoss(
binned_pc_uuid="binned_pc_map",
map_logits_uuid="ego_height_binned_map_logits",
),
semantic_map_loss=SemanticMapFocalLoss(
semantic_map_uuid="semantic_map",
map_logits_uuid="ego_semantic_map_logits",
),
),
pipeline_stages=[
PipelineStage(
loss_names=["ppo_loss", "binned_map_loss", "semantic_map_loss"],
loss_weights=[1.0, 1.0, 100.0],
max_stage_steps=training_steps,
)
],
num_steps=32,
num_mini_batch=1,
update_repeats=3,
use_lr_decay=True,
lr=3e-4,
)