In [40]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense,LayerNormalization, BatchNormalization, Dropout, Conv1D, Flatten, Reshape
from tensorflow.keras.models import Model

# define transformer block
class TransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate =0.1):
        super(TransformerBlock, self).__init__()
        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)

        self.ffn = tf.keras.Sequential([ 
            Dense(ff_dim, activation = "relu"),
            Dense(embed_dim)
        ])

        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, inputs, training, mask=None):
        attn_output = self.att(inputs, inputs, inputs, attention_mask=mask) # 3 inputs for query, key, value
        attn_output = self.dropout1(attn_output,training=training)
        out1 = self.layernorm1(inputs + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

In [41]:
class PatchEmbedding(Layer):
    def __init__(self, num_patches, embedding_dim):
        super(PatchEmbedding, self).__init__()
        self.num_patches = num_patches
        self.embedding_dim = embedding_dim
        self.projection = Dense(embedding_dim)

    # Apply patches to disired layers
    def call(self, patches):
        return self.projection(patches)

In [42]:
class VisionTransformer(tf.keras.Model):
    def __init__(self, num_patches, embedding_dim, num_heads, ff_dim, num_layers,num_classes):
        super(VisionTransformer, self).__init__()
        # Define patch embedding
        self.patch_embed = PatchEmbedding(num_patches, embedding_dim)
        # Define transformer layer using TransformerBlock
        self.transformer_layers = [TransformerBlock(embedding_dim, num_heads, ff_dim) for _ in range(num_layers)]

        self.flatten = Flatten()
        self.dense = Dense(num_classes, activation="softmax")

    def call(self, images, training):
        patches = self.extract_patches(images)
        x = self.patch_embed(patches)

        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x, training= training) # apply embed images to transformer layer
            x = self.flatten(x)
            return self.dense(x)
    # extract patches from images for the tranformer model
    def extract_patches(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes = [1,16,16,1], strides=[1,16,16,1],
            rates = [1,1,1,1],
            padding="VALID"
        )
        patches = tf.reshape(patches, [batch_size, -1, 16*16*3])
        return patches
        

#### Define SpeechTransformer model

In [45]:
class SpeechTransformer(Model):
    def __init__(self, num_mel_bins, embedding_dim, num_heads, ff_dim, num_layers, num_classes):
        super(SpeechTransformer, self).__init__()

        # define Convolution layer
        self.conv1 = Conv1D(filters=embedding_dim, kernel_size=3, strides=1,padding="same", activation="relu")
        
        self.batch_norm = BatchNormalization()
        self.reshape = Reshape((-1, embedding_dim))
        self.transformer_layers = [TransformerBlock(embedding_dim, num_heads, ff_dim) 
                                  for _ in range(num_layers)]
        self.flatten = Flatten()
        self.dense = Dense(num_classes, activation="softmax")

    def call(self, spectrograms):
        x = self.conv1(spectrograms)
        x = self.batch_norm(x)
        x = self.reshape(x)

        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x, training=True)
        x = self.flatten(x)
        return self.dense(x)

In [46]:
 # Example usage

num_mel_bins = 80
embedding_dim = 128
num_heads = 4
ff_dim = 512
num_layers = 6
num_classes = 30 # Example for phoneme classification

# Initialize SpeechTransformer model
st = SpeechTransformer(num_mel_bins, embedding_dim, num_heads, ff_dim, num_layers, num_classes)

# generate example spectrograms
spectrograms = tf.random.uniform((32, 100, num_mel_bins)) # Batch pf 32 spectrograms with 100 time frame

#model 
output = st(spectrograms, training=True)
print(output.shape)

(32, 30)
