Skip to content

Commit

Permalink
fix bug in mean() to compile with pytorch version 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
OctoberChang committed Oct 3, 2017
1 parent f2cda1f commit b15c988
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mmd_gan.py
Expand Up @@ -63,8 +63,8 @@ def __init__(self):

def forward(self, input):
output = self.main(-input)
output = -output.mean(1)
return output.view(1)
output = -output.mean()
return output


# Get argument
Expand Down Expand Up @@ -189,6 +189,8 @@ def forward(self, input):
mmd2_D = F.relu(mmd2_D)

# compute rank hinge loss
#print('f_enc_X_D:', f_enc_X_D.size())
#print('f_enc_Y_D:', f_enc_Y_D.size())
one_side_errD = one_sided(f_enc_X_D.mean(0) - f_enc_Y_D.mean(0))

# compute L2-loss of AE
Expand Down

0 comments on commit b15c988

Please sign in to comment.