Skip to content

Commit

Permalink
fix mask dtype comatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed Mar 21, 2020
1 parent 41a1389 commit dd1ecbb
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
6 changes: 6 additions & 0 deletions src/textbrewer/compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch

if torch.__version__ < '1.2':
mask_dtype = torch.uint8
else:
mask_dtype = torch.bool
6 changes: 4 additions & 2 deletions src/textbrewer/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from .presets import *
from .configurations import TrainingConfig, DistillationConfig

from .compatibility import mask_dtype

logger = logging.getLogger("Distillation")
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -796,7 +798,7 @@ def _select_logits_with_mask(logits_list, masks_list):
if len(masks_list)==len(logits_list):
for logits,mask in zip(logits_list,masks_list):
if len(logits.shape)==3:
mask = mask.unsqueeze(-1).expand_as(logits).to(torch.uint8)
mask = mask.unsqueeze(-1).expand_as(logits).to(mask_dtype)
logits_select = torch.masked_select(logits,mask).view(-1,logits.size(-1))
else:
logits_select = logits #Logits_mask has no effect on logits of shape (batch_size, logits_to_be_softmaxed)
Expand All @@ -805,7 +807,7 @@ def _select_logits_with_mask(logits_list, masks_list):
mask = masks_list[0]
for logits in logits_list:
if len(logits.shape)==3:
mask = mask.unsqueeze(-1).expand_as(logits).to(torch.uint8)
mask = mask.unsqueeze(-1).expand_as(logits).to(mask_dtype)
logits_select = torch.masked_select(logits,mask).view(-1,logits.size(-1))
else:
logits_select = logits #Logits_mask has no effect on logits of shape (batch_size, logits_to_be_softmaxed)
Expand Down
4 changes: 3 additions & 1 deletion src/textbrewer/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
from typing import List

from .compatibility import mask_dtype

def kd_mse_loss(logits_S, logits_T, temperature=1):
'''
Calculate the mse loss between logits_S and logits_T
Expand Down Expand Up @@ -155,7 +157,7 @@ def cos_loss(state_S, state_T, mask=None):
state_S = state_S.view(-1,state_S.size(-1))
state_T = state_T.view(-1,state_T.size(-1))
else:
mask = mask.to(state_S).unsqueeze(-1).expand_as(state_S).to(torch.uint8) #(bs,len,dim)
mask = mask.to(state_S).unsqueeze(-1).expand_as(state_S).to(mask_dtype) #(bs,len,dim)
state_S = torch.masked_select(state_S, mask).view(-1, mask.size(-1)) #(bs * select, dim)
state_T = torch.masked_select(state_T, mask).view(-1, mask.size(-1)) # (bs * select, dim)

Expand Down

0 comments on commit dd1ecbb

Please sign in to comment.