Skip to content

Conversation

@qihqi
Copy link
Collaborator

@qihqi qihqi commented May 9, 2024

This way it produces more accurate results (with EOS)

{'rouge1': 36.9881, 'rouge2': 13.3464, 'rougeL': 21.7437, 'rougeLsum': 35.1489, 'gen_len': 1295948, 'gen_num': 1000}

@qihqi qihqi force-pushed the hanq_add_model branch 2 times, most recently from ab0c882 to 5db75d6 Compare May 9, 2024 16:19
This way it produces more accurate results
@qihqi qihqi force-pushed the hanq_add_model branch from 5db75d6 to 3fcd49d Compare May 9, 2024 16:20
@qihqi qihqi requested review from FanhaiLu1 and lsy323 May 9, 2024 16:21

if new_key != key:
state_dict[new_key] = state_dict.pop(key)
output_ckpt_dir.mkdir(parents=True, exist_ok=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be needed. Output folder will be created in _export_to_local https://github.com/google/jetstream-pytorch/blob/811d718c1f93e5ce37182e2c1ec54d3dc0b4aed7/convert_checkpoints.py#L355

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

]
model_config = json.loads((input_ckpt_dir / "config.json").read_text())
for key in list(state_dict.keys()):
print(key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


freqs_cis : -1 # torch.complex64 (16384, 128)
freqs_cis : null # torch.complex64 (16384, 128)
layers.*.self_attn.qkv_proj.weight: 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this one only for test purpose? In gemma model, I saw the code directly read wq,wk and wv.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

self.env.apply_sharding(output, axis=2)
return self.wo(output)
output = self.attention_kernel(xq, xk, xv, mask, cache)
output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, code is cleaner with refactoring the attention kernel.

return x_out


class GemmaAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In long term, we might need to extend from Attention class, large percentage of code are similar.

@qihqi qihqi force-pushed the hanq_add_model branch from 97a44d5 to a9fe13e Compare May 9, 2024 18:09
@qihqi qihqi merged commit 57eb0e1 into main May 9, 2024
@qihqi qihqi deleted the hanq_add_model branch May 9, 2024 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants