Skip to content

Vladimetr/ASR-Knowledge-Transferring

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Knowledge Transferring (KT) module for train ASR

This repository represents PyTorch implementation of "Improving CTC-based speech recognition via knowledge transferring from pre-trained language models"
Knowledge Transferring (KT) implies usage of Language Model (LM) knowledge for train encoder (Wav2Vec).


There are two methods to train ASR Encoder:

  • Representation Learning (RL)
    code
    im1
  • Classification Learning (CL)
    code
    im2

Representation Learning (RL)

It uses BERT and provides two mechanisms:


Classification Learning (CL)

It uses cross attention mechanism for both encoder output and GPT2 embeddings. Attention outputs after linear layer produce logits for CrossEntropy Loss.


Example of usage

import torch
from kt_module import from_yaml

train_module = from_yaml('config.yaml')

# Encoder outputs (from Wav2Vec)
# B - batch size
# L - max seq len
# H - encoder out dim
B, L, H = 3, 10, 768
encoder_outputs = torch.rand(B, L, H, dtype=torch.float32, 
                                      requires_grad=True)
mask = torch.ones(B, L, dtype=torch.bool)  # 0 value means invalid

target_sentences = [
    'привет как дела',
    'что нового',
    'давай пока'
]
assert B == len(target_sentences)

losses = train_module(encoder_outputs, mask, target_sentences)
# loss = ctc_loss + [weighted sum of these losses]
# (see eq.4 and eq.7 in paper)

# loss.backward()

Releases

No releases published

Languages