# Transducer ASR implementation in PyTorch



In [8]:
import torch
import random
import string
import unidecode
import numpy as np
import itertools
from itertools import cycle, islice
from collections import Counter
from tqdm import tqdm
from torch.utils.data import BufferedShuffleDataset, IterableDataset, DataLoader


# Transducer Module


<img src="trans_module.png" width="400">

In [9]:
NULL_INDEX = 0

encoder_dim = 1024
embedding_dim = 32
predictor_dim = 1024
joiner_dim = 1024

In [10]:
class TransEncoder(torch.nn.Module):
  def __init__(self, num_inputs):
    super(TransEncoder, self).__init__()
    self.embed = torch.nn.Embedding(num_inputs, embedding_dim)
    self.rnn = torch.nn.GRU(input_size=embedding_dim, hidden_size=encoder_dim, num_layers=3, batch_first=True, bidirectional=True, dropout=0.1)
    self.linear = torch.nn.Linear(encoder_dim*2, joiner_dim)

  def forward(self, x):
    out = x
    out = self.embed(out)
    out = self.rnn(out)[0]
    out = self.linear(out)
    return out

In [11]:
class TransPredictor(torch.nn.Module):
  def __init__(self, num_outputs):
    super(TransPredictor, self).__init__()
    self.embed = torch.nn.Embedding(num_outputs, embedding_dim)
    self.rnn = torch.nn.GRUCell(input_size=embedding_dim, hidden_size=predictor_dim)
    self.linear = torch.nn.Linear(predictor_dim, joiner_dim)
    
    self.initial_state = torch.nn.Parameter(torch.randn(predictor_dim))
    self.start_symbol = NULL_INDEX

  def forward_one_step(self, input, previous_state):
    embedding = self.embed(input)
    state = self.rnn.forward(embedding, previous_state)
    out = self.linear(state)
    return out, state

  def forward(self, y):
    batch_size = y.shape[0]
    U = y.shape[1]
    outs = []
    state = torch.stack([self.initial_state] * batch_size).to(y.device)
    for u in range(U+1): # need U+1 to get null output for final timestep 
      if u == 0:
        decoder_input = torch.tensor([self.start_symbol] * batch_size).to(y.device)
      else:
        decoder_input = y[:,u-1]
      out, state = self.forward_one_step(decoder_input, state)
      outs.append(out)
    out = torch.stack(outs, dim=1)
    return out

In [12]:
class TransJoiner(torch.nn.Module):
  def __init__(self, num_outputs):
    super(TransJoiner, self).__init__()
    self.linear = torch.nn.Linear(joiner_dim, num_outputs)

  def forward(self, encoder_out, predictor_out):
    out = encoder_out + predictor_out
    out = torch.nn.functional.relu(out)
    out = self.linear(out)
    return out

# Transducer model + loss function

Using the encoder, predictor, and joiner, we will implement the Transducer model and its loss function.

<img src="https://lorenlugosch.github.io/images/transducer/forward-messages.png" width="25%">

We will use a simple PyTorch implementation of the loss function, relying on automatic differentiation to give us gradients. 

(It's more efficient to write the forward() and backward() in C/CUDA. See https://github.com/HawkAaron/warp-transducer for this. There's also going to be an efficient implementation of the Transducer loss function by my colleague Abdel Heba in the [SpeechBrain](https://speechbrain.github.io/) toolkit, to be released soon.)

In [13]:
class TransducerASR(torch.nn.Module):
  def __init__(self, num_inputs, num_outputs):
    super(TransducerASR, self).__init__()
    self.encoder = TransEncoder(num_inputs)
    self.predictor = TransPredictor(num_outputs)
    self.joiner = TransJoiner(num_outputs)

    if torch.cuda.is_available(): self.device = "cuda"
    else: self.device = "cpu"
    self.to(self.device)

  def compute_forward_prob(self, joiner_out, T, U, y):
    """
    joiner_out: tensor of shape (B, T_max, U_max+1, #labels)
    T: list of input lengths
    U: list of output lengths 
    y: label tensor (B, U_max+1)
    """
    B = joiner_out.shape[0]
    T_max = joiner_out.shape[1]
    U_max = joiner_out.shape[2] - 1
    log_alpha = torch.zeros(B, T_max, U_max+1).to(y.device)
    for t in range(T_max):
      for u in range(U_max+1):
          if u == 0:
            if t == 0:
              log_alpha[:, t, u] = 0.

            else: #t > 0
              log_alpha[:, t, u] = log_alpha[:, t-1, u] + joiner_out[:, t-1, 0, NULL_INDEX] 
                  
          else: #u > 0
            if t == 0:
              log_alpha[:, t, u] = log_alpha[:, t,u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
            
            else: #t > 0
              log_alpha[:, t, u] = torch.logsumexp(torch.stack([
                  log_alpha[:, t-1, u] + joiner_out[:, t-1, u, NULL_INDEX],
                  log_alpha[:, t, u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
              ]), dim=0)
    
    log_probs = []
    for b in range(B):
      log_prob = log_alpha[b, T[b]-1, U[b]] + joiner_out[b, T[b]-1, U[b], NULL_INDEX]
      log_probs.append(log_prob)
    log_probs = torch.stack(log_probs) 
    return log_prob

  def compute_loss(self, x, y, T, U):
    encoder_out = self.encoder.forward(x)
    predictor_out = self.predictor.forward(y)
    joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
    loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()
    return loss



<img src="alignments.png" width="75%">

In [9]:
def greedy_search(self, x, T):
  y_batch = []
  B = len(x)
  k = 10 
  encoder_out = self.encoder.forward(x)
  U_max = 200
  for b in range(B):
    t = 0; u = 0; y = [self.predictor.start_symbol]; predictor_state = self.predictor.initial_state.unsqueeze(0)
    path_metric = 1
    while t < T[b] and u < U_max:
     
      predictor_input = torch.tensor([ y[-1] ]).to(x.device)
      g_u, predictor_state = self.predictor.forward_one_step(predictor_input, predictor_state)
      f_t = encoder_out[b, t]
      h_t_u = self.joiner.forward(f_t, g_u)
        
      argmax = h_t_u.max(-1)[1].item()
    
      if argmax == NULL_INDEX:
        t += 1
      else: # argmax == a label
        u += 1
        y.append(argmax)
    y_batch.append(y[1:]) # remove start symbol
    
  return y_batch

Transducer.greedy_search = greedy_search

# Utilities - Dataloader and labels

Here we will add a bit of boilerplate code for training and loading data.

In [14]:
char_labels =  [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", 
               "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]

char_dict  =  {char_labels[i] : i for i in range(len(char_labels))} 
inv_char_dict = {i : char_labels[i] for i in range(len(char_labels))} 

class IterableDataSetWithShuffle(IterableDataset):
    def __init__(self, filename, max_len=10):
        self.filename = filename
        self.max_len = max_len
        
        
    def parse_file(self, filename):
        
        def gen_features(line):
            
            #Remove linefeed and convert to lower case
            line = line.replace("\n", "") 
            line = unidecode.unidecode(line)
            line = line.lower()
            
            #Randomly repeat characters to mimic the phonemes in speech segment
            #Randomly remove characters to simulate auditory effect 
            line_y = "".join(c for c in line if c in char_labels)
            line_x = "".join(c*random.randint(0,1) if c in ("gjp") else c*random.randint(1,2) for c in line_y)   
            line_x = " ".join(line_x.split())
            return((line_x,line_y)) 
        
        with open(filename, 'r') as file_obj:
             
             #restrict length of sentences to max_len words   
             for line in file_obj:
                short_t_lines = []
                line = line.split('\t')[2]
                temp = line.split(' ')
                count = len(temp)//self.max_len
                resid = len(temp)%self.max_len
    
                k = 0
                while(count > 0):    
                     short_line = ' '.join(temp[k:k+self.max_len])  
  
                     short_t_lines.append(gen_features(short_line))
                     count = count - 1
                     k = k + self.max_len
                if(resid > 0):
                     short_line = ' '.join(temp[k:k+resid])
                     short_line = ' '.join(short_line.split())
                     if(len(short_line)>1):
                        short_t_lines.append(gen_features(short_line))
                 
                yield  from short_t_lines   
                    
    def __iter__(self):
        return cycle(self.parse_file(self.filename))  
    
    
#Encode and decode functions to map characters to int vectors
# 0 is mapped to blank symbol used for alignments

def encode_string(s):
  return [char_dict[c] + 1 for c in s]

def decode_labels(l):
  return "".join([inv_char_dict[c - 1] for c in l])


class Collate:
  def __call__(self, batch):
    """
    batch: list of tuples (input string, output string)
    Returns a minibatch of strings, encoded as labels and padded to have the same length.
    """
    x = []; y = []
    batch_size = len(batch)
    for index in range(batch_size):
      x_,y_ = batch[index]
      x.append(encode_string(x_))
      y.append(encode_string(y_))

    # pad all sequences to have same length
    T = [len(x_) for x_ in x]
    U = [len(y_) for y_ in y]
    T_max = max(T)
    U_max = max(U)
    for index in range(batch_size):
      x[index] += [NULL_INDEX] * (T_max - len(x[index]))
      x[index] = torch.tensor(x[index])
      y[index] += [NULL_INDEX] * (U_max - len(y[index]))
      y[index] = torch.tensor(y[index])

    # stack into single tensor
    x = torch.stack(x)
    y = torch.stack(y)
    T = torch.tensor(T)
    U = torch.tensor(U)

    return (x,y,T,U)


In [None]:
filename = "../LibriSpeech/train-clean-360/transcripts.tsv"
train_filename = "../LibriSpeech/train-clean-100/transcripts_460.tsv"
eval_filename = "../LibriSpeech/dev-clean/transcripts.tsv"
test_filename = "../LibriSpeech/test-clean/transcripts.tsv"
collate = Collate()

train_dataset = IterableDataSetWithShuffle(train_filename, max_len=10)
train_buffer = BufferedShuffleDataset(train_dataset, buffer_size=1024)
train_loader = DataLoader(train_buffer, batch_size=32, collate_fn=collate)     

eval_dataset = IterableDataSetWithShuffle(eval_filename, max_len=10)
eval_buffer = BufferedShuffleDataset(eval_dataset, buffer_size=1024)
eval_loader = DataLoader(eval_buffer, batch_size=8, collate_fn=collate)     

test_dataset = IterableDataSetWithShuffle(test_filename, max_len=10)
test_buffer = BufferedShuffleDataset(test_dataset, buffer_size=1024)
test_loader = DataLoader(test_buffer, batch_size=8, collate_fn=collate)   

In [12]:
class Trainer:
  def __init__(self, model, lr=0.0001, dr=0.9):
    self.model = model
    self.lr = lr
    self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
    self.dr = dr
    self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=self.dr)
  
  def train(self, dataset, print_interval=100):
    train_loss = 0
    num_samples = 0
    self.model.train()
    num_batch = 9200
    for idx, batch in enumerate(islice(dataset, num_batch)):
      x,y,T,U = batch
      x = x.to(self.model.device); y = y.to(self.model.device)
      batch_size = len(x)
      num_samples += batch_size
      loss = self.model.compute_loss(x,y,T,U)
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      train_loss += loss.item() * batch_size
      if idx % print_interval == 0:
         print("Batch_num :", idx)
    train_loss /= num_samples
    return train_loss


  def test(self, dataset, print_interval=100):
    test_loss = 0
    num_samples = 0
    self.model.eval()
    num_batch = 1000
    for idx, batch in enumerate(islice(dataset, num_batch)):
      x,y,T,U = batch
      xx = x.to(self.model.device); yy = y.to(self.model.device)
      batch_size = len(x)
      num_samples += batch_size
      loss = self.model.compute_loss(xx,yy,T,U)
      test_loss += loss.item() * batch_size
      if idx % print_interval == 0:
        print("\n")
        print("input:", decode_labels(x[0,:T[0]].numpy()))
        print("guess:", decode_labels(self.model.greedy_search(xx,T)[0]))
        print("truth:", decode_labels(y[0,:U[0]].numpy()))
        print("")
    test_loss /= num_samples
    return test_loss

  def fit(self, train_dataset, eval_dataset, num_epochs=10, print_interval = 100):
    train_loss = []
    eval_loss = []
    
  
    for epochs in range(num_epochs):
        print("Epoch No:{}".format(epochs))
        t_loss = self.train(train_dataset, print_interval)
        train_loss.append(t_loss)
        self.lr_scheduler.step()
        e_loss = self.test(eval_dataset, print_interval)
        eval_loss.append(e_loss)
        print("train loss {:.2f} eval_loss {:.2f}".format(t_loss, e_loss))
        save_file = "./model_epoch_{}.pt".format(epochs)
        torch.save(model.state_dict(), save_file)
        
    return train_loss, eval_loss


 
    

# Training the model

Now we will train a model. This will generate some output sequences every 20 batches.

In [None]:
num_chars = len(char_labels)
model = Transducer(num_inputs=num_chars+1, num_outputs=num_chars+1)
trainer = Trainer(model=model, lr=0.0001)
load_file = './model_1024_join_dim.pt'
model.load_state_dict(torch.load(load_file))
train = True
if train:
    train_loss, eval_loss = trainer.fit(train_loader, eval_loader, num_epochs=10)
    save_file = './asr_model.pt'
    print("Train Loss :", train_loss)
    print("Eval Loss :", eval_loss)
    torch.save(model.state_dict(), save_file)



Epoch No:0
Batch_num : 0
Batch_num : 100
Batch_num : 200
Batch_num : 300
Batch_num : 400
Batch_num : 500
Batch_num : 600
Batch_num : 700
Batch_num : 800
Batch_num : 900
Batch_num : 1000
Batch_num : 1100
Batch_num : 1200
Batch_num : 1300
Batch_num : 1400
Batch_num : 1500
Batch_num : 1600
Batch_num : 1700
Batch_num : 1800
Batch_num : 1900
Batch_num : 2000
Batch_num : 2100
Batch_num : 2200
Batch_num : 2300
Batch_num : 2400
Batch_num : 2500
Batch_num : 2600
Batch_num : 2700
Batch_num : 2800
Batch_num : 2900
Batch_num : 3000
Batch_num : 3100
Batch_num : 3200
Batch_num : 3300
Batch_num : 3400
Batch_num : 3500
Batch_num : 3600
Batch_num : 3700
Batch_num : 3800
Batch_num : 3900
Batch_num : 4000
Batch_num : 4100
Batch_num : 4200
Batch_num : 4300
Batch_num : 4400
Batch_num : 4500
Batch_num : 4600
Batch_num : 4700
Batch_num : 4800
Batch_num : 4900
Batch_num : 5000
Batch_num : 5100
Batch_num : 5200
Batch_num : 5300
Batch_num : 5400
Batch_num : 5500
Batch_num : 5600
Batch_num : 5700
Batch_num : 580

In [None]:

num_chars = len(char_labels)
load_file = './asr_model.pt'
model = Transducer(num_inputs=num_chars+1, num_outputs=num_chars+1)
model.load_state_dict(torch.load(load_file))
tester = Trainer(model=model, lr=0.0003)

In [None]:
from copy import deepcopy
def beam_search(self, x, T):  
 

  def continue_search(t, u, T_max, U_max, k):
      ret_val = True
      for i in range(k):
          ret_val = ret_val and (t[i]<T_max and u[i] < U_max)
      return(ret_val)
  
  k = 1
  B = len(x)
  encoder_out = self.encoder.forward(x)
  U_max = 200
  y_beam_batch = []
  for b in range(B):
    first_beam = True    
    t =[0]*k; u = [0]*k; 
    y = [self.predictor.start_symbol]*k; pred_state = [self.predictor.initial_state.unsqueeze(0)]*k
    path_sequence = [[self.predictor.start_symbol] for i in range(k)] 
    path_metric = [1]*k
    temp_sequence = [[0]]*k
    temp_t = [0]*k
    temp_u = [0]*k
    while continue_search(t, u, T[b], U_max, k): # t[beam_index] < T[b] and u[beam_index] < U_max:
      temp_path_m = np.array([])
      argmax = []  
      temp_pred_state = []  
      
    
      num_beam = k
      if(first_beam):
         num_beam = 1
         first_beam = False   
        
      for beam_index in range(num_beam):    
          predictor_input = torch.tensor([ y[beam_index] ]).to(x.device)
          g_u, beam_state = self.predictor.forward_one_step(predictor_input, pred_state[beam_index])
          f_t = encoder_out[b, t[beam_index]]
          h_t_u = self.joiner.forward(f_t, g_u) 
          h_t_u = torch.sort(h_t_u) 
          soft_out = softmax_mod(h_t_u[0][0][-k:]  
          argmax.extend([ h_t_u[1][0][-i].item() for i in range(1, k+1)])
          temp_path_m = np.append(temp_path_m , [path_metric[beam_index]*(soft_out[-i].item()) for i in range(1, k+1)])
          temp_pred_state.append(beam_state)
      
     
      args = temp_path_m.argsort()[::-1][:k]
      beam_args = args//k
      #print(argmax, args, beam_args, temp_path_m[args]) 
      #print(decode_labels(argmax))
       
     
      for beam_index in range(k):
          path_metric[beam_index] = temp_path_m[args[beam_index]]
          temp_sequence[beam_index] = deepcopy(path_sequence[beam_args[beam_index]])  
         
          out_pred = argmax[args[beam_index]]  
          #print("Curr Seq:", decode_labels(temp_sequence[beam_index]), "Char:", decode_labels([out_pred]))  
          pred_state[beam_index] = temp_pred_state[beam_args[beam_index]]
          if out_pred == NULL_INDEX:
             t[beam_index] += 1
          else: # argmax == a label
             u[beam_index] += 1
             y[beam_index] = out_pred
             temp_sequence[beam_index].append(out_pred) 
               
      for beam_index in range(k): 
          path_sequence[beam_index] = temp_sequence[beam_index]
          #t[beam_index] = temp_t[beam_index]  
          #u[beam_index] = temp_u[beam_index] 
          #print("Next Seq:", decode_labels(path_sequence[beam_index]))  
          
      #print('-------------------------------')  
    #h_t_u = torch.sort(h_t_u) 
    y_beam_batch.append(path_sequence[0][1:])
    
  
  return y_beam_batch


Transducer.greedy_search = beam_search

In [None]:
from copy import deepcopy
softmax_mod = torch.nn.Softmax(dim=0)
def beam_search(self, x, T):  
 

  def continue_search(t, u, T_max, U_max, k):
      ret_val = True
      for i in range(k):
          ret_val = ret_val and (t[i]<T_max and u[i] < U_max)
      return(ret_val)
  
  k = 3
  B = len(x)
  encoder_out = self.encoder.forward(x)
  U_max = 200
  y_beam_batch = []
  for b in range(B):
    first_beam = True    
    t =[0]*k; u = [0]*k; 
    y = [self.predictor.start_symbol]*k; pred_state = [self.predictor.initial_state.unsqueeze(0)]*k
    path_sequence = [[self.predictor.start_symbol] for i in range(k)] 
    path_metric = [1]*k
    temp_sequence = [[0]]*k
    temp_t = [0]*k
    temp_u = [0]*k
    while continue_search(t, u, T[b], U_max, k): # t[beam_index] < T[b] and u[beam_index] < U_max:
      temp_path_m = np.array([])
      argmax = []  
      temp_pred_state = []  
      
    
      num_beam = k
      if(first_beam):
         num_beam = 1
         first_beam = False   
        
      for beam_index in range(num_beam):    
          predictor_input = torch.tensor([ y[beam_index] ]).to(x.device)
          g_u, beam_state = self.predictor.forward_one_step(predictor_input, pred_state[beam_index])
          f_t = encoder_out[b, t[beam_index]]
          h_t_u = self.joiner.forward(f_t, g_u) 
        
          h_t_u = torch.sort(h_t_u) 
          soft_out = softmax_mod(h_t_u[0][0])    
          argmax.extend([ h_t_u[1][0][-i].item() for i in range(1, k+1)])
          temp_path_m = np.append(temp_path_m , [path_metric[beam_index]*soft_out[-i].item() for i in range(1, k+1)])
          temp_pred_state.append(beam_state)
      
     
      args = temp_path_m.argsort()[::-1][:k]
      beam_args = args//k
      #print(argmax, args, beam_args, temp_path_m[args]) 
      #print(decode_labels(argmax))
       
     
      for beam_index in range(k):
          path_metric[beam_index] = temp_path_m[args[beam_index]]
          temp_sequence[beam_index] = deepcopy(path_sequence[beam_args[beam_index]])  
         
          out_pred = argmax[args[beam_index]]  
          #print("Curr Seq:", decode_labels(temp_sequence[beam_index]), "Char:", decode_labels([out_pred]))  
          pred_state[beam_index] = temp_pred_state[beam_args[beam_index]]
          if out_pred == NULL_INDEX:
             t[beam_index] += 1
          else: # argmax == a label
             u[beam_index] += 1
             y[beam_index] = out_pred
             temp_sequence[beam_index].append(out_pred) 
               
      for beam_index in range(k): 
          path_sequence[beam_index] = temp_sequence[beam_index]
          #t[beam_index] = temp_t[beam_index]  
          #u[beam_index] = temp_u[beam_index] 
          #print("Next Seq:", decode_labels(path_sequence[beam_index]))  
          
      #print('-------------------------------')  
    #h_t_u = torch.sort(h_t_u) 
    y_beam_batch.append(path_sequence[0][1:])
  return y_beam_batch

Transducer.greedy_search = beam_search