In [1]:
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division

import torch as t
import torch.nn as nn
from torch.autograd import Variable as V
import torch.nn.functional as F

In [5]:
# generated mask according sequence_length and max length, 1 ---> truth, 0 ---> padded
def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    batch_size = sequence_length.size(0)
    seq_range = t.arange(0, max_len).long()
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_range_expand = V(seq_range_expand)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = (sequence_length.unsqueeze(1).expand_as(seq_range_expand))
    return seq_range_expand < seq_length_expand

In [6]:
seqs_length = V(t.LongTensor([9, 8, 7, 6, 5, 4, 3, 2, 1]))
print("sequence length: ", seqs_length)

# generated mask
mask = sequence_mask(seqs_length)
print("mask: ", mask)

sequence length:  Variable containing:
 9
 8
 7
 6
 5
 4
 3
 2
 1
[torch.LongTensor of size 9]

mask:  Variable containing:
    1     1     1     1     1     1     1     1     1
    1     1     1     1     1     1     1     1     0
    1     1     1     1     1     1     1     0     0
    1     1     1     1     1     1     0     0     0
    1     1     1     1     1     0     0     0     0
    1     1     1     1     0     0     0     0     0
    1     1     1     0     0     0     0     0     0
    1     1     0     0     0     0     0     0     0
    1     0     0     0     0     0     0     0     0
[torch.ByteTensor of size 9x9]



In [7]:
def masked_cross_entropy(logits, target, length):
    """
    Args:
        logits: A Variable containing a FloatTensor of size (batch, max_len, num_classes)
                which contains the unnormalized probability for each class.
        
        target: A Variable containing a LongTensor of size (batch, max_len) which contains
                the index of the true class for each corresponding step.
        
        length: A Variable containing a LongTensor of size (batch,) which contains the length
                of each data in a batch.
                
    Returns:
        loss: An average loss value masked by the length.
    
    """
    length = V(t.LongTensor(length))
    # if t.cuda.is_available():
    #     length = length.cuda()
    
    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = F.log_softmax(logits_flat)
    
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    
    # losses_flat: (batch * max_len, 1)
    losses_flat = -t.gather(log_probs_flat, dim=1, index=target_flat)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    
    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))
    
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    return loss   