diff --git a/ptp/components/models/vqa/attention.py b/ptp/components/models/vqa/attention.py index 196b26e..e9d3af7 100644 --- a/ptp/components/models/vqa/attention.py +++ b/ptp/components/models/vqa/attention.py @@ -151,7 +151,7 @@ def forward(self, data_dict): outputs = attention_enc_img elif(self.output_mode == 'Fusion'): # Fusion -- Concatenate attention-weighted image encodings and question encodings. - outputs = torch.cat([attention_enc_img, latent_q], dim=1) + outputs = torch.cat([attention_enc_img, enc_q], dim=1) # print("outputs", outputs.shape) # Add predictions to datadict. data_dict.extend({self.key_outputs: outputs})