-
Notifications
You must be signed in to change notification settings - Fork 18
Use GemmaAttention for Gemma #72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ab0c882 to
5db75d6
Compare
This way it produces more accurate results
convert_checkpoints.py
Outdated
|
|
||
| if new_key != key: | ||
| state_dict[new_key] = state_dict.pop(key) | ||
| output_ckpt_dir.mkdir(parents=True, exist_ok=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be needed. Output folder will be created in _export_to_local https://github.com/google/jetstream-pytorch/blob/811d718c1f93e5ce37182e2c1ec54d3dc0b4aed7/convert_checkpoints.py#L355
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
convert_checkpoints.py
Outdated
| ] | ||
| model_config = json.loads((input_ckpt_dir / "config.json").read_text()) | ||
| for key in list(state_dict.keys()): | ||
| print(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
default_shardings/gemma.yaml
Outdated
|
|
||
| freqs_cis : -1 # torch.complex64 (16384, 128) | ||
| freqs_cis : null # torch.complex64 (16384, 128) | ||
| layers.*.self_attn.qkv_proj.weight: 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this one only for test purpose? In gemma model, I saw the code directly read wq,wk and wv.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
| self.env.apply_sharding(output, axis=2) | ||
| return self.wo(output) | ||
| output = self.attention_kernel(xq, xk, xv, mask, cache) | ||
| output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, code is cleaner with refactoring the attention kernel.
| return x_out | ||
|
|
||
|
|
||
| class GemmaAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In long term, we might need to extend from Attention class, large percentage of code are similar.
This way it produces more accurate results (with EOS)
{'rouge1': 36.9881, 'rouge2': 13.3464, 'rougeL': 21.7437, 'rougeLsum': 35.1489, 'gen_len': 1295948, 'gen_num': 1000}