Skip to content

Commit

Permalink
fix kd_loss in general distiller
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed Dec 17, 2020
1 parent 290fe52 commit c7ef92e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/textbrewer/distiller_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def compute_loss(self,results_S,results_T):
temperature = self.d_config.temperature_scheduler(l_S, l_T, self.d_config.temperature)
else:
temperature = self.d_config.temperature
total_kd_loss += self.kd_loss(l_S, l_T, temperature)
total_kd_loss += self.kd_loss(l_S, l_T, temperature)
else:
for l_T,l_S in zip(logits_list_T,logits_list_S):
if self.d_config.temperature_scheduler is not None:
temperature = self.d_config.temperature_scheduler(l_S, l_T, self.d_config.temperature)
else:
temperature = self.d_config.temperature
total_kd_loss = self.kd_loss(l_S, l_T, temperature)
total_kd_loss += self.kd_loss(l_S, l_T, temperature)
total_loss += total_kd_loss * self.d_config.kd_loss_weight
losses_dict['unweighted_kd_loss'] = total_kd_loss

Expand Down

0 comments on commit c7ef92e

Please sign in to comment.