Skip to content

Commit

Permalink
Composer work -- incomplete and not working
Browse files Browse the repository at this point in the history
  • Loading branch information
Coriana committed Aug 13, 2022
1 parent 0b9f6d3 commit 46cf94f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
45 changes: 43 additions & 2 deletions agent_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,49 @@ def _env_action_to_agent(self, minerl_action_transformed, to_torch=False, check_

def forward(self, batch): # batch is the output of the dataloader
# specify how batches are passed through the model
inputs, _ = batch
return self.model(inputs)
batch_images = batch[0]
batch_actions = batch[1]
batch_episode_id = batch[2]
# for batch_i, (batch_images, batch_actions, batch_episode_id) in enumerate(batch):
batch_loss = []
for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id):
if image is None and action is None:
# A work-item was done. Remove hidden state
if episode_id in episode_hidden_states:
removed_hidden_state = episode_hidden_states.pop(episode_id)
del removed_hidden_state
continue

agent_action = self._env_action_to_agent(action, to_torch=True, check_if_null=True)
if agent_action is None:
# Action was null
continue

agent_obs = self._env_obs_to_agent({"pov": image})
if episode_id not in episode_hidden_states:
episode_hidden_states[episode_id] = policy.initial_state(1)
agent_state = episode_hidden_states[episode_id]

pi_distribution, v_prediction, new_agent_state = policy.get_output_for_observation(
agent_obs,
agent_state,
dummy_first
)

log_prob = policy.get_logprob_of_action(pi_distribution, agent_action)

# Make sure we do not try to backprop through sequence
# (fails with current accumulation)
new_agent_state = tree_map(lambda x: x.detach(), new_agent_state)
episode_hidden_states[episode_id] = new_agent_state

# Finally, update the agent to increase the probability of the
# taken action.
# Remember to take mean over batch losses
#loss = -log_prob / BATCH_SIZE
batch_loss.append(-log_prob)

return self.model(batch_loss)

def loss(self, outputs, batch):
# pass batches and `forward` outputs to the loss
Expand Down
17 changes: 14 additions & 3 deletions behavioural_cloning_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch as th
import numpy as np

from composer.core import DataSpec
from composer import Trainer
from agent_composer import PI_HEAD_KWARGS, MineRLAgent
from data_loader import DataLoader
Expand Down Expand Up @@ -46,7 +47,10 @@ def load_model_parameters(path_to_model_file):
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
return policy_kwargs, pi_head_kwargs


def num_samples_fn(batch):
return BATCH_SIZE

def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights):
agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model)

Expand All @@ -56,6 +60,9 @@ def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights):
policy = agent.policy
trainable_parameters = policy.parameters()

episode_hidden_states = {}
dummy_first = th.from_numpy(np.array((False,))).to(DEVICE)

# Parameters taken from the OpenAI VPT paper
optimizer = th.optim.Adam(
trainable_parameters,
Expand All @@ -69,13 +76,17 @@ def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights):
batch_size=BATCH_SIZE,
n_epochs=EPOCHS
)


data_spec = DataSpec(data_loader, get_num_samples_in_batch=num_samples_fn,)

trainer = Trainer(
model=agent,
optimizers=optimizer,
train_dataloader=data_loader,
train_dataloader=data_spec,
max_duration='10ep'
)


trainer.fit()

state_dict = policy.state_dict()
Expand Down

0 comments on commit 46cf94f

Please sign in to comment.