In [1]:
import os
import re
import math
import numpy as np

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
from torch.utils.data.dataset import Subset

import editdistance

In [2]:
class TextTransform:
    """Maps characters to integers and vice versa"""
    def __init__(self):
      self.char_map = { '': 0, ' ': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, 'i': 10, 'j': 11, 'k': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'x': 25, 'y': 26, 'z': 27 }
      self.index_map = { key: value for value, key in self.char_map.items() }

    def text_to_int(self, text):
      """ Use a character map and convert text to an integer sequence """
      return [self.char_map.get(c, 0) for c in text]

    def int_to_text(self, labels):
      """ Use a character map and convert integer labels to an text sequence """
      return ''.join([self.index_map[i] for i in labels])

train_audio_transforms = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
    torchaudio.transforms.TimeMasking(time_mask_param=100)
)

valid_audio_transforms = torchaudio.transforms.MelSpectrogram()

text_transform = TextTransform()



In [3]:
def data_processing(data, data_type="train"):
  spectrograms = []
  labels = []
  input_lengths = []
  label_lengths = []
  for (waveform, sample_rate, original_text, normalized_text, speaker_id, chapter_id, utterance_id) in data:
    if data_type == 'train':
      spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
    elif data_type == 'valid':
      spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
    else:
      raise Exception('data_type should be train or valid')
    spectrograms.append(spec)
    label = torch.Tensor(text_transform.text_to_int(normalized_text.lower()))
    labels.append(label)
    input_lengths.append(spec.shape[0]//2)
    label_lengths.append(len(label))

  spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
  labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

  return spectrograms, labels, input_lengths, label_lengths

In [4]:
def GreedyDecoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True):
  arg_maxes = torch.argmax(output, dim=2)
  decodes = []
  targets = []
  for i, args in enumerate(arg_maxes):
    decode = []
    targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))
    for j, index in enumerate(args):
      if index != blank_label:
        if collapse_repeated and j != 0 and index == args[j -1]:
          continue
        decode.append(index.item())
    decodes.append(text_transform.int_to_text(decode))
  return decodes, targets

In [5]:
class BidirectionalLSTM(nn.Module):
  def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
    super(BidirectionalLSTM, self).__init__()

    self.BiLSTM = nn.LSTM(
      input_size=rnn_dim, hidden_size=hidden_size,
      num_layers=1, batch_first=batch_first, bidirectional=True
    )
    self.layer_norm = nn.LayerNorm(rnn_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    x = self.layer_norm(x)
    x = F.gelu(x)
    x, _ = self.BiLSTM(x)
    x = self.dropout(x)
    return x

In [6]:
class SelfAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(SelfAttention, self).__init__()

        self.hidden_dim = hidden_dim
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(True),
            nn.Linear(64, 1)
        )

    def forward(self, encoder_outputs):
        # encoder_outputs : [batch_size, seq_len, hidden_dim]
        energy = self.projection(encoder_outputs)  # [batch_size, seq_len, 1]
        weights = F.softmax(energy.squeeze(-1), dim=1)  # [batch_size, seq_len]
        # weights : [batch_size, seq_len]

        outputs = (encoder_outputs * weights.unsqueeze(-1)).sum(dim=1)  # [batch_size, hidden_dim]

        return outputs

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()

        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.d_v = d_model // n_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(p=dropout)
        self.out_linear = nn.Linear(d_model, d_model)

    def attention(self, q, k, v, mask=None):
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        output = torch.matmul(attention_weights, v)
        return output

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        q = self.q_linear(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        k = self.k_linear(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)
        v = self.v_linear(v).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)

        attention_output = self.attention(q, k, v, mask=mask)
        attention_output = attention_output.transpose(1,2).contiguous().view(batch_size, -1, self.n_heads*self.d_v)
        output = self.out_linear(attention_output)
        return output


In [8]:
class SpeechRecognitionModel(nn.Module):
  def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1):
    super(SpeechRecognitionModel, self).__init__()
    n_feats = n_feats//2
    self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3//2)  # cnn for extracting heirachal features

    # n ~residual~ cnn layers with filter size of 32
    self.cnn_layers = nn.Sequential(*[
      # ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) 
      nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=3//2)
      for _ in range(n_cnn_layers * 2) # Times two to make this more compatible with the original model I am modifying, in which each one in n_cnn_layers created two Conv2d layers with skip connections
    ])
    self.fully_connected = nn.Linear(n_feats*32, rnn_dim)

    self.birnn_layers = nn.Sequential(*[
      # BidirectionalGRU(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
      #                   hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
      BidirectionalLSTM(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                        hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
      for i in range(n_rnn_layers)
    ])

    self.classifier = nn.Sequential(
      nn.Linear(rnn_dim*2, rnn_dim),  # birnn returns rnn_dim*2
      nn.GELU(),
      nn.Dropout(dropout),
      nn.Linear(rnn_dim, n_class)
    )

    self.attention = MultiHeadAttention(n_heads = 16, d_model=rnn_dim, dropout=dropout)  # add self-attention layer


  def forward(self, x):
    # print('Initial x:', x.size())
    x = self.cnn(x)
    x = self.cnn_layers(x)
    # print('After CNN layers:', x.size())
    sizes = x.size()
    x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
    x = x.transpose(1, 2) # (batch, time, feature)
    x = self.fully_connected(x)
    # print('After FC layer:', x.size())
    x = self.attention(x,x,x) # (batch, time, feature)
    # print('After attention:', x.size())
    x = self.birnn_layers(x)
    # print('After bilstm:', x.size())
    x = self.classifier(x)
    # print('After classifier:', x.size())
    return x

In [9]:
def train(model, device, train_loader, criterion, optimizer, scheduler, epoch):
  model.train()
  data_len = len(train_loader.dataset)
  for batch_idx, _data in enumerate(train_loader):
    spectrograms, labels, input_lengths, label_lengths = _data 
    spectrograms, labels = spectrograms.to(device), labels.to(device)

    optimizer.zero_grad()

    output = model(spectrograms)  # (batch, time, n_class)
    output = F.log_softmax(output, dim=2)
    output = output.transpose(0, 1) # (time, batch, n_class)

    loss = criterion(output, labels, input_lengths, label_lengths)
    loss.backward()

    # print('loss', loss.item())
    # print('learning_rate', scheduler.get_lr())

    optimizer.step()
    scheduler.step()
    if batch_idx % 10 == 0 or batch_idx == data_len:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(spectrograms), data_len,
        100. * batch_idx / len(train_loader), loss.item())
      )


In [10]:
def test(model, device, test_loader, criterion):
  print('\nevaluating...')
  model.eval()
  test_loss = 0
  test_char_edit_dist = []
  test_word_edit_dist = []
  with torch.no_grad():
    for data in test_loader:
      spectrograms, labels, input_lengths, label_lengths = data 
      spectrograms, labels = spectrograms.to(device), labels.to(device)

      output = model(spectrograms) # (batch, time, n_class)
      output = F.log_softmax(output, dim=2)
      output = output.transpose(0, 1) # (time, batch, n_class)

      loss = criterion(output, labels, input_lengths, label_lengths)
      test_loss += loss.item() / len(test_loader)

      decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)
      print("Pred:", decoded_preds)
      print("Actual: ", decoded_targets)
      for j in range(len(decoded_preds)):
        test_char_edit_dist.append(editdistance.eval(decoded_targets[j], decoded_preds[j]))
        test_word_edit_dist.append(editdistance.eval(decoded_targets[j].split(" "), decoded_preds[j].split(" ")))

  avg_char_edit_dist = sum(test_char_edit_dist)/len(test_char_edit_dist)
  avg_word_edit_dist = sum(test_word_edit_dist)/len(test_word_edit_dist)

  print("Test set:")
  print("Average loss: {:.4f}".format(test_loss))
  print("Average character edit distance: {:4f}".format(avg_char_edit_dist))
  print("Average word edit distance: {:.4f}".format(avg_word_edit_dist))

In [11]:
def main(learning_rate=5e-4, batch_size=20, epochs=10, train_url="train-clean-100", test_url="test-clean"):
  hparams = {
    "n_cnn_layers": 3,
    "n_rnn_layers": 5,
    "rnn_dim": 512,
    "n_class": 29,
    "n_feats": 128,
    "stride": 2,
    "dropout": 0.1,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "epochs": epochs
  }

  torch.manual_seed(7)

  # Get ideal device (CPU, GPU, or MPS for Apple Silicon)
  use_cuda = torch.cuda.is_available()
  device = torch.device("cuda" if use_cuda else "cpu")
  # if torch.backends.mps.is_available():
  #   if not torch.backends.mps.is_built():
  #     torch.backends.mps.build()
  #   device = torch.device("mps")

  if not os.path.isdir("./data"):
    os.makedirs("./data")

  train_dataset = torchaudio.datasets.LIBRITTS(root="data", url=train_url, download=True)
  test_dataset = torchaudio.datasets.LIBRITTS(root="data", url=test_url, download=True)

  indices = torch.randperm(len(train_dataset))[:10000]
  train_dataset = Subset(train_dataset, indices)
  indices = torch.randperm(len(test_dataset))[:1000]
  test_dataset = Subset(test_dataset, indices)


  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = data.DataLoader(dataset=train_dataset,
                              batch_size=hparams['batch_size'],
                              shuffle=True,
                              collate_fn=lambda x: data_processing(x, 'train'),
                              **kwargs)
  test_loader = data.DataLoader(dataset=test_dataset,
                              batch_size=hparams['batch_size'],
                              shuffle=False,
                              collate_fn=lambda x: data_processing(x, 'valid'),
                              **kwargs)

  model = SpeechRecognitionModel(
    hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
    hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
  ).to(device)

  print(model)
  print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

  optimizer = optim.AdamW(model.parameters(), hparams['learning_rate'])
  criterion = nn.CTCLoss(blank=28).to(device)
  scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hparams['learning_rate'], 
                                          steps_per_epoch=int(len(train_loader)),
                                          epochs=hparams['epochs'],
                                          anneal_strategy='linear')
  
  for epoch in range(1, epochs + 1):
    train(model, device, train_loader, criterion, optimizer, scheduler, epoch)
    test(model, device, test_loader, criterion)

    torch.save({
      "epoch": epoch,
      "model_state_dict": model.state_dict(),
      "optimizer_state_dict": optimizer.state_dict(),
    }, "bilstmwattention.pt")

In [12]:
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

In [13]:
learning_rate = 5e-4
batch_size = 10
epochs = 1
libri_train_set = "train-clean-100"
libri_test_set = "test-clean"

main(learning_rate, batch_size, epochs, libri_train_set, libri_test_set)

100%|██████████| 7.19G/7.19G [08:35<00:00, 15.0MB/s]
100%|██████████| 1.15G/1.15G [01:24<00:00, 14.5MB/s]


SpeechRecognitionModel(
  (cnn): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (cnn_layers): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (fully_connected): Linear(in_features=2048, out_features=512, bias=True)
  (birnn_layers): Sequential(
    (0): BidirectionalLSTM(
      (BiLSTM): LSTM(512, 512, batch_first=True, bidirectional=True)
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): BidirectionalLSTM(
      (BiLSTM): LSTM(1024, 512, bidirectional=True)
      (laye

OutOfMemoryError: ignored

In [None]:
model = torch.load('bilstmwattention.pt')

In [None]:
test_dataset = torchaudio.datasets.LIBRITTS(root="data", url="test-clean", download=True)