In [1]:
import warprnnt_pytorch
import torch

In [2]:
device = torch.device('cuda', 0)

In [3]:
logits = torch.FloatTensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
                              [0.1, 0.1, 0.6, 0.1, 0.1],
                              [0.1, 0.1, 0.2, 0.8, 0.1]],
                              [[0.1, 0.6, 0.1, 0.1, 0.1],
                              [0.1, 0.1, 0.2, 0.1, 0.1],
                              [0.7, 0.1, 0.2, 0.1, 0.1]]]])
warp_transducer_logits = logits.clone()
torchaudio_logits = logits.clone()
optimized_transducer_logits = logits.clone()

logits_cuda = logits.to(device)
warp_transducer_logits_cuda = logits_cuda.clone()
torchaudio_logits_cuda = logits_cuda.clone()
optimized_transducer_logits_cuda = logits_cuda.clone()


targets = torch.tensor([[1, 2]], dtype=torch.int32)
logit_lengths = torch.tensor([2], dtype=torch.int32)
target_lengths = torch.tensor([2], dtype=torch.int32)

targets_cuda = targets.to(device)
logit_lengths_cuda = logit_lengths.to(device)
target_lengths_cuda = target_lengths.to(device)

In [4]:
print(logits.shape)

torch.Size([1, 2, 3, 5])


In [5]:
warp_transducer_logits.requires_grad_(True)
torchaudio_logits.requires_grad_(True)
optimized_transducer_logits.requires_grad_(True)

warp_transducer_logits_cuda.requires_grad_(True)
torchaudio_logits_cuda.requires_grad_(True)
optimized_transducer_logits_cuda.requires_grad_(True)

tensor([[[[0.1000, 0.6000, 0.1000, 0.1000, 0.1000],
          [0.1000, 0.1000, 0.6000, 0.1000, 0.1000],
          [0.1000, 0.1000, 0.2000, 0.8000, 0.1000]],

         [[0.1000, 0.6000, 0.1000, 0.1000, 0.1000],
          [0.1000, 0.1000, 0.2000, 0.1000, 0.1000],
          [0.7000, 0.1000, 0.2000, 0.1000, 0.1000]]]], device='cuda:0',
       requires_grad=True)

In [6]:
warp_transducer_cpu_loss = warprnnt_pytorch.rnnt_loss(warp_transducer_logits, 
                                                      targets, 
                                                      logit_lengths, 
                                                      target_lengths,
                                                      blank=0,
                                                      reduction='mean',
                                                      fastemit_lambda=0)

warp_transducer_cuda_loss = warprnnt_pytorch.rnnt_loss(warp_transducer_logits_cuda, 
                                                       targets_cuda, 
                                                       logit_lengths_cuda, 
                                                       target_lengths_cuda,
                                                       blank=0,
                                                       reduction='mean',
                                                       fastemit_lambda=0)

In [7]:
print(f'warp_transducer, cpu_loss: {warp_transducer_cpu_loss}, cuda_loss: {warp_transducer_cuda_loss}')

warp_transducer, cpu_loss: tensor([4.4957], grad_fn=<_RNNTBackward>), cuda_loss: tensor([4.4957], device='cuda:0', grad_fn=<_RNNTBackward>)


In [8]:
warp_transducer_cpu_loss.backward()
warp_transducer_cuda_loss.backward()


In [9]:
print(warp_transducer_logits.grad)
print(warp_transducer_logits.grad.device)

tensor([[[[-0.1312, -0.3999,  0.1770,  0.1770,  0.1770],
          [-0.1857,  0.1225, -0.1817,  0.1225,  0.1225],
          [-0.3209,  0.0627,  0.0693,  0.1262,  0.0627]],

         [[ 0.0546, -0.2182,  0.0546,  0.0546,  0.0546],
          [ 0.1207,  0.1207, -0.4830,  0.1207,  0.1207],
          [-0.6926,  0.1687,  0.1865,  0.1687,  0.1687]]]])
cpu


In [10]:
print(warp_transducer_logits_cuda.grad)
print(warp_transducer_logits_cuda.grad.device)

tensor([[[[-0.1312, -0.3999,  0.1770,  0.1770,  0.1770],
          [-0.1857,  0.1225, -0.1817,  0.1225,  0.1225],
          [-0.3209,  0.0627,  0.0693,  0.1262,  0.0627]],

         [[ 0.0546, -0.2182,  0.0546,  0.0546,  0.0546],
          [ 0.1207,  0.1207, -0.4830,  0.1207,  0.1207],
          [-0.6926,  0.1687,  0.1865,  0.1687,  0.1687]]]], device='cuda:0')
cuda:0


In [18]:
!pip install --upgrade pip 
# !pip uninstall torchaudio -y
!pip install torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [11]:
import torchaudio



In [12]:
torchaudio_cpu_loss = torchaudio.functional.rnnt_loss(torchaudio_logits,
                                                      targets,
                                                      logit_lengths,
                                                      target_lengths,
                                                      blank=0,
                                                      reduction='mean'
                                                      )

torchaudio_cuda_loss = torchaudio.functional.rnnt_loss(torchaudio_logits_cuda,
                                                       targets_cuda,
                                                       logit_lengths_cuda,
                                                       target_lengths_cuda,
                                                       blank=0,
                                                       reduction='mean'
                                                       )

In [13]:
print(f'torchaudio, cpu_loss: {torchaudio_cpu_loss}, cuda_loss: {torchaudio_cuda_loss}')

torchaudio, cpu_loss: 4.495666980743408, cuda_loss: 4.49566650390625


In [14]:
torchaudio_cpu_loss.backward()
torchaudio_cuda_loss.backward()

In [15]:
print(torchaudio_logits.grad)

tensor([[[[-0.1312, -0.3999,  0.1770,  0.1770,  0.1770],
          [-0.1857,  0.1225, -0.1817,  0.1225,  0.1225],
          [-0.3209,  0.0627,  0.0693,  0.1262,  0.0627]],

         [[ 0.0546, -0.2182,  0.0546,  0.0546,  0.0546],
          [ 0.1207,  0.1207, -0.4830,  0.1207,  0.1207],
          [-0.6926,  0.1687,  0.1865,  0.1687,  0.1687]]]])


In [16]:
print(torchaudio_logits_cuda.grad)

tensor([[[[-0.1312, -0.3999,  0.1770,  0.1770,  0.1770],
          [-0.1857,  0.1225, -0.1817,  0.1225,  0.1225],
          [-0.3209,  0.0627,  0.0693,  0.1262,  0.0627]],

         [[ 0.0546, -0.2182,  0.0546,  0.0546,  0.0546],
          [ 0.1207,  0.1207, -0.4830,  0.1207,  0.1207],
          [-0.6926,  0.1687,  0.1865,  0.1687,  0.1687]]]], device='cuda:0')
