In [47]:
import tensorflow as tf
import os
from datetime import datetime

In [48]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[1], 'GPU')
tf.config.experimental.set_memory_growth(gpus[1], True)

AUTOTUNE = tf.data.AUTOTUNE

In [49]:
INPUT_AUDIO_DIR = os.path.dirname(os.getcwd()) + "/data-manipulation/input/audio/"
INPUT_META_DIR = os.path.dirname(os.getcwd()) + "/data-manipulation/input/metadata/"

if not os.path.exists(INPUT_AUDIO_DIR):
    os.makedirs(INPUT_AUDIO_DIR)
if not os.path.exists(INPUT_META_DIR):
    os.makedirs(INPUT_META_DIR)

## Define functions
Mainly pulled from model.ipynb

In [None]:
ALL_LABELS = tf.constant(['Rctrl', 'p', 'esc', 'g', 'slash', 'down', '7', 'equal', 'w', 'a', 'dash', 'caps', 'l', 'd', 'backspace', 'bracketclose', 'z', '1', 'end', 'Rshift', 'comma', 'c', 'tab', 'b', 'j', 'right', 'Lctrl', 'n', 't', 'f', 'm', 'o', 'apostrophe', 'y', '8', 'space', 'backslash', 's', '9', 'i', 'r', 'bracketopen', 'semicolon', 'q', '5', 'k', '3', 'x', '4', '6', '2', 'Lshift', 'left', 'backtick', 'enter', 'fullstop', 'e', '0', 'h', 'v', 'up', 'u', 'delete'], dtype=tf.string)

# Encodes a list of instance labels into a binary vector
def multi_label_binary_encode_tensor(instance_labels):
    # Ensure instance_labels is a list (to handle both single and multi-label cases)
    if isinstance(instance_labels, tf.Tensor) and instance_labels.shape == ():
        instance_labels = tf.expand_dims(instance_labels, axis=0)

    # Ensure instance_labels is a list
    if isinstance(instance_labels, str):
        instance_labels = [instance_labels]
        
    # print(f"Encoding labels: {instance_labels}")  # Debug

    # Create a tensor of zeros with the same length as all_labels
    binary_vector = tf.zeros(len(ALL_LABELS), dtype=tf.int32)

    # Iterate through instance_labels and set corresponding indices to 1
    for label in instance_labels:
        # Find the index of the label in ALL_LABELS using TensorFlow string matching
        matches = tf.equal(ALL_LABELS, label)
        indices = tf.where(matches)  # Indices where matches occur
        if tf.size(indices) > 0:  # Ensure the label exists in ALL_LABELS
            index = indices[0][0]
            binary_vector = tf.tensor_scatter_nd_update(
                binary_vector, indices=[[index]], updates=[1]
            )
    
    return binary_vector

def multi_label_binary_decode_tensor(binary_vector):
    binary_vector = binary_vector.numpy()
    decoded_labels = [ALL_LABELS[i] for i, val in enumerate(binary_vector) if val == 1]
    return decoded_labels

def get_waveform(filepath):
    audio_binary = tf.io.read_file(filepath)
    audio = tf.squeeze(audio_binary)
    waveform, samplerate = tf.audio.decode_wav(audio)

    # Reduce to 1 channel by averaging
    waveform = tf.reduce_mean(waveform, axis=-1)
    
    if (samplerate != 44100):
        print("Incorrect sample rate: " + filepath)
    
    return waveform

def split_into_windows(waveform, frame_length=2205, frame_step=1102): # 50ms windows with 50% overlap
    print("Waveform shape before framing:", waveform.shape)
    frames = tf.signal.frame(waveform, frame_length=frame_length, frame_step=frame_step)
    print("Frames shape after framing (with overlap):", frames.shape)
    return frames

def split_into_sequences(frames, sequence_length=3):
    num_frames = tf.shape(frames)[0]
    sequence_step = 1
    start_indices = tf.range(0, num_frames - sequence_length + 1, sequence_step)
    sequences = tf.map_fn(
        lambda start: frames[start:start + sequence_length],
        start_indices,
        fn_output_signature=tf.TensorSpec(shape=(sequence_length, frames.shape[1]), dtype=frames.dtype)
    )
    print("Frames grouped into sequences:", sequences.shape)
    return sequences

In [51]:
def load_audio_data():
    data = []
    for file in os.listdir(INPUT_AUDIO_DIR):
        split_fn = file.split("-")
        split_fn[1] = split_fn[1][:-4] # Remove .wav extension
        split_fn[0] = float(split_fn[0])
        split_fn[1] = float(split_fn[1])
        data.append({"start_time": datetime.fromtimestamp(split_fn[0]),
                      "end_time": datetime.fromtimestamp(split_fn[1]),
                        "waveform": get_waveform(INPUT_AUDIO_DIR + file)
        })
    
    return data

def loadMetadata():
    metadata = []
    for file in os.listdir(INPUT_META_DIR):
        with open(INPUT_META_DIR + file, "r") as file:
            fdata = file.readlines()

        for line in fdata:
            line = line.strip("\n").split(",")
            metadata.append({"label": multi_label_binary_encode_tensor(line[0]), "time": datetime.fromtimestamp(float(line[1]))})
    return metadata

# Remove keypresses keylogged outside of recording time
def filterMetadata(audio_data, metadata):
    filtered_metadata = []
    for mdata in metadata:
        for audio_ts in audio_data:
            start_time = audio_ts["start_time"]
            end_time = audio_ts["end_time"]
            timestamp = mdata["time"]
            if start_time <= timestamp <= end_time:
                filtered_metadata.append(mdata)
    return filtered_metadata


In [52]:
# Label frames
def preprocess_waveform_and_label(audioData, label):
    waveform = audioData["waveform"]

    tf.debugging.assert_equal(
        tf.shape(waveform)[0], 
        tf.constant(44100, dtype=tf.int32),
        message="Waveform must have 44100 samples"
    )

    frames = split_into_windows(waveform)

    # def detect_keypress(frame):
    #     return tf.reduce_max(tf.abs(frame)) > threshold

    keypress_mask = tf.map_fn(
        detect_keypress,
        frames,
        fn_output_signature=tf.bool
    )

    # Assign labels per frame
    binary_label = multi_label_binary_encode_tensor(label)  # Binary tensor for the given label
    labels = tf.where(
        keypress_mask[:, tf.newaxis],  # Broadcast the mask to match binary_label dimensions
        tf.tile(binary_label[tf.newaxis, :], [keypress_mask.shape[0], 1]),  # Tile the binary label for all frames
        tf.zeros((keypress_mask.shape[0], len(ALL_LABELS)), dtype=tf.int32)  # All-zero tensor for no_keypress
    )

    frame_sequences = split_into_sequences(frames, 3)
    label_sequences = split_into_sequences(labels, 3)

    # Debugging
    print("Frames shape:", frame_sequences.shape)
    print("Labels shape:", label_sequences.shape)

    return frame_sequences, label_sequences

## Preprocess Data

In [53]:
audio_data = load_audio_data()
metadata = filterMetadata(audio_data, loadMetadata())


for adata in audio_data:
     tf.data.Dataset.from_tensor_slices(preprocess_waveform_and_label())

print(metadata)

dataset = mka_waveform_and_label_ds.map(
    lambda waveform, label: tf.data.Dataset.from_tensor_slices(preprocess_waveform_and_label(waveform, label)),
    num_parallel_calls=tf.data.AUTOTUNE
).flat_map(lambda x: x)

Encoding labels: ['a']
Encoding labels: ['b']
Encoding labels: ['c']
Encoding labels: ['d']
Encoding labels: ['e']
Encoding labels: ['f']
Encoding labels: ['g']
Encoding labels: ['h']
Encoding labels: ['i']
Encoding labels: ['j']
Encoding labels: ['k']
Encoding labels: ['l']
Encoding labels: ['m']
Encoding labels: ['n']
Encoding labels: ['o']
Encoding labels: ['p']
Encoding labels: ['q']
Encoding labels: ['r']
Encoding labels: ['s']
Encoding labels: ['t']
Encoding labels: ['u']
Encoding labels: ['v']
Encoding labels: ['w']
Encoding labels: ['x']
Encoding labels: ['y']
Encoding labels: ['z']
Encoding labels: ['Lshift']
Encoding labels: ['h']
Encoding labels: ['e']
Encoding labels: ['l']
Encoding labels: ['l']
Encoding labels: ['o']
Encoding labels: ['space']
Encoding labels: ['f']
Encoding labels: ['r']
Encoding labels: ['o']
Encoding labels: ['m']
Encoding labels: ['space']
Encoding labels: ['t']
Encoding labels: ['h']
Encoding labels: ['e']
Encoding labels: ['space']
Encoding labels: 

TypeError: preprocess_waveform_and_label() missing 2 required positional arguments: 'audioData' and 'label'

## Convert Supplementary data to Dataset

## Save new dataset