In [1]:
### Scripts to analyze audio

In [2]:
BASEDIR="/usr2/asetlur/GraphNeuralTTS/Tacotron-pytorch/training-accentdb-char-baseline-with-additional-info/"
LOGDIR="log"

### Mel-spectogram classifier

In [14]:
from torch.utils.data import Dataset
import torch
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import glob
import os
from collections import Counter
from tensorboardX import SummaryWriter
from tqdm import tqdm
import subprocess

#### Dataloader/Dataset

In [4]:
def _pad_2d(x, max_len):
    x = np.pad(x, [(0, max_len - len(x)), (0, 0)],
               mode="constant", constant_values=0)
    return x

class MelDataset(Dataset):
    
    def __init__(self, pth):
        self.mel_files = glob.glob(f"{BASEDIR}/*mel*")
        print(f"{len(self.mel_files)} mel-files found")
        # the files are supposed to be named accent_speaker_*.wav, 
        # e.g. australian_s02_362.wav
        self.labels = [os.path.basename(mel_file_pth).split("_")[:2] for mel_file_pth in self.mel_files]
        labels = [os.path.basename(mel_file_pth).split("_")[:2] for mel_file_pth in self.mel_files]
        
        # get accent labels, make a dict
        self.accent_labels = [l[0] for l in self.labels]
        self.accent_label_dict = {k: i for i, k in enumerate(sorted(Counter(self.accent_labels).keys()))}
        print(self.accent_label_dict)
        
        # same processing for the speakers
        self.speaker_labels = [" ".join(l) for l in self.labels]
        self.speaker_label_dict = {k: i for i, k in enumerate(sorted(Counter(self.speaker_labels).keys()))}
        print(self.speaker_label_dict)

    def __getitem__(self, i):
        return np.load(self.mel_files[i])
    
    def __len__(self):
        return len(self.mel_files)

    @staticmethod
    def batchify(dataset, bsz, shuffle=True):
        idx = list(range(len(dataset)))
        if shuffle:
            np.random.shuffle(idx)

        for begin in range(0, len(dataset), bsz):
            end = min(begin + bsz, len(dataset))
            num_elems = end - begin
    
            
            # read all the mels for this batch, find the max length
            mels = [dataset[idx[i]] for i in range(begin, end)]
            seq_lengths = torch.LongTensor([len(mel) for mel in mels])
            max_target_len = seq_lengths.max().item()
            
            
            b = np.array([_pad_2d(mel, max_target_len) for mel in mels],
                 dtype=np.float32)
            mel_batch = torch.FloatTensor(b)
            speaker_labels = torch.LongTensor([dataset.speaker_label_dict[dataset.speaker_labels[idx[i]]]\
                                               for i in range(begin, end)])
            accent_labels = torch.LongTensor([dataset.accent_label_dict[dataset.accent_labels[idx[i]]]\
                                              for i in range(begin, end)])
            
            
            seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
            
            
            yield mel_batch[perm_idx], speaker_labels[perm_idx], accent_labels[perm_idx], seq_lengths
            
            

In [5]:
# dataset sanity checks
dataset = MelDataset(BASEDIR)

# check shape of one mel file
print(dataset[0].shape)

# dataloader check
dataloader = MelDataset.batchify(dataset, 32)

# check two batches for correct batchification:
for _ in range(2):
    mel, speaker, accent, input_lengths = next(dataloader)
    print(mel.shape, speaker.shape, accent.shape)

14999 mel-files found
{'american': 0, 'australian': 1, 'bangla': 2, 'british': 3, 'indian': 4, 'malayalam': 5, 'odiya': 6, 'telugu': 7, 'welsh': 8}
{'american s01': 0, 'american s02': 1, 'american s03': 2, 'american s04': 3, 'american s05': 4, 'american s06': 5, 'american s07': 6, 'american s08': 7, 'australian s01': 8, 'australian s02': 9, 'bangla s01': 10, 'bangla s02': 11, 'british s01': 12, 'british s02': 13, 'indian s01': 14, 'indian s02': 15, 'malayalam s01': 16, 'malayalam s02': 17, 'malayalam s03': 18, 'odiya s01': 19, 'telugu s01': 20, 'telugu s02': 21, 'welsh s01': 22}
(154, 80)
torch.Size([32, 379, 80]) torch.Size([32]) torch.Size([32])
torch.Size([32, 437, 80]) torch.Size([32]) torch.Size([32])


#### Model

In [56]:
def get_1dconv(in_channels, out_channels, max_pool=False):
    return nn.Sequential(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1),
                      nn.ELU(),
                      nn.BatchNorm1d(out_channels),
                      nn.MaxPool1d(3, stride=2) if max_pool else nn.Identity(),
                      nn.Dropout(p=0.1))

In [67]:
class MelClassifier(nn.Module):
    def __init__(self, 
                 num_class,
                 mel_spectogram_dim: int = 80,
                 gru_hidden_size=32,
                 gru_num_layers=2):
        super(MelClassifier, self).__init__()
        self.num_class = num_class
        self.conv_blocks = nn.Sequential(
                        get_1dconv(in_channels=mel_spectogram_dim, out_channels=64),
                        get_1dconv(in_channels=64, out_channels=128),
                        get_1dconv(in_channels=128, out_channels=128, max_pool=True),
                        get_1dconv(in_channels=128, out_channels=128, max_pool=True),
                        get_1dconv(in_channels=128, out_channels=128, max_pool=True))
            
        self.gru = nn.GRU(input_size=128, hidden_size=gru_hidden_size, num_layers=gru_num_layers,\
                          bidirectional=True, batch_first=True, dropout=0.3)
        num_directions = 2
        self.mlp = nn.Linear(gru_hidden_size * gru_num_layers * num_directions, self.num_class)

    def forward(self, mel_batch, input_lengths):
        batch_size = len(mel_batch)
        # mel_batch -> (batch_size, max_time_step, 80)
        conv_output = self.conv_blocks(mel_batch.permute(0, 2, 1)).permute(0, 2, 1)
        # conv_output -> (batch_size, max_time_step, 32)

        output, h_n = self.gru(conv_output)
        # h_n -> (4, batch_size, 32)
        
        h_n = h_n.permute(1, 0, 2).reshape(batch_size, -1)
        return h_n, self.mlp(h_n)
        

#### Training loop

In [73]:
device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

In [74]:
model = MelClassifier(len(dataset.speaker_label_dict)).to(device)

In [75]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3,
                                    betas=(0.9, 0.99),
                                    eps=1e-6,
                                    weight_decay=0.01)

In [76]:
writer = SummaryWriter(LOGDIR)

In [77]:
subprocess.call(["rm", "-r", f"{LOGDIR}/*"])

1

In [83]:
num_epochs = 10
dataset = MelDataset(BASEDIR)
loss_func = nn.CrossEntropyLoss()

losses = []
accuracy = []
for epoch in range(num_epochs):
    
    dataloader = MelDataset.batchify(dataset, 32)
    
    # training
    for i, (mels, speakers, accents, input_lengths) in enumerate(dataloader):
        mels = mels.to(device)
        speakers = speakers.to(device)
        accents = accents.to(device)
        optimizer.zero_grad()
        
        h_n, logits = model(mels, input_lengths)
        loss = loss_func(logits, speakers).mean()
        accuracy.append(sum(torch.argmax(logits, dim=1) == speakers).item() * 100. / len(accents))
    
        
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if i % 50 == 0:
            print(f"Epoch = {epoch} iter = {i} Loss = {round(np.array(losses).mean(), 2)} Acc = {round(np.array(accuracy).mean(), 2)}")
            losses = []
            
    metadata_speaker = []
    metadata_accent = []    
    dataloader = MelDataset.batchify(dataset, 32)

    # extract embedding for tboard
    print("Extracting Embedding")
    for i, (mels, speakers, accents, input_lengths) in tqdm(enumerate(dataloader), total=len(dataloader)):

        metadata_accent += speakers.numpy().tolist()
        metadata_accent += accents.numpy().tolist()
        mels = mels.to(device)
        speakers = speakers.to(device)
        accents = accents.to(device)

        with torch.no_grad():
            h_n, logits = model(mels, input_lengths)
            writer.add_embedding(tag="accent",
                                mat=h_n,
                                global_step=epoch,
                                metadata=metadata_accent)
            writer.add_embedding(tag="speaker",
                    mat=h_n,
                    global_step=epoch,
                    metadata=metadata_speaker)

14999 mel-files found
{'american': 0, 'australian': 1, 'bangla': 2, 'british': 3, 'indian': 4, 'malayalam': 5, 'odiya': 6, 'telugu': 7, 'welsh': 8}
{'american s01': 0, 'american s02': 1, 'american s03': 2, 'american s04': 3, 'american s05': 4, 'american s06': 5, 'american s07': 6, 'american s08': 7, 'australian s01': 8, 'australian s02': 9, 'bangla s01': 10, 'bangla s02': 11, 'british s01': 12, 'british s02': 13, 'indian s01': 14, 'indian s02': 15, 'malayalam s01': 16, 'malayalam s02': 17, 'malayalam s03': 18, 'odiya s01': 19, 'telugu s01': 20, 'telugu s02': 21, 'welsh s01': 22}
Epoch = 0 iter = 0 Loss = 0.15 Acc = 100.0


KeyboardInterrupt: 

In [85]:
print("Extracting Embedding")
dataloader = MelDataset.batchify(dataset, 32)


for i, (mels, speakers, accents, input_lengths) in tqdm(enumerate(dataloader), total=len(dataset) // 32):
    
    metadata_accent += speakers.numpy().tolist()
    metadata_accent += accents.numpy().tolist()
    mels = mels.to(device)
    speakers = speakers.to(device)
    accents = accents.to(device)
    
    with torch.no_grad():
        h_n, logits = model(mels, input_lengths)
        
writer.add_embedding(tag="accent",
                    mat=h_n,
                    global_step=epoch,
                    metadata=metadata_accent)
writer.add_embedding(tag="speaker",
        mat=h_n,
        global_step=epoch,
        metadata=metadata_speaker)

  0%|          | 0/468 [00:00<?, ?it/s]

Extracting Embedding





AssertionError: #labels should equal with #data points