Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust Online DPO changes for RL update #1664

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update rl.py removed function that did not need to be patched
  • Loading branch information
pluesclues authored Feb 8, 2025
commit 7503c716911dc5bd1a09b156625e16d84b2b6f99
49 changes: 4 additions & 45 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
@@ -113,7 +113,9 @@ def unsloth_get_reward(
else:
IS_SAGEMAKER_MP_POST_1_10 = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these primarily for Amazon Sagemaker?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for Amazon Sagemaker, I wanted to keep the old functionality of the function itself, sadly it is very deep into transformers and transformers.trainer this is exclusively in the saving steps. This function had to be overwritten since it does evaluation outside of the trainer's code directly. I think we could theoretically remove this if you want to, I just did not want to remove this from the function if we did not have to. I think there is a lot of SMP forward passes in the function hence why I decided to keep it and tried to make it compatible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/unslothai/unsloth/blob/main/unsloth/models/llama.py#L1956-L1961
Wasn't unsloth already setting this to False? Or were we missing something so far?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So actually yes, they do that, its more of to patch the trainer function itself right, so when trainer in huggingface wants to perform saving at like 500 steps right, it will call this function and if we do not write the Amazon sagemaker stuff here it will error out due to dependency issues and I wanted to keep most of the features of the function itself intact.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah cool. So no more force disabling sagemaker after these changes I presume

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So sagemaker by default will be disabled, its just that this function needs to check if sagemaker is there or not in order to also bring some of the other dependencies with it that are needed in the function. I did not really want to touch too much of the internal transformers stuff except for what unsloth needed to patch to get working.

from transformers.trainer_pt_utils import nested_detach
breakpoint()

#breakpoint()

def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_keys,):
"""
Perform an evaluation step on `model` using `inputs`.
@@ -214,50 +216,7 @@ def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_key
logits = logits[0]

return (loss, logits, labels)
# start_time defaults to None to allow compatibility with transformers<=4.46
def unsloth_maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
logs: dict[str, float] = {}

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

# reset tr_loss to zero
tr_loss -= tr_loss

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
logs["learning_rate"] = self._get_learning_rate()

# Add our metrics
for key, val in self.stats.items():
logs[key] = sum(val) / len(val)
self.stats = {key: [] for key in self.stats} # reset stats

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
self.log(logs, start_time)
else: # transformers<=4.46
self.log(logs)

#we need to edit
#"/home/kt828/miniconda3/envs/trl_env/lib/python3.11/site-packages/transformers/trainer.py", line 4471 needs to be modified
"""
metrics = None
if self.control.should_evaluate:
metrics = self._evaluate(trial, ignore_keys_for_eval)
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)

if self.args.save_strategy == "best":
self.control.should_save = is_new_best_metric
"""
if self.control.should_save:
self._save_checkpoint(model, trial)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

import trl.trainer
trainers = dir(trl.trainer)