Skip to content

Commit

Permalink
🏎 DeepSpeed Optimizer indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
Zasder3 committed Jun 12, 2021
1 parent baebc38 commit e18b7ca
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions models/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def training_step(self, train_batch, idx):
acc_t = (torch.argmax(image_logits, 0) == ground_truth).sum()
self.log_dict({'loss': loss, 'acc': (acc_i + acc_t) / 2 / len(image)}, prog_bar=True)

if isinstance(optimizer, list):
optimizer = optimizer[0]
optimizer.zero_grad()

# image loss
Expand Down Expand Up @@ -207,6 +209,8 @@ def training_step(self, train_batch, idx):
loss += (F.kl_div(image_logits_notemp * self.sink_temp, img_target) + F.kl_div(image_logits_notemp.t() * self.sink_temp, txt_target)) / 2 * self.kl_coeff
self.log_dict({'loss': loss, 'acc': (acc_i + acc_t) / 2 / len(image)}, prog_bar=True)

if isinstance(optimizer, list):
optimizer = optimizer[0]
optimizer.zero_grad()

# image loss
Expand Down

0 comments on commit e18b7ca

Please sign in to comment.