In [None]:
!pip install nemo-toolkit['all']

In [2]:
import torch
import torch.nn as nn

In [3]:
from nemo.core.classes import Serialization, Typing, typecheck
from nemo.core.neural_types import LossType, NeuralType, LogprobsType, IntType

[NeMo W 2022-06-25 10:44:27 optimizers:55] Apex was not found. Using the lamb or fused_adam optimizer will error out.


In [4]:
_all_ = ['RBKDLoss']

In [5]:
class RBKDLoss(Serialization, Typing, nn.CTCLoss):

    @property
    def input_types(self):
        """Input types definitions for RBKDLoss.
        """
        return {
            "teacher_logits": NeuralType(('B', 'T', 'D'), LogprobsType()),
            "log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()),
        }

    @property
    def output_types(self):
        """Output types definitions for CTCLoss.
        loss:
            NeuralType(None)
        """
        return {"loss": NeuralType(elements_type=LossType())}

    def __init__(self, num_classes, zero_infinity=False, reduction='mean_batch', temperature=3):
        # Don't forget to properly call base constructor
        self._blank = num_classes
        # Don't forget to properly call base constructor
        if reduction == 'mean_batch':
            ctc_reduction = 'none'
            self._apply_batch_mean = True
        elif reduction in ['sum', 'mean', 'none']:
            ctc_reduction = reduction
            self._apply_batch_mean = False
        super().__init__(blank=self._blank, reduction=ctc_reduction, zero_infinity=zero_infinity)
        self.temperature = temperature

    @typecheck
    def forward(self, teacher_logits, log_probs):
        # RBKD Loss

        # here we transpose because we expect [B, T, D] while PyTorch assumes [T, B, D]
        log_probs = log_probs.transpose(1, 0)
        teacher_logits = teacher_logits.transpose(1, 0)

        # teacher soft probability
        p_teacher = torch.div(torch.pow(teacher_logits, 1 / self.temperature),
                              torch.sum(torch.pow(teacher_logits, 1 / self.temperature), 2))
        # student soft probability
        p_student = torch.div(torch.pow(log_probs, 1 / self.temperature),
                              torch.sum(torch.pow(log_probs, 1 / self.temperature), 2))
        
        # loss (Row of length B)
        rbkd_loss = torch.sum(torch.sum(-1 * (p_teacher * (torch.log(p_student))), 2), 0)

        if self._apply_batch_mean:
            rbkd_loss = torch.mean(rbkd_loss)
        return rbkd_loss
