In [1]:
from tqdm import tqdm
import random
import math

import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import torchaudio
from torchaudio.transforms import MelSpectrogram

from PIL import Image
import Levenshtein as Lev

import matplotlib.pyplot as plt

from IPython.display import Audio, display

# Dataset

In [2]:
def play_audio(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    if num_channels == 1:
        display(Audio(waveform[0], rate=sample_rate))
    elif num_channels == 2:
        display(Audio((waveform[0], waveform[1]), rate=sample_rate))
    else:
        raise ValueError("Waveform with more than 2 channels are not supported.")

In [3]:
train_data = torchaudio.datasets.SPEECHCOMMANDS('./SpeechCommands', download = True, subset = "training")
val_data = torchaudio.datasets.SPEECHCOMMANDS('./SpeechCommands', download = True, subset = "validation")

In [4]:
class CommandsDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        self.feature = MelSpectrogram(n_mels = 64)
        self.tokenizer = {chr(ord("a") + i): i + 3 for i in range(26)}
        self.tokenizer['PAD'] = 0
        self.tokenizer['SOS'] = 1
        self.tokenizer['EOS'] = 2
        
        self.inv_tokenizer = {value: key for key, value in self.tokenizer.items()}
        
    def __encode(self, word):
        result = [1]
        for w in word:
            result.append(self.tokenizer[w])
        result.append(2)
        return torch.tensor(result)
    
    def decode(self, tokens):
        word = []
        for t in tokens.tolist():
            word.append(self.inv_tokenizer[t])
        return "".join(word)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        waveform, sample_rate, target, _, _ = self.data[index]
        
        if sample_rate != 16000:
            waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
        features = self.feature(waveform)
        
        return features[0], self.__encode(target)

In [5]:
train_dataset = CommandsDataset(train_data)
val_dataset = CommandsDataset(val_data)

In [6]:
len(train_dataset.tokenizer)

29

In [7]:
def collate_fn(batch):
    audio = []
    audio_len = []
    text = []
    for sample in batch:
        audio.append(sample[0].transpose(0, 1))
        audio_len.append(audio[-1].size(0))
        text.append(sample[1])
    audio = torch.nn.utils.rnn.pad_sequence(audio, batch_first = True)
    audio_mask = torch.ones((audio.size(0), audio.size(1)))
    for k, length in enumerate(audio_len):
        audio_mask[k, :length] = 0
    
    text = torch.nn.utils.rnn.pad_sequence(text, batch_first = True)
    text_mask = torch.ones(text.size())
    text_mask[text>0] = 0
    
    return audio, audio_mask, text, text_mask

In [8]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 16, shuffle = True, collate_fn = collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 1, shuffle = False, collate_fn = collate_fn)

In [9]:
batch = next(iter(train_loader))

In [10]:
audio, audio_mask, text, text_mask = batch

In [11]:
print(audio.size())
print(audio_mask.size())
print(text.size())
print(text_mask.size())

torch.Size([16, 81, 64])
torch.Size([16, 81])
torch.Size([16, 7])
torch.Size([16, 7])


# Model

In [12]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x) :
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [13]:
class STT(nn.Module):
    def __init__(self):
        super(STT, self).__init__()
        
        self.prepare = torch.nn.Sequential(
                torch.nn.Linear(64, 128),
                torch.nn.LayerNorm(128),
                torch.nn.Dropout(0.1),
                torch.nn.ReLU()
            )
        self.embeddings = nn.Embedding(29, 128)
        self.pos_encoder_src = PositionalEncoding(128, 0.1)
        self.pos_encoder_trg = PositionalEncoding(128, 0.1)
        self.transformer = torch.nn.Transformer(d_model=128, nhead=4, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=512, dropout=0.1, batch_first=True)
        
        self.last_block = nn.Linear(128, 29)
    
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)).bool()
        return mask
    
    def forward(self, audio, audio_padding_mask, text, text_padding_mask):
        text_mask_nopeak = self.generate_square_subsequent_mask(text.size(1)).cuda()
        
        audio = self.prepare(audio)
        text = self.embeddings(text)
        
        audio = self.pos_encoder_src(audio)
        text = self.pos_encoder_trg(text)
        
        out = self.transformer(src = audio, tgt = text, tgt_mask = text_mask_nopeak,
                               src_key_padding_mask=audio_padding_mask, tgt_key_padding_mask=text_padding_mask)
        
        out = self.last_block(out)
        
        return out
    
    def encode(self, src, src_mask):
        audio = self.prepare(src)
        audio = self.pos_encoder_src(audio)
        return self.transformer.encoder(audio)

    def decode(self, tgt, memory, tgt_mask):
        text = self.embeddings(tgt)
        text = self.pos_encoder_trg(text)
        
        return self.transformer.decoder(text, memory, tgt_mask)

In [109]:
model = STT().cuda()
criterion = nn.CrossEntropyLoss(ignore_index = 0)
optimazer = torch.optim.Adam(model.parameters(), lr = 0.001, betas = (0.9, 0.99))
scheduler = torch.optim.lr_scheduler.LinearLR(optimazer, start_factor = 1, end_factor = 0.01, total_iters=10)
writer = SummaryWriter()
global_step = 1

In [110]:
# model(audio, audio_mask, text, text_mask).size()

In [111]:
# audio = audio.cuda()
# audio_mask = audio_mask.cuda()
# text = text.cuda()
# text_mask = text_mask.cuda()

# for i in range(100):
#     optimazer.zero_grad()
    
#     out = model(audio, audio_mask, text[:,:-1], text_mask[:,:-1])
# #     print(out.size())
# #     print(text.size())
#     loss = criterion(out.transpose(1,2), text[:,1:])
#     loss.backward()
#     optimazer.step()
#     print(loss.item())
    

In [112]:
train_dataset.tokenizer["t"]

22

In [113]:
print(train_dataset.decode(text[1]))

SOSthreeEOS


In [114]:
# model.eval()
# memory = model.encode(audio[1:2], audio_mask)
# ys = torch.ones(1, 1).fill_(1).type(torch.long).cuda()
# while True:
#     print(ys.size())
#     tgt_mask = model.generate_square_subsequent_mask(ys.size(1)).cuda()
#     out = model.decode(ys, memory, tgt_mask)
#     prob = model.last_block(out[:, -1])
#     _, next_word = torch.max(prob, dim=1)
#     next_word = next_word.item()
#     ys = torch.cat([ys.cpu(),torch.ones(1, 1).type(torch.long).fill_(next_word)], dim=1).type(torch.long).cuda()
#     if next_word == 2:
#         break
# print(train_dataset.decode(ys[0]))

In [115]:
def train(model, loader, criterion, optimazer):
    global global_step
    model.train()
    for audio, audio_mask, text, text_mask in tqdm(loader):
        
        audio = audio.cuda()
        audio_mask = audio_mask.cuda()
        text = text.cuda()
        text_mask = text_mask.cuda()
        
        optimazer.zero_grad()
        
        out = model(audio, audio_mask, text[:,:-1], text_mask[:,:-1])
        loss = criterion(out.transpose(1,2), text[:,1:])
        
        loss.backward()
        optimazer.step()
        
        writer.add_scalar('train/Loss', loss.item(), global_step)
        global_step += 1

In [116]:
def predict(model, audio, audio_mask):
    step = 1
    memory = model.encode(audio, audio_mask)
    ys = torch.ones(1, 1).fill_(1).type(torch.long).cuda()
    while step < 15:
#         print(ys.size())
        tgt_mask = model.generate_square_subsequent_mask(ys.size(1)).cuda()
        out = model.decode(ys, memory, tgt_mask)
        prob = model.last_block(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        ys = torch.cat([ys.cpu(),torch.ones(1, 1).type(torch.long).fill_(next_word)], dim=1).type(torch.long).cuda()
        step += 1
        if next_word == 2:
            break
    return train_dataset.decode(ys[0])

def val(model, loader, criterion):
    global global_step
    model.eval()
    acc = []
    for audio, audio_mask, text, text_mask in tqdm(loader):
        
        audio = audio.cuda()
        audio_mask = audio_mask.cuda()
        text = text.cuda()
        text_mask = text_mask.cuda()
        result = predict(model, audio, audio_mask)
        target = train_dataset.decode(text[0]) 
#         print(result, target)
        acc.append(Lev.distance(target, result)/len(target))
    writer.add_scalar('val/cer', sum(acc)/len(acc), global_step)
    print("Mean CER:",sum(acc)/len(acc))
    

In [117]:
# train(model, train_loader, criterion, optimazer)

In [118]:
# val(model, val_loader, criterion)

In [120]:
# val(model, val_loader, criterion)

In [None]:
for epoch in range(1, 11):
    train(model, train_loader, criterion, optimazer)
    val(model, val_loader, criterion)
    scheduler.step()

100%|██████████| 5303/5303 [01:57<00:00, 45.32it/s]
100%|██████████| 9981/9981 [01:43<00:00, 96.63it/s] 


Mean CER: 0.18468258383346955


100%|██████████| 5303/5303 [01:56<00:00, 45.63it/s]
100%|██████████| 9981/9981 [01:46<00:00, 93.96it/s] 


Mean CER: 0.14497215066045796


100%|██████████| 5303/5303 [01:56<00:00, 45.49it/s]
100%|██████████| 9981/9981 [01:45<00:00, 94.38it/s] 


Mean CER: 0.11635179268182068


100%|██████████| 5303/5303 [01:58<00:00, 44.81it/s]
100%|██████████| 9981/9981 [01:45<00:00, 94.60it/s] 


Mean CER: 0.1058312455892851


100%|██████████| 5303/5303 [01:56<00:00, 45.66it/s]
100%|██████████| 9981/9981 [01:45<00:00, 94.18it/s] 


Mean CER: 0.0839786153131509


100%|██████████| 5303/5303 [01:55<00:00, 45.74it/s]
 83%|████████▎ | 8300/9981 [01:28<00:15, 105.38it/s]

In [124]:
def ErrorAnalys(model, loader, criterion):
    pos = 0
    neg = 0
    global global_step
    model.eval()
    acc = []
    for audio, audio_mask, text, text_mask in tqdm(loader):
        
        audio = audio.cuda()
        audio_mask = audio_mask.cuda()
        text = text.cuda()
        text_mask = text_mask.cuda()
        result = predict(model, audio, audio_mask)
        target = train_dataset.decode(text[0]) 
        if result == target:
            pos += 1
        else:
            neg += 1
        acc.append(Lev.distance(target, result)/len(target))
    print("Mean CER:",sum(acc)/len(acc))
    print("Pos", pos/(pos+neg))
    print("Neg", neg/(pos+neg))

In [125]:
ErrorAnalys(model, val_loader, criterion)

100%|██████████| 9981/9981 [01:46<00:00, 93.86it/s] 

Mean CER: 0.05821697894886914
Pos 0.8478108405971345
Neg 0.15218915940286545



