Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hwang1996 committed Mar 26, 2019
1 parent bcd37d7 commit c912fc1
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions train.py
Expand Up @@ -42,7 +42,6 @@
text_discriminator = torch.nn.DataParallel(text_emb_discriminator().cuda(), device_ids=device)
netsD = torch.nn.DataParallel(D_NET128().cuda(), device_ids=device)


## load loss functions
triplet_loss = TripletLoss(device, margin=0.3)
img2text_criterion = nn.MultiLabelMarginLoss().cuda()
Expand All @@ -52,8 +51,6 @@
class_criterion = nn.CrossEntropyLoss(weight=weights_class).cuda()

GAN_criterion = nn.BCELoss().cuda()
cosine_crit = nn.CosineEmbeddingLoss(0.1).cuda()


nz = opts.Z_DIM
noise = Variable(torch.FloatTensor(opts.batch_size, nz)).cuda()
Expand Down Expand Up @@ -468,4 +465,4 @@ def update(self, val, n=1):
self.avg = self.sum / self.count

if __name__ == '__main__':
main()
main()

0 comments on commit c912fc1

Please sign in to comment.