In [None]:
!pip install nemo_toolkit['all']
!pip install hydra-core==1.1
!pip install import-ipynb

In [None]:
import torch
import torch.nn as nn
import sys

from nemo.core.classes import Serialization, Typing, typecheck
from nemo.core.neural_types import LossType, NeuralType, LogprobsType, IntType

In [3]:
sys.argv = ['']
del sys

In [4]:
_all_ = ['EBKDLoss']

In [5]:
class EBKDLoss(Serialization, Typing, nn.CTCLoss):
    @property
    def input_types(self):
        """Input types definitions for EBKDLoss.
        """
        return {
            "teacher_logits": NeuralType(('B', 'T', 'D'), LogprobsType()),
            "log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()),
            "teacher_feature_map": NeuralType(('B', 'T', 'D'), LogprobsType()),
            "student_feature_map": NeuralType(('B', 'T', 'D'), LogprobsType())
        }

    @property
    def output_types(self):
        """Output types definitions for EBKDLoss.
        loss:
            NeuralType(None)
        """
        return {"loss": NeuralType(elements_type=LossType())}

    def _init_(self, num_classes, zero_infinity=False, reduction='mean_batch'):
        # Don't forget to properly call base constructor
        self._blank = num_classes
        # Don't forget to properly call base constructor
        if reduction == 'mean_batch':
            ctc_reduction = 'none'
            self._apply_batch_mean = True
        elif reduction in ['sum', 'mean', 'none']:
            ctc_reduction = reduction
            self._apply_batch_mean = False
        super().__init__(blank=self._blank, reduction=ctc_reduction, zero_infinity=zero_infinity)

    @typecheck
    def forward(self, teacher_logits, log_probs, teacher_feature_map, student_feature_map):
        # EBKD Loss
        # here we transpose/permute because we expect [B, T, D] while PyTorch assumes [T, B, D]
        log_probs = log_probs.transpose(1, 0)
        teacher_logits = teacher_logits.transpose(1, 0)
        teacher_feature_map = teacher_feature_map.transpose(1, 0)
        student_feature_map = student_feature_map.transpose(1, 0)

        # teacher greedy prediction probability - 1D Tensor (row of length B)
        teacher_gpp = torch.prod(torch.max(teacher_logits, 2), 0)
        # student greedy prediction probability - 1D Tensor (row of length B)
        student_gpp = torch.prod(torch.max(log_probs, 2), 0)

        teacher_importance_map = torch.autograd.grad(torch.log(teacher_gpp), teacher_feature_map, retain_graph=True)
        student_importance_map = torch.autograd.grad(torch.log(student_gpp), student_feature_map, retain_graph=True)

        teacher_attention_map = torch.nn.ReLU(torch.mul(teacher_importance_map, teacher_feature_map))
        student_attention_map = torch.nn.ReLU(torch.mul(student_importance_map, student_feature_map))

        # normalized teacher attention map
        teacher_norm = torch.norm(teacher_attention_map, 2)
        # normalized student attention map
        student_norm = torch.norm(student_attention_map, 2)

        # loss (Row of length B)
        ebkd_loss = torch.sum(torch.norm(
            (teacher_attention_map / teacher_norm) - (student_attention_map / student_norm), 
            2) / teacher_attention_map.shape[0], 0)

        if self._apply_batch_mean:
            ebkd_loss = torch.mean(ebkd_loss)
        return ebkd_loss
