From e22b8d0efbfd0a15194446170b193ec450443666 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 26 Jul 2021 19:38:39 +0100 Subject: [PATCH] Set the dtype correctly for vision GPT model (#694) * Set the dtype correctly * Add changelog --- CHANGELOG.md | 3 +++ pl_bolts/models/vision/image_gpt/gpt2.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a247597da4..749bef4180 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed momentum updating from val step and add separate val queue ([#631](https://github.com/PyTorchLightning/lightning-bolts/pull/631)) +- Fixed FP16 support with vision GPT model ([#694](https://github.com/PyTorchLightning/lightning-bolts/pull/694)) + + ## [0.3.4] - 2021-06-17 ### Changed diff --git a/pl_bolts/models/vision/image_gpt/gpt2.py b/pl_bolts/models/vision/image_gpt/gpt2.py index 9588a82920..37999d1af2 100644 --- a/pl_bolts/models/vision/image_gpt/gpt2.py +++ b/pl_bolts/models/vision/image_gpt/gpt2.py @@ -94,7 +94,7 @@ def forward(self, x, classify=False): h = self.token_embeddings(x.long()) # prepend sos token - sos = torch.ones(1, batch, self.hparams.embed_dim, device=x.device) * self.sos + sos = torch.ones(1, batch, self.hparams.embed_dim, device=x.device, dtype=x.dtype) * self.sos h = torch.cat([sos, h[:-1, :, :]], axis=0) # add positional embeddings