In [1]:
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[1], 'GPU')
tf.config.experimental.set_memory_growth(gpus[1], True)

AUTOTUNE = tf.data.AUTOTUNE

import os
from datetime import datetime
import matplotlib.pyplot as plt
import sys

sys.path.insert(0, '../training')
from shared_funcs import multi_label_binary_encode_tensor, multi_label_binary_decode_tensor, get_waveform, split_into_windows, split_into_sequences 

2024-12-01 21:31:50.162187: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-01 21:31:50.175792: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-01 21:31:50.180031: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-01 21:31:50.190886: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
I0000 00:00:1733113914.928538  506318 cuda_executor.c

In [2]:
ALL_LABELS = tf.constant(['Rctrl', 'p', 'esc', 'g', 'slash', 'down', '7', 'equal', 'w', 'a', 'dash', 'caps', 'l', 'd', 'backspace', 'bracketclose', 'z', '1', 'end', 'Rshift', 'comma', 'c', 'tab', 'b', 'j', 'right', 'Lctrl', 'n', 't', 'f', 'm', 'o', 'apostrophe', 'y', '8', 'space', 'backslash', 's', '9', 'i', 'r', 'bracketopen', 'semicolon', 'q', '5', 'k', '3', 'x', '4', '6', '2', 'Lshift', 'left', 'backtick', 'enter', 'fullstop', 'e', '0', 'h', 'v', 'up', 'u', 'delete'], dtype=tf.string)

In [3]:
def get_MKA_waveforms_and_labels():
    datasetPath = os.path.dirname(os.getcwd()) + "/data-manipulation/MKA datasets"
    # excluded_classes = ["start", "menu", "pgdn", "pgup", "home", "neshanay xwarawa", "neshanay sarawa", "neshanay lay rast", "neshanay lay chap", "fn", "cmd", 'altL', 'altR', 'lcmd', 'Lalt']

    mkaWaveforms= []
    mkaLabels= []
    unique_cases = []

    for manufacturer in os.listdir(datasetPath):
        wavChildFolderPath = datasetPath + "/" + manufacturer + "/Sound Segment(wav)"
        caseFolders = os.listdir(wavChildFolderPath)
        
        for case in caseFolders:
            if tf.reduce_any(tf.equal(ALL_LABELS, case)):
                for file in os.listdir(wavChildFolderPath + "/" + case):
                    wf = get_waveform(wavChildFolderPath + "/" + case + "/" + file)
                    mkaWaveforms.append(wf)
                    mkaLabels.append(case)
                    # print(wavChildFolderPath + "/" + case + "/" + file)
                
                if case not in unique_cases:
                    unique_cases.append(case)
    
    if len(mkaWaveforms) != len(mkaLabels):
        print("Filepaths and Labels do not match")
        return None, None

    print("Cases: ", unique_cases)
    return mkaWaveforms, mkaLabels

# Filters out frames with insignificant signals and labels them as "no_keypress".
def filter_insignificant_frames(frames, labels, threshold=0.05):
    # Compute the max amplitude for each frame
    frame_amplitudes = tf.reduce_max(tf.abs(frames), axis=-1)

    # Identify frames with insignificant signals
    insignificant_mask = frame_amplitudes <= threshold

    # Assign "no_keypress" to insignificant frames
    no_keypress_label = tf.zeros_like(labels[0], dtype=labels.dtype)

    # Assign all-zero binary tensor to insignificant frames
    labels = tf.where(insignificant_mask[:, tf.newaxis], no_keypress_label, labels)
    
    return frames, labels

# Split and label data according to 50ms sample amplitude on clean data
def preprocess_waveform_and_label(waveform, label, threshold=0.05):

    tf.debugging.assert_equal(
        tf.shape(waveform)[0], 
        tf.constant(44100, dtype=tf.int32),
        message="Waveform must have 44100 samples"
    )

    # Split waveform into frames
    frames = split_into_windows(waveform)

    # Detect keypress in each frame
    def detect_keypress(frame):
        return tf.reduce_max(tf.abs(frame)) > threshold

    keypress_mask = tf.map_fn(
        detect_keypress,
        frames,
        fn_output_signature=tf.bool
    )

    # Assign labels per frame
    binary_label = multi_label_binary_encode_tensor(label)  # Binary tensor for the given label
    labels = tf.where(
        keypress_mask[:, tf.newaxis],  # Broadcast the mask to match binary_label dimensions
        tf.tile(binary_label[tf.newaxis, :], [keypress_mask.shape[0], 1]),  # Tile the binary label for all frames
        tf.zeros((keypress_mask.shape[0], len(ALL_LABELS)), dtype=tf.int32)  # All-zero tensor for no_keypress
    )

    # Relabel insignificant frames as "no_keypress"
    frames, labels = filter_insignificant_frames(frames, labels, threshold=0.15)

    # Create sequence of 5 frames for temporal differences
    frame_sequences = split_into_sequences(frames, 3)
    label_sequences = split_into_sequences(labels, 3)

    # Debugging
    print("Frames shape:", frame_sequences.shape)
    print("Labels shape:", label_sequences.shape)

    return frame_sequences, label_sequences

In [4]:
# Load MKA
mka_waveforms, mka_labels = get_MKA_waveforms_and_labels()

# Verify length of audios is 1 second at 44100 Hz
for i, tr_f in enumerate(mka_waveforms):
    num_samples = tf.shape(tr_f)[0]  # Length along the first dimension (number of samples)
    # num_channels = tf.shape(tr_f)[1] # Length along the second dimension (number of channels)
    if (num_samples != 44100):
        print(f"Waveform {i}: num_samples = {num_samples}")

# Map the preprocessing function to the dataset
mka_waveform_and_label_ds = tf.data.Dataset.from_tensor_slices((mka_waveforms, mka_labels))

mka_dataset = mka_waveform_and_label_ds.map(
    lambda waveform, label: tf.data.Dataset.from_tensor_slices(preprocess_waveform_and_label(waveform, label)),
    num_parallel_calls=tf.data.AUTOTUNE
).flat_map(lambda x: x)

mka_dataset.cache()

Cases:  ['Rctrl', 'p', 'esc', 'g', 'slash', 'down', '7', 'equal', 'w', 'a', 'dash', 'caps', 'l', 'd', 'backspace', 'bracketclose', 'z', '1', 'end', 'Rshift', 'comma', 'c', 'tab', 'b', 'j', 'right', 'Lctrl', 'n', 't', 'f', 'm', 'o', 'apostrophe', 'y', '8', 'space', 'backslash', 's', '9', 'i', 'r', 'bracketopen', 'semicolon', 'q', '5', 'k', '3', 'x', '4', '6', '2', 'Lshift', 'left', 'backtick', 'enter', 'fullstop', 'e', '0', 'h', 'v', 'up', 'u', 'delete']
Frames shape: (37, 3, 2205)
Labels shape: (37, 3, 63)


<CacheDataset element_spec=(TensorSpec(shape=(3, 2205), dtype=tf.float32, name=None), TensorSpec(shape=(3, 63), dtype=tf.int32, name=None))>

In [5]:
path = os.path.dirname(os.getcwd()) + "/data-manipulation/mka_dataset"

mka_dataset.save(path)