diff --git a/generative/networks/nets/vqvae.py b/generative/networks/nets/vqvae.py index 74173067..524184c1 100644 --- a/generative/networks/nets/vqvae.py +++ b/generative/networks/nets/vqvae.py @@ -442,10 +442,12 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return reconstruction, quantization_losses - def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor: + def encode_stage_2_inputs(self, x: torch.Tensor, quantized: bool = True) -> torch.Tensor: z = self.encode(x) e, _ = self.quantize(z) - return e + if quantized: + return e + return z def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor: e, _ = self.quantize(z)