In [32]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from glob import glob

from collections import Counter

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset
import torchaudio

import plotly.graph_objects as go
from scipy import signal

from sklearn.utils.class_weight import compute_class_weight

%config Completer.use_jedi = False

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

# Transform

In [3]:
validation_paths = {'train/audio/' + i.strip() for i in open('train/validation_list.txt', 'r').readlines()}
test_paths = {'train/audio/' + i.strip() for i in open('train/testing_list.txt', 'r').readlines()}
train_paths = set(glob('train/audio/*/*')) - validation_paths - test_paths

In [4]:
def pad(waveform, target=16000):
    size = waveform.shape[1]
    if size < target:
        yield torch.cat([waveform[:, :target-size], waveform], dim=1)
    elif size > 2 * target:
        for i in range(size//target):
            yield waveform[:, i*target:(i+1)*target]
    elif size > target:
        yield waveform[:, :target]
    else:
        yield waveform


In [72]:
def transform(waveform, sample_rate):
    mel_specgram = torchaudio.transforms.MelSpectrogram(
        sample_rate, normalized=True, win_length=400, hop_length=252, n_mels=128
    )(waveform)
    

    mel_specgram = mel_specgram[0,:,:].log2().detach().numpy()
    mel_specgram = np.clip(mel_specgram, -20, 20)
    mean = mel_specgram.mean()
    mel_specgram = (mel_specgram - mean) / mel_specgram.std() + 1e-6
    return mel_specgram.T


def transform(waveform, sample_rate):
    mel_specgram = torchaudio.transforms.MelSpectrogram(
        sample_rate, normalized=True, center=True, f_max=8000,
        win_length=400, hop_length=252, n_mels=128
        )(waveform)
    mel_specgram = mel_specgram[0,:,:]#.log().detach().numpy()
    mel_specgram = np.clip(mel_specgram, -20, 20)
    return mel_specgram.T
    
    
# taken from https://www.kaggle.com/prabhavsingh/midas-task1-final
def transform(audio, sample_rate, window_size=20, step_size=10, eps=1e-10):
    nperseg = int(round(window_size * sample_rate / 1e3))
    noverlap = int(round(step_size * sample_rate / 1e3))
    freqs, times, spec = signal.spectrogram(audio,
                                    fs=sample_rate,
                                    window='hamming',
                                    nperseg=nperseg,
                                    noverlap=noverlap,
                                    detrend=False)
    
    return np.squeeze(np.log(spec.T.astype(np.float32) + eps))


audio, sr = torchaudio.load('train/audio/cat/004ae714_nohash_0.wav')

spec = transform(audio, sr)
spec.shape

(99, 161)

In [73]:
go.Figure(go.Heatmap(z=spec.T))

In [75]:
validation_data = []
test_data = []
train_data = []
new_train_paths, new_validation_paths, new_test_paths = [], [], []

for data, paths, new_path in zip(
    [train_data, validation_data, test_data], 
    [train_paths, validation_paths, test_paths],
    [new_train_paths, new_validation_paths, new_test_paths]
):
    for path in tqdm(paths):
        try:
            waveform, sample_rate = torchaudio.load(path)
            for wave_slice in pad(waveform):
                tr = transform(wave_slice, sample_rate)
                if tr.shape != (99, 161):
                    print(path)
                    continue
                tr = torch.nan_to_num(torch.tensor(tr), nan=0, posinf=10, neginf=-10)
                data.append(tr)
                new_path.append(path)
        except Exception as err:
            print(err)
            pass


  0%|          | 0/51095 [00:00<?, ?it/s]

train/audio/on/1df99a8a_nohash_1.wav
train/audio/go/ecef25ba_nohash_0.wav
train/audio/bird/4954abe8_nohash_0.wav
train/audio/no/82c6d220_nohash_0.wav
train/audio/eight/a40c62f1_nohash_0.wav
train/audio/stop/69953f48_nohash_0.wav
train/audio/go/b01c8f61_nohash_0.wav
train/audio/one/36746d7f_nohash_0.wav
train/audio/left/b01c8f61_nohash_0.wav
train/audio/stop/82c6d220_nohash_0.wav
train/audio/seven/99081f4d_nohash_3.wav
train/audio/dog/82c6d220_nohash_0.wav
train/audio/right/a40c62f1_nohash_0.wav
train/audio/tree/c0fb6812_nohash_1.wav
train/audio/three/4254621e_nohash_0.wav
train/audio/zero/a40c62f1_nohash_0.wav
train/audio/no/a40c62f1_nohash_1.wav
Error loading audio file: failed to open file.
train/audio/nine/99081f4d_nohash_0.wav
train/audio/eight/82c6d220_nohash_0.wav
train/audio/bed/36746d7f_nohash_0.wav
train/audio/three/4954abe8_nohash_0.wav
train/audio/left/a40c62f1_nohash_0.wav
train/audio/up/c0fb6812_nohash_0.wav
train/audio/tree/a40c62f1_nohash_0.wav
train/audio/go/4954abe8_no

  0%|          | 0/6798 [00:00<?, ?it/s]

train/audio/on/22aa3665_nohash_0.wav
train/audio/stop/90804775_nohash_0.wav
train/audio/eight/90804775_nohash_0.wav
train/audio/two/794cdfc5_nohash_0.wav
train/audio/wow/22aa3665_nohash_0.wav
train/audio/dog/90804775_nohash_0.wav


  0%|          | 0/6835 [00:00<?, ?it/s]

train/audio/down/4a0e2c16_nohash_0.wav


In [76]:
open('data/train_paths.txt', 'w').writelines(
    ['/'.join(i.split('/')[2:]) + '\n' for i in list(new_train_paths)])
open('data/test_paths.txt', 'w').writelines(
    ['/'.join(i.split('/')[2:]) + '\n' for i in list(new_test_paths)])
open('data/validation_paths.txt', 'w').writelines(
    ['/'.join(i.split('/')[2:]) + '\n' for i in list(new_validation_paths)])

In [77]:
# train_data = np.nan_to_num(np.load('data/train_data.npy'), nan=0, posinf=10, neginf=-10)
# validation_data = np.nan_to_num(np.load('data/validation_data.npy'), 0, posinf=10, neginf=-10)
# test_data = np.nan_to_num(np.load('data/test_data.npy'), 0, posinf=10, neginf=-10)

train_labels = [i.strip().split('/')[0] for i in open('data/train_paths.txt').readlines()]
validation_labels = [i.strip().split('/')[0] for i in open('data/validation_paths.txt').readlines()]
test_labels = [i.strip().split('/')[0] for i in open('data/test_paths.txt').readlines()]

print(*sorted(Counter(train_labels).items(), key=lambda x: x[1]), sep='\n', end='\n\n')
print(*sorted(Counter(validation_labels).items(), key=lambda x: x[1]), sep='\n', end='\n\n')
print(*sorted(Counter(test_labels).items(), key=lambda x: x[1]), sep='\n', end='\n\n')

('_background_noise_', 398)
('bed', 1339)
('sheila', 1372)
('tree', 1372)
('happy', 1373)
('dog', 1395)
('cat', 1399)
('bird', 1410)
('wow', 1414)
('marvin', 1424)
('house', 1427)
('left', 1837)
('three', 1839)
('off', 1839)
('four', 1839)
('down', 1842)
('up', 1842)
('five', 1844)
('eight', 1850)
('no', 1851)
('right', 1851)
('go', 1858)
('yes', 1860)
('on', 1863)
('six', 1863)
('zero', 1865)
('two', 1873)
('nine', 1874)
('seven', 1874)
('stop', 1883)
('one', 1891)

('marvin', 160)
('bird', 162)
('wow', 165)
('tree', 166)
('cat', 168)
('dog', 169)
('house', 173)
('sheila', 176)
('happy', 189)
('bed', 197)
('one', 230)
('nine', 230)
('two', 235)
('eight', 242)
('five', 242)
('stop', 245)
('left', 247)
('three', 248)
('off', 256)
('on', 256)
('right', 256)
('go', 260)
('zero', 260)
('up', 260)
('yes', 261)
('six', 262)
('seven', 263)
('down', 264)
('no', 270)
('four', 280)

('house', 150)
('bird', 158)
('marvin', 162)
('wow', 165)
('cat', 166)
('bed', 176)
('dog', 180)
('happy', 180)
('

In [78]:
labels_codes = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']
def map_labels(labels):
    return [
        10 if i=='_background_noise_' 
        else labels_codes.index(i) if i in labels_codes 
        else 11
        for i in labels]

train_labels = map_labels(train_labels)
validation_labels = map_labels(validation_labels)
test_labels = map_labels(test_labels)

In [79]:
sorted(Counter(train_labels).items(), key=lambda x: x[0])

[(0, 1860),
 (1, 1851),
 (2, 1842),
 (3, 1842),
 (4, 1837),
 (5, 1851),
 (6, 1863),
 (7, 1839),
 (8, 1883),
 (9, 1858),
 (10, 398),
 (11, 32537)]

In [97]:
class SpeechDataset(Dataset):
    def __init__(self, data, label, transform=None):
#         super().__init__()
        self.data = data
        self.label = label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label = self.label[idx]
        inp = self.data[idx]
        if transform:
            inp = transform(inp)
        return inp, label

In [98]:
all([i.shape[1]==128 for i in train_data])

False

In [99]:
# mask classes 10 and 11

mask = np.array(train_labels) < 10
train_data = [train_data[i] for i in range(len(train_data)) if mask[i]]
train_labels = np.array(train_labels)[mask]

mask = np.array(validation_labels) < 10
validation_data = [validation_data[i] for i in range(len(validation_data)) if mask[i]]
validation_labels = np.array(validation_labels)[mask]

mask = np.array(test_labels) < 10
test_data = [test_data[i] for i in range(len(test_data)) if mask[i]]
test_labels = np.array(test_labels)[mask]


In [100]:
weights = compute_class_weight(class_weight='balanced', classes=np.array(list(range(10))), y=train_labels)


In [84]:
trainset = SpeechDataset(data=train_data, label=train_labels, transform=pytorch_transform)
validationset = SpeechDataset(data=validation_data, label=validation_labels, transform=pytorch_transform)
testset = SpeechDataset(data=test_data, label=test_labels, transform=pytorch_transform)

In [85]:
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=32, shuffle=True, num_workers=2)
validaionloader = torch.utils.data.DataLoader(
    validationset, batch_size=32, shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=32, shuffle=False, num_workers=2)

In [86]:
def test(net, dataloader, device='cuda:0'):
    net = net.to(device)
    conf_matrix = np.zeros((10, 10))
    for i, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        with torch.no_grad():
            pred = net(X)
            pred_cls = pred.argmax(1)
            for j in range(len(pred_cls)):
                conf_matrix[y[j]][pred_cls[j]] += 1
    return conf_matrix

In [87]:
def train(net, dataloader, validaionloader, optimizer, loss_func, epochs=100, device='cuda:0'):
    n = len(trainloader.dataset)
    net.to(device)
    train_acc = np.zeros(epochs)
    test_acc = np.zeros(epochs)
    train_loss = np.zeros(epochs)
    test_loss = np.zeros(epochs)
    
    for epoch in tqdm(range(epochs)):
        epoch_ok = 0
        for i, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            pred = net(X)
            loss = loss_func(pred, y)
            if torch.any(torch.isnan(loss)):
                breakpoint()
            train_loss[epoch] += loss.item()
            pred_cls = pred.argmax(1)
            epoch_ok += sum(pred_cls == y)
            
            loss.backward()
            optimizer.step()
        
        conf_matrix = test(net, validaionloader)    
        test_acc[epoch] = np.diag(conf_matrix).sum() / np.sum(conf_matrix)
        train_acc[epoch] = epoch_ok / n
        
        print(f'Epoch {epoch}: train loss: {train_loss[epoch]:.4f}, train acc: {train_acc[epoch]:.3f}, '
              f'test acc: {test_acc[epoch]}')
    
    return train_acc, test_acc, train_loss, test_loss

In [94]:
class SimpleLSTM(nn.Module):
    
    def __init__(self):
        super(SimpleLSTM, self).__init__()
        self.conv1 = nn.Conv1d(161, 128, kernel_size=5, padding=2)
        self.conv = nn.Conv1d(128, 128, kernel_size=5, padding=2)
        self.bn = nn.BatchNorm1d(128)
        
        self.lstm = nn.GRU(input_size=128, hidden_size=128, num_layers=2, dropout=0.5, bidirectional=True)
        self.linear = nn.Sequential(
            nn.Linear(256, 32),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(32, 12)
        )
        
    
    def forward(self, x):
        x = x.to('cuda')
        x = torch.swapaxes(x, 1, 2) # (B, L, C) -> (B, C, L) ; L = length of the sequence, C = num of channels
        x = F.relu(self.bn(self.conv1(x)))
        x = F.relu(self.bn(self.conv(x)))
        x = F.relu(self.bn(self.conv(x)))
        x = F.relu(self.bn(self.conv(x)))
        x = torch.swapaxes(x, 1, 2)
    
        x, _ = self.lstm(x)
        x = x[:, -1, :]  # get logit for last time interval
        x = self.linear(x)
        return x
    

In [95]:
# loss_func = nn.CrossEntropyLoss(weight=torch.tensor(weights, dtype=torch.float).to(device))
loss_func = nn.CrossEntropyLoss()
net = SimpleLSTM().to(device)
sgd = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
# sgd = optim.SGD(net.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(sgd, step_size=30, gamma=0.5, verbose=True)

Adjusting learning rate of group 0 to 1.0000e-02.


In [96]:
train_acc, test_acc, train_loss, test_loss = train(
    net, trainloader, validaionloader, sgd, loss_func, 300)

  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 0: train loss: 1357.5384, train acc: 0.096, test acc: 0.09514563106796116
Epoch 1: train loss: 1343.9469, train acc: 0.098, test acc: 0.10097087378640776
Epoch 2: train loss: 1340.3894, train acc: 0.101, test acc: 0.10174757281553398


KeyboardInterrupt: 

In [None]:
go.Figure(
    [go.Scatter(y=test_acc), go.Scatter(y=train_acc)]
)