From 036e251a7cdaf1f3b0601cd5d2f93d164aba75c7 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 19 Jul 2024 02:33:51 +0000 Subject: [PATCH 1/2] set accumulate type to bf16 --- jetstream_pt/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 3ecd875a..7554509a 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -148,7 +148,7 @@ def forward(self, inputs): self.weight, (((2,), (1)), ((), ())), None, - jnp.int32.dtype, + jnp.bfloat16.dtype, ) result = result * self.weight_scaler if self.quantize_activation: From df9fd382a662bf7533e74efd616917a3e69d44a3 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 19 Jul 2024 03:39:37 +0000 Subject: [PATCH 2/2] fix comment --- jetstream_pt/layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 7554509a..20deb4da 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -139,7 +139,8 @@ def forward(self, inputs): if not self.quantize_activation: result = F.linear(inputs, self.weight) else: - # We have to call jax because we need to do dot(int8, int8)->int32. + # We have to call jax because we need to specify the output dtype of dot + # dot(int8, int8)->bf16. # This semantic cannot be represented in torch. The inferred output dtype # will be int8 in torch, causing the dot result to overflow. result = torchjax.call_jax(