Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove orchestrator abstraction from API #289

Merged
merged 7 commits into from
Feb 10, 2023

Conversation

jon-tow
Copy link
Collaborator

@jon-tow jon-tow commented Feb 8, 2023

This PR removes all of the orchestrator components for reasons outlined in #278.

Highlights for reviewer(s):

  • Adds a new method to AcceleratePPOTrainer called add_prompt_pipeline that mimics the prompt pipeline loading and device placement of the removed PPO orchestrator. This is sort of awkward because it requires users to manually call the method before running make_experience (the same thing you have to do with add_eval_pipeline). Open to suggestions; I'd prefer to pass the pipeline to the trainer constructor but it breaks the config-first approach currently implemented (should discuss for future refactoring?).

  • Removes dead code related to MagiCARP. Removes unused utils.topk_mask.

Reproduction reports:

@jon-tow jon-tow marked this pull request as ready for review February 8, 2023 07:33
@cat-state
Copy link
Collaborator

Thanks jon! I think you might have to also change something for the nemo trainer?

Copy link
Collaborator

@cat-state cat-state left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this mostly looks good except for the nemo, in future we can try to refactor and break up the big make_experience functions

@@ -89,8 +89,12 @@ def train( # noqa: C901
eval_prompts = prompts[:batch_size]

pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer)
orch = get_orchestrator(config.train.orchestrator)(trainer, pipeline, chunk_size=config.method.chunk_size)
orch.make_experience(config.method.num_rollouts)
trainer.add_prompt_pipeline(pipeline)
Copy link
Collaborator

@cat-state cat-state Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re add_prompt_pipeline - yeah that is a bit awkward, it should be possible to pass it via args if you move the get_trainer call into the PPO part of the branch? But if its too messy then np

You could also probably replace the get_pipeline with PromptPipeline since all the models use the same pipeline

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Yeah I wasn't sure if it'd be too messy - let me give it a go :)

  • I agree on replacing the get_pipeline stuff with just PromptPipeline; I originally did that but reverted before creating the PR to limit the scope to the issue being addressed. I can't think of any other prompt pipelines so not sure what the _DATAPIPELINE registry is even intended for?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so not sure what the _DATAPIPELINE registry is even intended for?

It's from the time when every pipeline was ought to be a specific dataset, each registered deliberately

@@ -95,3 +142,73 @@ def save_pretrained(self, directory: Optional[str] = None):
"`AccelerateILQLTrainer` does not currently support automatic saving "
"with `transformers.PreTrainedModel.save_pretrained`."
)

def make_experience(self, samples, rewards, max_length=2048):
Copy link
Collaborator

@cat-state cat-state Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe pull this into somewhere it can be shared with nemo impl?
I guess this could mean its worth also passing in the rollout store as an arg like the promptpipeline for PPO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've duplicated the make_experience into NeMo for now. There was a subtle difference in logging whereby NeMo couldn't recognize the global rank RANK == 0checks forcing each rank to write tables to stdout (the fix is to just use their global rank check util).

I think it might be best to push this off to another PR because we'll need to re-visit this abstraction again for PPO. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that makes sense for now. Maybe torch.distributed.get_rank() will work for both but we can revisit

Copy link
Collaborator

@cat-state cat-state left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks jon! I think this looks good to me, we can excise and clean up get_pipeline in future work and revisit the sharing between accelerate and nemo

@jon-tow jon-tow merged commit 81e935a into CarperAI:main Feb 10, 2023
@jon-tow jon-tow deleted the remove-orchs branch February 10, 2023 16:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants