In [8]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from pathlib import Path

from tensorflow import Tensor
from tqdm.notebook import tqdm
from typing import Callable, List, Tuple, Union, Optional, Dict, Any, Sequence, Iterable, TypeVar

In [2]:
class DataPipeFactory:
    def __init__(self, tfrecord_path, ref_audio_path, word_information_path, cache=None):
        self.tfrecord_path :Path = Path(tfrecord_path)
        self.ref_audio_path :Path = Path(ref_audio_path)
        self.word_information_path :Path = Path(word_information_path)
        self.__cache_status = False
        if not self.tfrecord_path.exists():
            raise FileNotFoundError(f"tfrecord_path {tfrecord_path} not found")
        if not self.ref_audio_path.exists():
            raise FileNotFoundError(f"ref_audio_path {ref_audio_path} not found")
        if not self.word_information_path.exists():
            raise FileNotFoundError(f"word_information_path {word_information_path} not found")
        self.__cache = str(cache)
        self.__pairs : tf.int32 = 2
        self.__available_voice = 4
        self.__mel_bins = 80
        self.__raw_data :tf.data.Dataset = self.__generate_raw_data()
    #create the parser function to parse the serialized generated above
    @staticmethod
    def parse_function(serialized_example : tf.string) -> Dict[str, tf.Tensor]:
        # Define a dict with the data-names and types we expect to find in the
        # serialized example.
        features = {
            'RecordName': tf.io.FixedLenFeature([], tf.string),
            'AudioSegment': tf.io.FixedLenFeature([], tf.string),
            'SampleRate': tf.io.FixedLenFeature([], tf.int64),
            'Sentence': tf.io.FixedLenFeature([], tf.string),
            'WordStart': tf.io.FixedLenFeature([], tf.string),
            'WordDuration': tf.io.FixedLenFeature([], tf.string),
            'MatchSegment': tf.io.FixedLenFeature([], tf.string),
            'MatchReference': tf.io.FixedLenFeature([], tf.string),
        }
        # Parse the input tf.Example proto using the dictionary above.
        e = tf.io.parse_single_example(serialized_example, features)
        #Convert the serialized tensor to tensor
        e['AudioSegment'] = tf.io.parse_tensor(e['AudioSegment'], out_type=tf.int16)
        e['Sentence'] = tf.io.parse_tensor(e['Sentence'], out_type=tf.int64)
        e['WordStart'] = tf.io.parse_tensor(e['WordStart'], out_type=tf.float32)
        e['WordDuration'] = tf.io.parse_tensor(e['WordDuration'], out_type=tf.float32)
        e['MatchSegment'] = tf.io.parse_tensor(e['MatchSegment'], out_type=tf.int64)
        e['MatchReference'] = tf.io.parse_tensor(e['MatchReference'], out_type=tf.int64)
        passage_id = tf.strings.split(e['RecordName'], sep='_')[3]
        #convert tf.string to int
        passage_id = tf.strings.to_number(passage_id, out_type=tf.int32) % 100000
        #convert to tf.string
        e['passage_id'] = tf.strings.as_string(passage_id)
        return e

    def __first_map_builder(self)-> Callable[[dict[str, Tensor]], dict[str, Tensor]]:
        get_mfcc = self.get_mfcc
        ref_audio_path = str(self.ref_audio_path.absolute())
        word_information_path = str(self.word_information_path.absolute())
        available_voice = self.__available_voice
        def created_map(e: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
            a = {'stu_mfcc': get_mfcc(e['AudioSegment'], e['SampleRate'])}
            ref_audio = tf.io.parse_tensor(tf.io.read_file(ref_audio_path + '/' + e['passage_id'] +'.tfs' ), out_type=tf.int16)
            a['ref_mfcc'] = get_mfcc(ref_audio, e['SampleRate'])
            passage_word = tf.io.parse_tensor(tf.io.read_file(word_information_path + '/' + e['passage_id'] +'_word.tfs' ), out_type=tf.int64)
            reference_time =  tf.io.parse_tensor(tf.io.read_file(word_information_path + '/' + e['passage_id'] +'_ref.tfs' ), out_type=tf.float32)
            a['valid_stu_start'] = tf.gather(e['WordStart'],e['MatchSegment'])
            a['valid_stu_duration'] = tf.gather(e['WordDuration'],e['MatchSegment'])

            a['valid_ref_word'] = tf.gather(passage_word, e['MatchReference'], batch_dims=1)
            a['valid_ref_start'] = tf.gather(reference_time[..., 0], e['MatchReference'], batch_dims=1)
            a['valid_ref_duration'] = tf.gather(reference_time[..., 1], e['MatchReference'], batch_dims=1)

            a['RecordName'] = e['RecordName']
            a['passage_id'] = e['passage_id']
            a['MatchSegment'] = e['MatchSegment']
            a['MatchReference'] = e['MatchReference']

            a['stu_mfcc'].set_shape([None, 80])
            a['ref_mfcc'].set_shape([available_voice, None, 80])
            a['valid_stu_start'].set_shape([available_voice, None])
            a['valid_stu_duration'].set_shape([available_voice, None])
            a['valid_ref_word'].set_shape([available_voice, None])
            a['valid_ref_start'].set_shape([available_voice, None])
            a['valid_ref_duration'].set_shape([available_voice, None])
            a['MatchSegment'].set_shape([available_voice, None])
            a['MatchReference'].set_shape([available_voice, None])
            return a
        return created_map

    def __generate_raw_data(self) -> tf.data.Dataset:
        self.__raw_data = tf.data.TFRecordDataset(self.tfrecord_path, compression_type='GZIP')\
            .map(self.parse_function,  num_parallel_calls=tf.data.AUTOTUNE)\
            .map(self.__first_map_builder(), num_parallel_calls=tf.data.AUTOTUNE)\
            .prefetch(tf.data.AUTOTUNE)
        return self.__raw_data

    @staticmethod
    @tf.function
    def get_mfcc(pcm: int,
                 sample_rate: int = 16000,
                 frame_length : int = 1024) -> tf.float32:
        # Implement the mel-frequency coefficients (MFC) from a raw audio signal.
        pcm = tf.cast(pcm, tf.float32) / tf.int16.max
        st_fft = tf.signal.stft(pcm, frame_length=frame_length, frame_step=frame_length // 8, fft_length=frame_length)
        spectrograms = tf.abs(st_fft)
        # Warp the linear scale spectrograms into the mel-scale.
        num_spectrogram_bins = frame_length // 2 + 1
        lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 80
        linear_to_mel_weight_matrix =\
            tf.signal.linear_to_mel_weight_matrix(num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, upper_edge_hertz)
        mel_spectrograms = tf.einsum('...t,tb->...b', spectrograms, linear_to_mel_weight_matrix)
        log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6)
        return log_mel_spectrograms

    @staticmethod
    def __pair_mapping(main : dict[str, tf.Tensor], counter : dict[str, tf.Tensor]) -> dict[str, tf.Tensor]:
        sample_dict = {}
        random_ref_voice_id = tf.random.uniform(shape=[], minval=0, maxval=tf.shape(counter['ref_mfcc'])[0], dtype=tf.int32)
        counter_random_ref_voice_id = tf.random.uniform(shape=[], minval=0, maxval=tf.shape(counter['ref_mfcc'])[0], dtype=tf.int32)
        sample_dict['stu_mfcc'] = main['stu_mfcc']
        sample_dict['ref_mfcc'] = main['ref_mfcc'][random_ref_voice_id]
        sample_dict['valid_stu_start'] = main['valid_stu_start'][random_ref_voice_id]
        sample_dict['valid_stu_duration'] = main['valid_stu_duration'][random_ref_voice_id]
        sample_dict['valid_ref_word'] = main['valid_ref_word'][random_ref_voice_id]
        sample_dict['valid_ref_start'] = main['valid_ref_start'][random_ref_voice_id]
        sample_dict['valid_ref_duration'] = main['valid_ref_duration'][random_ref_voice_id]

        sample_dict['counter_ref_mfcc'] = counter['ref_mfcc'][counter_random_ref_voice_id]

        # Sample same mount of period from counter that match the main
        # Get the range of word under main
        main_word_range = tf.shape(sample_dict['valid_ref_word'])
        # Sample same mount of period from counter that match the main
        # Generate same amount of random integer match up the range of main_word_range
        # counter_word_index = tf.random.uniform(shape=main_word_range, minval=0, maxval=tf.shape(counter['valid_ref_word'][counter_random_ref_voice_id])[0], dtype=tf.int32)
        shuffled_index = tf.random.shuffle(tf.range(tf.shape(counter['valid_ref_word'][counter_random_ref_voice_id])[0]))
        if tf.shape(shuffled_index)[0] > main_word_range[0]:
            counter_word_index = shuffled_index[:main_word_range[0]]
        else:
            counter_word_index = tf.random.uniform(shape=main_word_range, minval=0, maxval=tf.shape(counter['valid_ref_word'][counter_random_ref_voice_id])[0], dtype=tf.int32)
            # replace the value in the range of shuffled_index with the value in counter_word_index
            counter_word_index = \
                tf.tensor_scatter_nd_update(
                    counter_word_index,
                    tf.range(tf.shape(shuffled_index)[0])[...,tf.newaxis],
                    shuffled_index)
        # Sample data using counter_word_index
        sample_dict['counter_valid_ref_word'] = \
            tf.gather(counter['valid_ref_word'][counter_random_ref_voice_id], counter_word_index)
        sample_dict['counter_valid_ref_start'] = \
            tf.gather(counter['valid_ref_start'][counter_random_ref_voice_id], counter_word_index)
        sample_dict['counter_valid_ref_duration'] = \
            tf.gather(counter['valid_ref_duration'][counter_random_ref_voice_id], counter_word_index)
        # determine if counter_valid_ref_word with main_valid_ref_word match up if match up return 1. else return -1.
        sample_dict['counter_word_match'] = tf.where(tf.equal(sample_dict['counter_valid_ref_word'],
                                                              sample_dict['valid_ref_word']), 1., -1.)
        sample_dict['counter_pool_index'] = counter_word_index
        return sample_dict

    def pre_save(self) -> None:
        self.__raw_data.save(self.__cache, compression='GZIP')
        self.__cache_status = True
        self.__raw_data = tf.data.Dataset.load(self.__cache).load(self.__cache)
        print(f'Cache saved to {self.__cache}')

    def get_raw_data(self) -> tf.data.Dataset:
        if Path(self.__cache).exists() and not self.__cache_status:
            self.__cache_status = True
            print(f'Load cache from {self.__cache}')
            self.__raw_data = tf.data.Dataset.load(self.__cache, compression='GZIP')
        return self.__raw_data

    def get_pair_data(self) -> tf.data.Dataset:
        return self.get_raw_data().apply(self.__pair_map_handle(self.__pairs))

    def __batching_handle(self, batch_size : int) -> Callable[[tf.data.Dataset], tf.data.Dataset]:
        def handle(ds):
            return ds.\
                padded_batch(batch_size,padding_values={k:tf.cast(-1, v.dtype)if v.dtype != tf.string else '' for k, v in ds.element_spec.items()})\
                .prefetch(tf.data.experimental.AUTOTUNE)
        return handle

    def __pair_map_handle(self, pairs : int,
                          deterministic : bool = True)\
            -> Callable[[tf.data.Dataset], tf.data.Dataset]:
        def handle(ds):
            tuple_of_pairs = tuple(ds.shuffle(20, reshuffle_each_iteration=True) for _ in range(pairs))
            comb_data = tf.data.Dataset.zip(tuple_of_pairs).filter(lambda x, y: x["RecordName"] != y["RecordName"])
            return comb_data.map(self.__pair_mapping, num_parallel_calls=tf.data.AUTOTUNE,
                                 deterministic=deterministic)\
                .shuffle(buffer_size=10, reshuffle_each_iteration=True)
        return handle

    def k_fold(self, total_fold : int,
               fold_index : int,
               batch_size : int,
               deterministic :bool = False)\
            -> Tuple[tf.data.Dataset, tf.data.Dataset]:
        if fold_index >= total_fold:
            raise ValueError("fold_index must be less than total_fold")
        indexed_data = self.get_raw_data().enumerate()
        train_data = indexed_data\
            .filter(lambda index, _: index % total_fold != fold_index)\
            .map(lambda _, data: data, num_parallel_calls=tf.data.AUTOTUNE, deterministic=deterministic)\
            .apply(self.__pair_map_handle(self.__pairs, deterministic=deterministic))\
            .apply(self.__batching_handle(batch_size))

        test_data = indexed_data\
            .filter(lambda index, _: index % total_fold == fold_index)\
            .map(lambda _, data: data, num_parallel_calls=tf.data.AUTOTUNE, deterministic=deterministic)\
            .apply(self.__pair_map_handle(self.__pairs, deterministic=deterministic))\
            .apply(self.__batching_handle(batch_size))
        return train_data, test_data

    def get_batch_data(self,
                       batch_size: int,
                       deterministic = False) -> tf.data.Dataset:
        return self.get_raw_data().apply(self.__pair_map_handle(self.__pairs, deterministic = deterministic)).apply(self.__batching_handle(batch_size))

In [3]:
#####
#####

In [4]:
ds = DataPipeFactory('../DataFolder/Tensorflow_DataRecord/Student_Answer_Record.tfrecord',
                     '../DataFolder/Siri_Related/Siri_Reference_Sample',
                     '../DataFolder/Siri_Related/Siri_Dense_Index', cache='../DataFolder/cache/datapipe/cached')
# dsp = ds.get_batch_data(10)
# it = iter(dsp)
ds.get_raw_data()

Metal device set to: Apple M1 Max


2022-12-12 18:45:06.579982: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-12-12 18:45:06.580193: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Load cache from ../DataFolder/cache/datapipe/cached


<_LoadDataset element_spec={'stu_mfcc': TensorSpec(shape=(None, 80), dtype=tf.float32, name=None), 'ref_mfcc': TensorSpec(shape=(4, None, 80), dtype=tf.float32, name=None), 'valid_stu_start': TensorSpec(shape=(4, None), dtype=tf.float32, name=None), 'valid_stu_duration': TensorSpec(shape=(4, None), dtype=tf.float32, name=None), 'valid_ref_word': TensorSpec(shape=(4, None), dtype=tf.int64, name=None), 'valid_ref_start': TensorSpec(shape=(4, None), dtype=tf.float32, name=None), 'valid_ref_duration': TensorSpec(shape=(4, None), dtype=tf.float32, name=None), 'RecordName': TensorSpec(shape=(), dtype=tf.string, name=None), 'passage_id': TensorSpec(shape=(), dtype=tf.string, name=None), 'MatchSegment': TensorSpec(shape=(4, None), dtype=tf.int64, name=None), 'MatchReference': TensorSpec(shape=(4, None), dtype=tf.int64, name=None)}>

In [5]:
#ds.pre_save()

In [6]:
for i,d  in tqdm(enumerate(ds.get_batch_data(10))):
    print(i)
    if i ==30 :break
    pass

0it [00:00, ?it/s]

2022-12-12 18:45:10.220305: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


In [7]:
d

{'stu_mfcc': <tf.Tensor: shape=(10, 9905, 80), dtype=float32, numpy=
 array([[[ 0.9713775 ,  1.3845315 ,  0.77603424, ..., -1.7692082 ,
          -2.0082936 , -2.622317  ],
         [ 0.97667575,  1.2969316 ,  0.7504074 , ..., -1.7754831 ,
          -1.938354  , -2.7207448 ],
         [ 1.017461  ,  1.0680547 ,  0.47336727, ..., -1.8805484 ,
          -1.946756  , -2.6055958 ],
         ...,
         [-1.        , -1.        , -1.        , ..., -1.        ,
          -1.        , -1.        ],
         [-1.        , -1.        , -1.        , ..., -1.        ,
          -1.        , -1.        ],
         [-1.        , -1.        , -1.        , ..., -1.        ,
          -1.        , -1.        ]],
 
        [[-6.3380704 , -5.1513643 , -5.266494  , ..., -4.8287344 ,
          -4.7598214 , -4.9528627 ],
         [-6.0297246 , -5.2150583 , -5.2498055 , ..., -4.9620986 ,
          -4.9917226 , -4.8773026 ],
         [-5.690273  , -5.471487  , -5.457563  , ..., -5.1897154 ,
          -5.11

In [22]:
a,b = ds.k_fold(5, 0, 10)

2022-12-09 13:42:05.599881: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


In [23]:
for i,d  in tqdm(enumerate(a)):
    print(i)
    if i == 30 :break
    pass

0it [00:00, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


In [24]:
for i,d  in tqdm(enumerate(b)):
    print(i)
    if i ==30 :break
    pass

0it [00:00, ?it/s]

2022-12-09 13:42:56.611151: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 6 of 10
2022-12-09 13:42:58.202708: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:415] Shuffle buffer filled.


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


In [25]:
dsr = ds.get_raw_data()

In [26]:
next(iter(dsr.shuffle(5).window(2)))

{'stu_mfcc': <_VariantDataset element_spec=TensorSpec(shape=(None, 80), dtype=tf.float32, name=None)>,
 'ref_mfcc': <_VariantDataset element_spec=TensorSpec(shape=(4, None, 80), dtype=tf.float32, name=None)>,
 'valid_stu_start': <_VariantDataset element_spec=TensorSpec(shape=(4, None), dtype=tf.float32, name=None)>,
 'valid_stu_duration': <_VariantDataset element_spec=TensorSpec(shape=(4, None), dtype=tf.float32, name=None)>,
 'valid_ref_word': <_VariantDataset element_spec=TensorSpec(shape=(4, None), dtype=tf.int64, name=None)>,
 'valid_ref_start': <_VariantDataset element_spec=TensorSpec(shape=(4, None), dtype=tf.float32, name=None)>,
 'valid_ref_duration': <_VariantDataset element_spec=TensorSpec(shape=(4, None), dtype=tf.float32, name=None)>,
 'RecordName': <_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>,
 'passage_id': <_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>,
 'MatchSegment': <_VariantDataset element_spec=Tens

In [6]:
###
#
# main_word_range = tf.range(tf.shape(i[0]['valid_ref_word'])[1])

In [12]:
Path(str(None))

PosixPath('None')

In [34]:
dataset = tf.data.Dataset.range(30)
window_size = 5
key_func = lambda x: x%3
reduce_func = lambda key, dataset: dataset.batch(window_size)
dataset = dataset.group_by_window(
    key_func=key_func,
    reduce_func=reduce_func,
    window_size=window_size)
for elem in dataset.as_numpy_iterator():
    print(elem)

[ 0  3  6  9 12]
[ 1  4  7 10 13]
[ 2  5  8 11 14]
[15 18 21 24 27]
[16 19 22 25 28]
[17 20 23 26 29]
