From 96af85d30a9ac0f79fac12ea5e141852f1d6e5f1 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 12 Mar 2024 16:21:25 +0000 Subject: [PATCH 1/2] Addition of non_quantized flag to enable non-quantised encoding for LDM training on the VQVAE. --- generative/networks/nets/vqvae.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 74173067..dd93541a 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -442,8 +442,10 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return reconstruction, quantization_losses - def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + def encode_stage_2_inputs(self, x: torch.Tensor, non_quantized: bool = False) -> torch.Tensor: z = self.encode(x) + if non_quantized: + return z e, _ = self.quantize(z) return e From 6b132e12c01dbe94f9c30637dbca56ba7cfeba73 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 12 Mar 2024 16:26:24 +0000 Subject: [PATCH 2/2] non_quantized > quantized Default true. --- generative/networks/nets/vqvae.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index dd93541a..524184c1 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -442,12 +442,12 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return reconstruction, quantization_losses - def encode_stage_2_inputs(self, x: torch.Tensor, non_quantized: bool = False) -> torch.Tensor: + def encode_stage_2_inputs(self, x: torch.Tensor, quantized: bool = True) -> torch.Tensor: z = self.encode(x) - if non_quantized: - return z e, _ = self.quantize(z) - return e + if quantized: + return e + return z def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: e, _ = self.quantize(z)