-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added global funetuning & validation loss early stopping & gemma support #50
Conversation
convert_to_hf.py
Outdated
|
||
try: | ||
import safetensors | ||
except: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's narrow down exception case with except ModuleNotFoundError:
.view(outs_batch.shape[0], -1) | ||
.mean(dim=1) | ||
.sqrt() | ||
(outs_batch - outs_tensor[j].to(device)).float().square().view(batch_size, -1).mean(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: the new formula looks similar to the previous one, but subtly different: it divides by squared norm instead of variance/std (without subtracting mean).
This looks like a minor change but let's double check: is this intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this makes more sense. It is signal to noise ratio.
@@ -724,6 +766,12 @@ def update_outs_parallel( | |||
default=None, | |||
help="(finetuning only) Per-device and per-forward-pass batch size used to accumulate global --batch_size", | |||
) | |||
parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a pair of applex-to-apples evaluations where validation early stopping improves final model, as opposed to naive (previous) early stopping? If yes, please attach links to wandb experiments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no apples-to-apples comparison.
I observed that the training in generally more robust with this option to the choice of learning rate.
The improvement is typically of order ~0.02 - 0.05 compared to our best runs.
In my opinion, this option is better than relative_mse_tolerance
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This run
https://wandb.ai/rock-and-roll/PPL_LLAMA_2/runs/anhk6jv6?nw=nwuserspiridon_sun_rotator
wikitext2 ppl = 6.22
is better than what we had in the paper
https://wandb.ai/rock-and-roll/PPL_LLAMA_2/runs/whmdskj8?nw=nwuserspiridon_sun_rotator
wikitext2 ppl = 6.31
src/modelutils.py
Outdated
print("Loading quantized model ...") | ||
model = load_quantized_model(model, load_quantized) | ||
# TODO works only for Llama |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if still true, please make this into an assert statement (e.g. assert config["model_type"] == "llama")
In this pull request following features are added: