### Install requirements

In [None]:
!pip install pytorch-lightning

In [None]:
!pip install torchaudio-augmentations

### Import libs

In [None]:
from pathlib import Path
import librosa
import random
from IPython.display import Audio
import os
import numpy as np
import torch
import torchaudio.transforms as T
import soundfile as sf
import torchaudio
from tqdm.notebook import trange
from torchaudio_augmentations import *
PATH = Path('/content/drive/MyDrive/wakeup_word')

os.listdir(PATH)

### Set up paths

In [None]:
NO_WAKEUP = PATH/'no_wakeup'
WAKEUP = PATH/'wakeup'
wakeup_file = PATH/'thanos_message.wav'
no_wakeup_file = PATH/'thanos_message_no_wakeup.wav'

SAVE_SR = 16000

num_segments = 1024

### Load wakeup word file

In [None]:
waveform, sample_rate = torchaudio.load(wakeup_file)

Audio(data=waveform, rate=sample_rate)

# Randomly sample and cut no wakup word parts

In [None]:
wav, sr = torchaudio.load(no_wakeup_file)

cut_size = sr*1 # cut one second
start_times = [random.randint(0, wav.shape[1] - cut_size) for _ in range(num_segments)]

resampler = T.Resample(sr, SAVE_SR, dtype=wav.dtype)

for i, start in enumerate(start_times):
    end = start + cut_size
    segment = wav[:,start:end]
    resampled_waveform = resampler(segment)

    save_path = NO_WAKEUP/f"segment_{i}.wav"
    # torchaudio.save(save_path, resampled_waveform, SAVE_SR)

### Define audio augmentations

In [None]:
_rand_transforms = [
    RandomApply([PolarityInversion()], p=0.1),
    RandomApply([Noise(min_snr=0.1, max_snr=0.15)], p=0.1),
    RandomApply([Gain()], p=0.1),
    RandomApply([Delay(sample_rate=SAVE_SR)], p=0.1),
    RandomApply([PitchShift(
        n_samples=SAVE_SR//4,
        sample_rate=SAVE_SR
    )], p=0.1),
    RandomApply([Reverb(sample_rate=SAVE_SR)], p=0.1)
]

transformator = Compose(transforms=_rand_transforms)

## Sample wakeup word with random cuts

In [None]:
num_segments=1024

In [None]:
waveform, sample_rate = torchaudio.load(wakeup_file)
# waveform = waveform[:,:sample_rate]

resampler = T.Resample(sample_rate, SAVE_SR)
waveform = resampler(waveform)

cut_size = 1*SAVE_SR
start_times = [random.randint(0, waveform.shape[1] - cut_size) for _ in range(num_segments)]


for i in trange(num_segments):
  _st = start_times[i]
  _end = _st+cut_size

  transformed_audio =  transformator(waveform[:,_st:_end])
  assert SAVE_SR == len(transformed_audio[0])

  save_path = WAKEUP/f"segment_{i}.wav"
  # torchaudio.save(save_path, transformed_audio, SAVE_SR)

# Create Dataset

In [None]:
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np

class WakeupWordDataset(Dataset):
    def __init__(self, audio_files, labels):
        self.audio_files = audio_files
        self.labels = labels

    def __getitem__(self, index):
        audio, sample_rate = torchaudio.load(self.audio_files[index])
        mfcc_transform = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_mels=80,
            )
        mfcc = mfcc_transform(audio[:,:16000]).squeeze()
        label = torch.Tensor([self.labels[index]])
        return mfcc.unsqueeze(0).float(), label

    def __len__(self):
        return len(self.audio_files)

In [None]:
wakeup_names = os.listdir(WAKEUP)
no_wakeup_names = os.listdir(NO_WAKEUP)

random.shuffle(wakeup_names)
random.shuffle(no_wakeup_names)

wakeup_names = list(map(lambda x: WAKEUP/x, wakeup_names))
no_wakeup_names = list(map(lambda x: NO_WAKEUP/x, no_wakeup_names))

train_size = int(num_segments*0.67)
test_size = num_segments-train_size

train_names = wakeup_names[:train_size] + no_wakeup_names[:train_size]


train_labels = [1]*train_size+[0]*train_size

test_names = wakeup_names[-test_size:] + no_wakeup_names[-test_size:]

test_labels = [1]*test_size+[0]*test_size

In [None]:
train_dataset = WakeupWordDataset(train_names, train_labels)
test_dataset = WakeupWordDataset(test_names, test_labels)

batch_size = 16

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Visualization

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(true_labels, predicted_labels):
    # Compute confusion matrix
    cm = confusion_matrix(true_labels, predicted_labels)

    # Plot confusion matrix
    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=['Negative', 'Positive'],
           yticklabels=['Negative', 'Positive'],
           title='Confusion matrix',
           ylabel='True label',
           xlabel='Predicted label')

    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], 'd'),
                    ha="center", va="center",
                    color="white" if cm[i, j] > cm.max() / 2. else "black")

    fig.tight_layout()
    plt.show()


# Define Model

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torch.nn as nn


class WakeupWordCNN(pl.LightningModule):
    def __init__(self):
        super(WakeupWordCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=7, padding=0)
        self.bn1 = nn.BatchNorm2d(8)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(8 * 37 * 37, 2)

        self.training_step_accs = []
        self.val_step_accs = []

        self.test_preds=[]
        self.test_labels=[]

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool(x)

        x = x.view(-1,8*37 * 37)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

    def training_step(self, batch, batch_idx):
        mfcc, labels = batch
        outputs = self(mfcc)
        loss = F.binary_cross_entropy(outputs[:,1], labels.squeeze())
        self.log('train_loss', loss, on_step=True, on_epoch=True)

        preds = (outputs[:,1] > 0.5).float()
        acc = (preds == labels.squeeze()).float().tolist()
        self.training_step_accs.extend(acc)

        return loss

    def on_train_epoch_end(self):
        self.training_step_accs = []

    def on_train_epoch_end(self):
        acc = sum(self.training_step_accs) / len(self.training_step_accs)
        self.log('train_acc_epoch', acc, on_step=False, on_epoch=True)

    def validation_step(self, batch, batch_idx):
        mfcc, labels = batch
        outputs = self(mfcc)
        loss = F.binary_cross_entropy(outputs[:,1], labels.squeeze())

        preds = (outputs[:,1] > 0.5).float()
        acc = (preds == labels.squeeze()).float().tolist()
        self.val_step_accs.extend(acc)

        self.log('val_loss', loss, on_step=True, on_epoch=True)


    def on_validation_epoch_start(self):
        self.val_step_accs = []


    def on_validation_epoch_end(self):
        acc = sum(self.val_step_accs) / len(self.val_step_accs)
        self.log('val_acc_epoch', acc, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        mfcc, labels = batch
        outputs = self(mfcc)
        loss = F.binary_cross_entropy(outputs[:,1], labels.squeeze())

        preds = torch.argmax(outputs, dim=1)
        acc = torch.sum(preds == labels.squeeze()).item() / len(labels)

        self.test_preds.extend(preds.tolist())
        self.test_labels.extend(labels.int().squeeze().tolist())

    def on_test_epoch_start(self):
        self.test_preds=[]
        self.test_labels=[]

    def on_test_epoch_end(self):
        plot_confusion_matrix(self.test_labels, self.test_preds)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=5e-4)
        return optimizer

# Train

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir lightning_logs/

In [None]:
model = WakeupWordCNN()

trainer = pl.Trainer(max_epochs=10)
trainer.fit(
    model,
    train_loader,
    test_loader,
)

### Test

In [None]:
trainer.test(
    model,
    test_loader
)

### Save scripted model

In [None]:
script = model.to_torchscript(file_path=PATH/"model.pt", method="script")

### Load model

In [None]:
import torch
from pathlib import Path

PATH = Path('/content/drive/MyDrive/wakeup_word')
model = torch.jit.load(PATH/"model.pt")
model.eval()

# Listen radio and catch wakeup words

In [None]:
import torch
from pathlib import Path
import requests
import io
import time
import torchaudio
from IPython.display import Audio

PATH = Path('/content/drive/MyDrive/wakeup_word')
model = torch.jit.load(PATH/"model.pt")
model.eval()

model.cuda()

In [None]:
def process_audio(wav, sample_rate):
  """Cuts wav with sliding window and makes predictions."""
  _wav = wav.unfold(-1,sample_rate, 2000)

  mfcc_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_mels=80,
  )

  for i in range(_wav.shape[1]):
      _cur_wav = _wav[:,i,:]

      mfcc = mfcc_transform(_cur_wav).float().unsqueeze(0)
      mfcc=mfcc.cuda()
      out = model(mfcc)
      out = (out[:,1]>0.95).float()
      if 1==out.item():
          print(i)
          display(Audio(data=_cur_wav, rate=sample_rate))

      mfcc=mfcc.cpu()
  torch.cuda.empty_cache()

In [None]:
audio_url = 'https://***/rario_stream'

chunk_duration = 10
stream_sample_rate = 44100
output_buffer = io.BytesIO()
sr=16000
work_sample_rate = 16000
resampler = torchaudio.transforms.Resample(sr, work_sample_rate)

### Listen stream. Display wakeup words

In [None]:
output_buffer = io.BytesIO()
while True:
    response = requests.get(audio_url, stream=True)
    for chunk in response.iter_content(chunk_size=4096):
        output_buffer.write(chunk)
        if output_buffer.getbuffer().nbytes >= (8000 * chunk_duration):
            output_buffer.seek(0)
            audio_data, sr = torchaudio.load(output_buffer)

            resampled_waveform = resampler(audio_data)
            process_audio(audio_data, work_sample_rate)
            output_buffer = io.BytesIO()

    if response.status_code==404:
        print('status 404. so sad.')
        time.sleep(15)