Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dynamic MML, sampling #17

Open
wants to merge 67 commits into
base: master
Choose a base branch
from
Open

add dynamic MML, sampling #17

wants to merge 67 commits into from

Conversation

MurtyShikhar
Copy link
Contributor

The logic for sampling seems simple, but I feel I'm missing something here. Can you please take a look at line 305 in basic_transition_function.py, @pdasigi?

Copy link
Member

@pdasigi pdasigi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sampling logic seems fine to me. You need to change _take_first_step too, though. Does the code run as expected now?

# this is the only change required to sample from current log probabilities instead of beam search
if sample:
categorical_dist = Categorical(logits=curr_log_probs)
_, group_index, log_prob, action_embedding, action = batch_states[categorical_dist.sample()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you're sampling only one action here. Was that intentional?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it was. So the idea is that you sample one action at each time step instead of keeping a sorted beam, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so I want a sampler function that can sample a whole decoded sequence, and then call THAT multiple times. So at the end i'll have max_num_decoded_sequences number of sequences for each batch instance

# (group_size, num_start_type)
start_action_logits = self._start_type_predictor(hidden_state)
log_probs = torch.nn.functional.log_softmax(start_action_logits, dim=-1)
sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll have to make changes here too, to sample the first action from log_probs above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh i didn't see this. why is the first step handled in a different way than the rest of the steps?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, so i'm going to add the same sampling logic to _take_first_step as well

Copy link
Member

@pdasigi pdasigi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment

all_actions = all_actions.detach().cpu().numpy().tolist()
if sample_states:
# (group_size,) one action per group element
sampler = Categorical(logits=sorted_log_probs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you have to use sorted log_probs here? You could have just passed log_probs, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I wanted sampler to be able to return indices into all_actions which works consistently only if we use sorted_log_probs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants