In [None]:
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "SM2_project.ipynb",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "dzLKpmZICaWN"
      },
      "source": [
        "import os\n",
        "import pathlib\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow.keras.layers.experimental import preprocessing\n",
        "from tensorflow.keras import layers\n",
        "from tensorflow.keras import models\n",
        "from IPython import display\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2-rayb7-3Y0I"
      },
      "source": [
        "data_dir = pathlib.Path('data/mini_speech_commands')\n",
        "if not data_dir.exists():\n",
        "  tf.keras.utils.get_file(\n",
        "      'mini_speech_commands.zip',\n",
        "      origin=\"http://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip\",\n",
        "      extract=True,\n",
        "      cache_dir='.', cache_subdir='data')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "70IBxSKxA1N9"
      },
      "source": [
        "commands = np.array(tf.io.gfile.listdir(str(data_dir)))\n",
        "commands = commands[commands != 'README.md']\n",
        "print('Commands:', commands)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hlX685l1wD9k"
      },
      "source": [
        "filenames = tf.io.gfile.glob(str(data_dir) + '/*/*')\n",
        "filenames = tf.random.shuffle(filenames)\n",
        "num_samples = len(filenames)\n",
        "print('Number of total examples:', num_samples)\n",
        "print('Number of examples per label:',\n",
        "      len(tf.io.gfile.listdir(str(data_dir/commands[0]))))\n",
        "print('Example file tensor:', filenames[0])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Cv_wts-l3KgD"
      },
      "source": [
        "train_files = filenames[:6400]\n",
        "val_files = filenames[6400: 6400 + 800]\n",
        "test_files = filenames[-800:]\n",
        "\n",
        "print('Training set size', len(train_files))\n",
        "print('Validation set size', len(val_files))\n",
        "print('Test set size', len(test_files))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9PjJ2iXYwftD"
      },
      "source": [
        "def decode_audio(audio_binary):\n",
        "  audio, _ = tf.audio.decode_wav(audio_binary)\n",
        "  return tf.squeeze(audio, axis=-1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8VTtX1nr3YT-"
      },
      "source": [
        "def get_label(file_path):\n",
        "  parts = tf.strings.split(file_path, os.path.sep)\n",
        "\n",
        "  # Note: You'll use indexing here instead of tuple unpacking to enable this \n",
        "  # to work in a TensorFlow graph.\n",
        "  return parts[-2] "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WdgUD5T93NyT"
      },
      "source": [
        "def get_waveform_and_label(file_path):\n",
        "  label = get_label(file_path)\n",
        "  audio_binary = tf.io.read_file(file_path)\n",
        "  waveform = decode_audio(audio_binary)\n",
        "  return waveform, label"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0SQl8yXl3kNP"
      },
      "source": [
        "AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
        "files_ds = tf.data.Dataset.from_tensor_slices(train_files)\n",
        "waveform_ds = files_ds.map(get_waveform_and_label, num_parallel_calls=AUTOTUNE)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8yuX6Nqzf6wT"
      },
      "source": [
        "rows = 3\n",
        "cols = 3\n",
        "n = rows*cols\n",
        "fig, axes = plt.subplots(rows, cols, figsize=(10, 12))\n",
        "for i, (audio, label) in enumerate(waveform_ds.take(n)):\n",
        "  r = i // cols\n",
        "  c = i % cols\n",
        "  ax = axes[r][c]\n",
        "  ax.plot(audio.numpy())\n",
        "  ax.set_yticks(np.arange(-1.2, 1.2, 0.2))\n",
        "  label = label.numpy().decode('utf-8')\n",
        "  ax.set_title(label)\n",
        "\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_4CK75DHz_OR"
      },
      "source": [
        "def get_spectrogram(waveform):\n",
        "  # Padding for files with less than 16000 samples\n",
        "  zero_padding = tf.zeros([16000] - tf.shape(waveform), dtype=tf.float32)\n",
        "\n",
        "  # Concatenate audio with padding so that all audio clips will be of the \n",
        "  # same length\n",
        "  waveform = tf.cast(waveform, tf.float32)\n",
        "  equal_length = tf.concat([waveform, zero_padding], 0)\n",
        "  spectrogram = tf.signal.stft(\n",
        "      equal_length, frame_length=255, frame_step=128)\n",
        "      \n",
        "  spectrogram = tf.abs(spectrogram)\n",
        "\n",
        "  return spectrogram"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4Mu6Y7Yz3C-V"
      },
      "source": [
        "for waveform, label in waveform_ds.take(1):\n",
        "  label = label.numpy().decode('utf-8')\n",
        "  spectrogram = get_spectrogram(waveform)\n",
        "\n",
        "print('Label:', label)\n",
        "print('Waveform shape:', waveform.shape)\n",
        "print('Spectrogram shape:', spectrogram.shape)\n",
        "print('Audio playback')\n",
        "display.display(display.Audio(waveform, rate=16000))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e62jzb36-Jog"
      },
      "source": [
        "def plot_spectrogram(spectrogram, ax):\n",
        "  # Convert to frequencies to log scale and transpose so that the time is\n",
        "  # represented in the x-axis (columns).\n",
        "  log_spec = np.log(spectrogram.T)\n",
        "  height = log_spec.shape[0]\n",
        "  X = np.arange(16000, step=height + 1)\n",
        "  Y = range(height)\n",
        "  ax.pcolormesh(X, Y, log_spec)\n",
        "\n",
        "\n",
        "fig, axes = plt.subplots(2, figsize=(12, 8))\n",
        "timescale = np.arange(waveform.shape[0])\n",
        "axes[0].plot(timescale, waveform.numpy())\n",
        "axes[0].set_title('Waveform')\n",
        "axes[0].set_xlim([0, 16000])\n",
        "plot_spectrogram(spectrogram.numpy(), axes[1])\n",
        "axes[1].set_title('Spectrogram')\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "43IS2IouEV40"
      },
      "source": [
        "def get_spectrogram_and_label_id(audio, label):\n",
        "  spectrogram = get_spectrogram(audio)\n",
        "  spectrogram = tf.expand_dims(spectrogram, -1)\n",
        "  label_id = tf.argmax(label == commands)\n",
        "  return spectrogram, label_id"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yEVb_oK0oBLQ"
      },
      "source": [
        "spectrogram_ds = waveform_ds.map(\n",
        "    get_spectrogram_and_label_id, num_parallel_calls=AUTOTUNE)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QUbHfTuon4iF"
      },
      "source": [
        "rows = 3\n",
        "cols = 3\n",
        "n = rows*cols\n",
        "fig, axes = plt.subplots(rows, cols, figsize=(10, 10))\n",
        "for i, (spectrogram, label_id) in enumerate(spectrogram_ds.take(n)):\n",
        "  r = i // cols\n",
        "  c = i % cols\n",
        "  ax = axes[r][c]\n",
        "  plot_spectrogram(np.squeeze(spectrogram.numpy()), ax)\n",
        "  ax.set_title(commands[label_id.numpy()])\n",
        "  ax.axis('off')\n",
        "  \n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "10UI32QH_45b"
      },
      "source": [
        "def preprocess_dataset(files):\n",
        "  files_ds = tf.data.Dataset.from_tensor_slices(files)\n",
        "  output_ds = files_ds.map(get_waveform_and_label, num_parallel_calls=AUTOTUNE)\n",
        "  output_ds = output_ds.map(\n",
        "      get_spectrogram_and_label_id,  num_parallel_calls=AUTOTUNE)\n",
        "  return output_ds"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HNv4xwYkB2P6"
      },
      "source": [
        "train_ds = spectrogram_ds\n",
        "val_ds = preprocess_dataset(val_files)\n",
        "test_ds = preprocess_dataset(test_files)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UgY9WYzn61EX"
      },
      "source": [
        "batch_size = 64\n",
        "train_ds = train_ds.batch(batch_size)\n",
        "val_ds = val_ds.batch(batch_size)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fdZ6M-F5_QzY"
      },
      "source": [
        "train_ds = train_ds.cache().prefetch(AUTOTUNE)\n",
        "val_ds = val_ds.cache().prefetch(AUTOTUNE)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ALYz7PFCHblP"
      },
      "source": [
        "for spectrogram, _ in spectrogram_ds.take(1):\n",
        "  input_shape = spectrogram.shape\n",
        "print('Input shape:', input_shape)\n",
        "num_labels = len(commands)\n",
        "\n",
        "norm_layer = preprocessing.Normalization()\n",
        "norm_layer.adapt(spectrogram_ds.map(lambda x, _: x))\n",
        "\n",
        "model = models.Sequential([\n",
        "    layers.Input(shape=input_shape),\n",
        "    preprocessing.Resizing(32, 32), \n",
        "    norm_layer,\n",
        "    layers.Conv2D(32, 3, activation='relu'),\n",
        "    layers.Conv2D(64, 3, activation='relu'),\n",
        "    layers.MaxPooling2D(),\n",
        "    layers.Dropout(0.25),\n",
        "    layers.Flatten(),\n",
        "    layers.Dense(128, activation='relu'),\n",
        "    layers.Dropout(0.5),\n",
        "    layers.Dense(num_labels),\n",
        "])\n",
        "\n",
        "model.summary()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "poJfoXyEbVGe"
      },
      "source": [
        "spectrogram_ds.map(lambda x, _: x)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wFjj7-EmsTD-"
      },
      "source": [
        "model.compile(\n",
        "    optimizer=tf.keras.optimizers.Adam(),\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=['accuracy'],\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ttioPJVMcGtq"
      },
      "source": [
        "EPOCHS = 10\n",
        "history = model.fit(\n",
        "    train_ds, \n",
        "    validation_data=val_ds,  \n",
        "    epochs=EPOCHS,\n",
        "    callbacks=tf.keras.callbacks.EarlyStopping(verbose=1, patience=2),\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nzhipg3Gu2AY"
      },
      "source": [
        "metrics = history.history\n",
        "plt.plot(history.epoch, metrics['loss'], metrics['val_loss'])\n",
        "plt.legend(['loss', 'val_loss'])\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "biU2MwzyAo8o"
      },
      "source": [
        "test_audio = []\n",
        "test_labels = []\n",
        "\n",
        "for audio, label in test_ds:\n",
        "  test_audio.append(audio.numpy())\n",
        "  test_labels.append(label.numpy())\n",
        "\n",
        "test_audio = np.array(test_audio)\n",
        "test_labels = np.array(test_labels)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ktUanr9mRZky"
      },
      "source": [
        "y_pred = np.argmax(model.predict(test_audio), axis=1)\n",
        "y_true = test_labels\n",
        "\n",
        "test_acc = sum(y_pred == y_true) / len(y_true)\n",
        "print(f'Test set accuracy: {test_acc:.0%}')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LvoSAOiXU3lL"
      },
      "source": [
        "confusion_mtx = tf.math.confusion_matrix(y_true, y_pred) \n",
        "plt.figure(figsize=(10, 8))\n",
        "sns.heatmap(confusion_mtx, xticklabels=commands, yticklabels=commands, \n",
        "            annot=True, fmt='g')\n",
        "plt.xlabel('Prediction')\n",
        "plt.ylabel('Label')\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zRxauKMdhofU"
      },
      "source": [
        "sample_file = data_dir/'no/01bb6a2a_nohash_0.wav'\n",
        "\n",
        "sample_ds = preprocess_dataset([str(sample_file)])\n",
        "\n",
        "for spectrogram, label in sample_ds.batch(1):\n",
        "  prediction = model(spectrogram)\n",
        "  plt.bar(commands, tf.nn.softmax(prediction[0]))\n",
        "  plt.title(f'Predictions for \"{commands[label[0]]}\"')\n",
        "  plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "370QW0-CnVW1"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}