-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add dynamic MML, sampling #17
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sampling logic seems fine to me. You need to change _take_first_step
too, though. Does the code run as expected now?
# this is the only change required to sample from current log probabilities instead of beam search | ||
if sample: | ||
categorical_dist = Categorical(logits=curr_log_probs) | ||
_, group_index, log_prob, action_embedding, action = batch_states[categorical_dist.sample()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like you're sampling only one action here. Was that intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it was. So the idea is that you sample one action at each time step instead of keeping a sorted beam, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, so I want a sampler function that can sample a whole decoded sequence, and then call THAT multiple times. So at the end i'll have max_num_decoded_sequences number of sequences for each batch instance
# (group_size, num_start_type) | ||
start_action_logits = self._start_type_predictor(hidden_state) | ||
log_probs = torch.nn.functional.log_softmax(start_action_logits, dim=-1) | ||
sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll have to make changes here too, to sample the first action from log_probs
above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh i didn't see this. why is the first step handled in a different way than the rest of the steps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, so i'm going to add the same sampling logic to _take_first_step as well
…y finished states...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment
all_actions = all_actions.detach().cpu().numpy().tolist() | ||
if sample_states: | ||
# (group_size,) one action per group element | ||
sampler = Categorical(logits=sorted_log_probs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you have to use sorted log_probs
here? You could have just passed log_probs
, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so I wanted sampler to be able to return indices into all_actions which works consistently only if we use sorted_log_probs
merging master
The logic for sampling seems simple, but I feel I'm missing something here. Can you please take a look at line 305 in
basic_transition_function.py
, @pdasigi?