Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added global funetuning & validation loss early stopping & gemma support #50

Merged
merged 3 commits into from
Mar 19, 2024

Conversation

Godofnothing
Copy link
Collaborator

In this pull request following features are added:

  • full finetuning with teacher logits
  • validation loss is tracked instead of the rate of loss change on validation
  • gemma models support

convert_to_hf.py Outdated

try:
import safetensors
except:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's narrow down exception case with except ModuleNotFoundError:

.view(outs_batch.shape[0], -1)
.mean(dim=1)
.sqrt()
(outs_batch - outs_tensor[j].to(device)).float().square().view(batch_size, -1).mean(dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: the new formula looks similar to the previous one, but subtly different: it divides by squared norm instead of variance/std (without subtracting mean).
This looks like a minor change but let's double check: is this intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes more sense. It is signal to noise ratio.

@@ -724,6 +766,12 @@ def update_outs_parallel(
default=None,
help="(finetuning only) Per-device and per-forward-pass batch size used to accumulate global --batch_size",
)
parser.add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a pair of applex-to-apples evaluations where validation early stopping improves final model, as opposed to naive (previous) early stopping? If yes, please attach links to wandb experiments

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no apples-to-apples comparison.
I observed that the training in generally more robust with this option to the choice of learning rate.
The improvement is typically of order ~0.02 - 0.05 compared to our best runs.
In my opinion, this option is better than relative_mse_tolerance.

Copy link
Collaborator Author

@Godofnothing Godofnothing Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print("Loading quantized model ...")
model = load_quantized_model(model, load_quantized)
# TODO works only for Llama
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if still true, please make this into an assert statement (e.g. assert config["model_type"] == "llama")

@Godofnothing Godofnothing merged commit 45733ef into main Mar 19, 2024
2 checks passed
@Vahe1994 Vahe1994 deleted the improved_finetuning branch May 29, 2024 11:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants