This notebook implements a Zero shot Key Word Spotting (KWS) using ImageBind.   
`Note`: This notebook assumes that it is `inside the directory of the ImageBind` repository.  

`Import` Dependencies.

In [1]:
# clone the ImageBind repository from directory:
# git clone https://github.com/facebookresearch/ImageBind
# cd ImageBind
# pip install -r requirements.txt



In [1]:
import data
import torch
import torchaudio
import os
import numpy as np
from torchaudio.datasets import SPEECHCOMMANDS
from torchaudio.datasets.speechcommands import _get_speechcommands_metadata as load_speechcommands_item
from models import imagebind_model
from models.imagebind_model import ModalityType
import IPython.display as ipd
from scipy.io.wavfile import write
import sounddevice as sd

device = "cuda:0" if torch.cuda.is_available() else "cpu"



`Call KWS test dataset` using SPEECHCOMMANDS from torchaudio.datasets.

In [2]:
path = 'data' # Edit path to the location of the dataset
if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)
directory = os.path.join('data', 'SpeechCommands', 'speech_commands_v0.02')
test_dataset = torchaudio.datasets.SPEECHCOMMANDS(path, download=True, subset='testing')

100%|██████████| 2.26G/2.26G [18:07<00:00, 2.23MB/s]   


In [3]:
class SilenceDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(SilenceDataset, self).__init__(root, subset='testing')
        self.len = len(self._walker) // 35
        path = os.path.join(self._path, torchaudio.datasets.speechcommands.EXCEPT_FOLDER)
        self.paths = [os.path.join(path, p) for p in os.listdir(path) if p.endswith('.wav')]

    def __getitem__(self, index):
        index = np.random.randint(0, len(self.paths))
        filepath = self.paths[index]
        waveform, sample_rate = torchaudio.load(filepath)
        return filepath, sample_rate, "silence", 0, 0

    def __len__(self):
        return self.len

class UnknownDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(UnknownDataset, self).__init__(root, subset='testing')
        self.len = len(self._walker) // 35

    def __getitem__(self, index):
        index = np.random.randint(0, len(self._walker))
        fileid = self._walker[index]
        waveform, sample_rate, _, speaker_id, utterance_number = load_speechcommands_item(fileid, self._path)
        return fileid, sample_rate, "unknown", speaker_id, utterance_number

    def __len__(self):
        return self.len

In [4]:
# Call SilenceDataset and UnknownDataset class and concatenate them to test_dataset
silence_dataset = SilenceDataset(path)
unknown_dataset = UnknownDataset(path)
test_dataset = torch.utils.data.ConcatDataset([test_dataset, silence_dataset, unknown_dataset])

`Load Input Files`: Map a text list to image/audio paths.

In [5]:
# Call the classes in KWS dataset 
classes = ['silence', 'unknown', 'backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow',
               'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no',
               'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three',
               'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']

classToIdx = {c: i for i, c in enumerate(classes)}
audio_pth = [[] for i in range(len(classes))]

`Load audio file for classification`:   
`Randomly pick` an audio from the test dataset, or `record user voice` under "user_recording.wav".

In [6]:
class AudioPlayerRecorder:
    def __init__(self, sample_rate):
        self.sample_rate = sample_rate

    def play_audio(self, audio):
        ipd.display(ipd.Audio(audio, rate=self.sample_rate))

    def record_audio(self, duration):
        recording = sd.rec(int(duration * self.sample_rate), samplerate=self.sample_rate, channels=1)
        print("Recording Audio...")
        sd.wait()
        write("user_recording.wav", self.sample_rate, recording)
        return recording
    
def get_random_audio(directory):
    random_index = np.random.randint(0, len(test_dataset))
    waveform, sample_rate, label, speaker_id, utterance_number = test_dataset[random_index]
    if label != 'silence' and label != 'unknown':
        rdm_path = os.path.join(directory, label, '{}_nohash_{}.wav'.format(speaker_id, utterance_number))
    else:
        rdm_path = waveform # Silence/Unknown Dataset returns the path of the audio file instead of the waveform. This was done from the class instantiation.
    return waveform, sample_rate, label, speaker_id, rdm_path

Choose `random` audio from the test split

In [7]:
def get_random_audio_from_path(directory):
    waveform, sample_rate, label, speaker_id, rdm_path = get_random_audio(directory)
    audio_player_recorder = AudioPlayerRecorder(sample_rate=sample_rate)
    audio_player_recorder.play_audio(waveform)
    print('Label:', label, 'Speaker_id: ', speaker_id)
    
    return waveform, sample_rate, label, speaker_id, rdm_path

`Record user's voice` for testing 

In [14]:
user_freq = 44100
audio_player_recorder = AudioPlayerRecorder(sample_rate=user_freq) # sample rate is 44100 by default
duration = 2  # seconds to record
recording = audio_player_recorder.record_audio(duration)
audio_player_recorder.play_audio("user_recording.wav")

Recording Audio...


`Instantiate` the model and `define` inference function.

In [8]:
# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)

Downloading imagebind weights to .checkpoints/imagebind_huge.pth ...


100%|██████████| 4.47G/4.47G [02:07<00:00, 37.8MB/s] 


ImageBindModel(
  (modality_preprocessors): ModuleDict(
    (vision): RGBDTPreprocessor(
      (cls_token): tensor((1, 1, 1280), requires_grad=True)
      
      (rgbt_stem): PatchEmbedGeneric(
        (proj): Sequential(
          (0): PadIm2Video()
          (1): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
        )
      )
      (pos_embedding_helper): SpatioTemporalPosEmbeddingHelper(
        (pos_embed): tensor((1, 257, 1280), requires_grad=True)
        
      )
    )
    (text): TextPreprocessor(
      (pos_embed): tensor((1, 77, 1024), requires_grad=True)
      (mask): tensor((77, 77), requires_grad=False)
      
      (token_embedding): Embedding(49408, 1024)
    )
    (audio): AudioPreprocessor(
      (cls_token): tensor((1, 1, 768), requires_grad=True)
      
      (rgbt_stem): PatchEmbedGeneric(
        (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10), bias=False)
        (norm_layer): LayerNorm((768,), eps=1e-05, elementwise_affine=

In [9]:
# Call audio x text inference
def inference(audio_path, classes=classes):
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(classes, device),
        ModalityType.AUDIO: data.load_and_transform_audio_data(audio_path, device), 
    }

    with torch.no_grad():
        embeddings = model(inputs)
        
    inference = torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1)
    top_probs, top_labels = inference.cpu().topk(1, dim=-1)
    
    return top_probs, top_labels

`Model Inference-- Solo`: Predict the class of the image on one audio path.

In [10]:
waveform, sample_rate, label, speaker_id, rdm_path = get_random_audio_from_path(directory)

Label: tree Speaker_id:  84d1e469


In [11]:
# Load audio path input data
audio_path = [rdm_path] # Call random audio

# Uncomment to use user recording
# audio_path = ["user_recording.wav"] 
# label = "zero" # Set user recording ground truth label

top_probs, top_labels = inference(audio_path)
print('Predicted label:', classes[top_labels[0][0].item()], 
      '\nProbability score:', top_probs[0][0].item())
print('Ground Truth: ', label)



Predicted label: house 
Probability score: 1.0
Ground Truth:  tree


`Model Inference-- Group`: Predict the class of the image on multiple points of audio path.

In [12]:
# value n is the number of times the user wants to play random audio. AKA no. of data points
def call_random_n_times(n):
    group_audio_path = []
    group_gt = []
    for i in range(n):
        waveform, sample_rate, label, speaker_id, rdm_path = get_random_audio(directory)
        group_audio_path.append(rdm_path)
        group_gt.append(label)
        
    return group_audio_path, group_gt, sample_rate

def evaluate(n, group_audio_path, group_gt):
    group_out = []
    correct_label = 0
    for i in range(n):
        top_probs, top_labels = inference([group_audio_path[i]])
        if classes[top_labels[0][0].item()] == group_gt[i]:
                correct_label += 1
        group_out.append(("Index:", i, "Predicted label:", classes[top_labels[0][0].item()], "Correct label:",  group_gt[i]))
    accuracy = correct_label/n
    print("------------------------------"
         "\nSummary Statistics: ", 
         "\n------------------------------",
          "\nNumber of data points: ", n,
          "\nAccuracy:", accuracy,
          "\n------------------------------")
    print("Predictions vs Labels: ", group_out)

In [13]:
n = 500 # Number of data points
# Call inference n times
group_audio_path, group_gt, sample_rate = call_random_n_times(n)
# Evaluate the inference
evaluate(n, group_audio_path, group_gt)



------------------------------
Summary Statistics:  
------------------------------ 
Number of data points:  500 
Accuracy: 0.022 
------------------------------
Predictions vs Labels:  [('Index:', 0, 'Predicted label:', 'marvin', 'Correct label:', 'off'), ('Index:', 1, 'Predicted label:', 'bed', 'Correct label:', 'silence'), ('Index:', 2, 'Predicted label:', 'bed', 'Correct label:', 'two'), ('Index:', 3, 'Predicted label:', 'dog', 'Correct label:', 'happy'), ('Index:', 4, 'Predicted label:', 'happy', 'Correct label:', 'five'), ('Index:', 5, 'Predicted label:', 'marvin', 'Correct label:', 'no'), ('Index:', 6, 'Predicted label:', 'off', 'Correct label:', 'nine'), ('Index:', 7, 'Predicted label:', 'forward', 'Correct label:', 'backward'), ('Index:', 8, 'Predicted label:', 'up', 'Correct label:', 'left'), ('Index:', 9, 'Predicted label:', 'wow', 'Correct label:', 'tree'), ('Index:', 10, 'Predicted label:', 'off', 'Correct label:', 'yes'), ('Index:', 11, 'Predicted label:', 'dog', 'Correct

`Comparison`: Show SOTA models scores.

In [14]:
# import module
from tabulate import tabulate

mydata = [
    ["M2D", 98.5, "Not Zero-Shot", "Supervised"],
    ["EAT-S", 98.15, "Not Zero-Shot", "Supervised"],
    ["Audio Spectrogram Transformer", 98.11, "Not Zero-Shot", "Supervised"],
    ["KW-MLP", 97.56, "Not Zero-Shot", "Supervised"],
      ["TripletLoss-res15", 	97.0, "Not Zero-Shot", "Supervised"],
]

head = ["Model", "Score", "Classification type", "Supervision type"]
 
# display table
print(tabulate(mydata, headers=head, tablefmt="grid"))

+-------------------------------+---------+-----------------------+--------------------+
| Model                         |   Score | Classification type   | Supervision type   |
| M2D                           |   98.5  | Not Zero-Shot         | Supervised         |
+-------------------------------+---------+-----------------------+--------------------+
| EAT-S                         |   98.15 | Not Zero-Shot         | Supervised         |
+-------------------------------+---------+-----------------------+--------------------+
| Audio Spectrogram Transformer |   98.11 | Not Zero-Shot         | Supervised         |
+-------------------------------+---------+-----------------------+--------------------+
| KW-MLP                        |   97.56 | Not Zero-Shot         | Supervised         |
+-------------------------------+---------+-----------------------+--------------------+
| TripletLoss-res15             |   97    | Not Zero-Shot         | Supervised         |
+--------------------