In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense, Embedding, LSTM, Bidirectional, Concatenate, Conv1D, Conv2D, MaxPooling2D, BatchNormalization, Dropout
from tensorflow.keras.models import Model

In [None]:
class LocationSensitiveAttention(Layer):
    def __init__(self, units, **kwargs):
        super(LocationSensitiveAttention, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        input_shape = list(input_shape)
#         print(input_shape, list(input_shape[-1]))
#         print(input_shape[-1][1])
#         print(self.units, (input_shape[0][-1], self.units))
        self.W1 = self.add_weight(name='W1',
                                  shape=(input_shape[0][-1], self.units),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.W2 = self.add_weight(name='W2',
                                  shape=(input_shape[-1][1], self.units),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.V = self.add_weight(name='V',
                                  shape=(self.units, 1),
                                  initializer='glorot_uniform',
                                  trainable=True)
        super(LocationSensitiveAttention, self).build(input_shape)

    def call(self, inputs, **kwargs):
        hidden_states, encoder_outputs = inputs  # encoder_outputs - выход энкодера
        print(hidden_states.shape, self.W1.shape, tf.transpose(encoder_outputs, perm=[0, 2, 1]).shape)
        
        # Вычисление весов внимания
        attention_weights = tf.matmul(hidden_states, self.W1)
        attention_weights = tf.tanh(attention_weights)
        
        # Вектор контекста для каждой временной ступени
        context_vectors = tf.matmul(attention_weights, tf.transpose(encoder_outputs, perm=[0, 2, 1]))

        # Преобразование контекста
        context_vectors = tf.matmul(context_vectors, self.W2)

        return context_vectors

    def get_config(self):
        config = super(LocationSensitiveAttention, self).get_config()
        config.update({'units': self.units})
        return config

In [None]:
class CBHG(Layer):
    """Convolutional-Bank Highway Network (CBHG) layer."""

    def __init__(self, K=16, projection_dim=128, **kwargs):
        super(CBHG, self).__init__(**kwargs)
        self.K = K
        self.projection_dim = projection_dim

    def build(self, input_shape):
        self.conv_banks = []
        for k in range(1, self.K + 1):
            self.conv_banks.append(
                Conv1D(filters=128, kernel_size=k, activation='relu', padding='same')
            )
        self.max_pool = MaxPooling2D(pool_size=(1, 2), padding='same')
        self.conv_projection = Conv1D(filters=self.projection_dim, kernel_size=3, activation='relu', padding='same')
        self.highway_layers = []
        for _ in range(4):
            self.highway_layers.append(
                Highway(self.projection_dim)
            )
        super(CBHG, self).build(input_shape)

    def call(self, inputs, **kwargs):
        outputs = inputs
        conv_bank_outputs = []
        for conv_bank in self.conv_banks:
            conv_bank_outputs.append(conv_bank(outputs))
        conv_bank_outputs = tf.concat(conv_bank_outputs, axis=-1)
        conv_bank_outputs = tf.expand_dims(conv_bank_outputs, axis=-1)
        conv_bank_outputs = self.max_pool(conv_bank_outputs)
        conv_bank_outputs = tf.squeeze(conv_bank_outputs, axis=-1)
        conv_bank_outputs = self.conv_projection(conv_bank_outputs)
        highway_outputs = conv_bank_outputs
        for highway_layer in self.highway_layers:
            highway_outputs = highway_layer(highway_outputs)
        return highway_outputs

class Highway(Layer):
    """Highway Network layer."""

    def __init__(self, units, **kwargs):
        super(Highway, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.W_H = self.add_weight(name='W_H',
                                  shape=(input_shape[-1], self.units),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.W_T = self.add_weight(name='W_T',
                                  shape=(input_shape[-1], self.units),
                                  initializer='glorot_uniform',
                                  trainable=True)
        self.b_H = self.add_weight(name='b_H',
                                  shape=(self.units,),
                                  initializer='zeros',
                                  trainable=True)
        self.b_T = self.add_weight(name='b_T',
                                  shape=(self.units,),
                                  initializer='zeros',
                                  trainable=True)
        super(Highway, self).build(input_shape)

    def call(self, inputs, **kwargs):
        H = tf.nn.relu(tf.matmul(inputs, self.W_H) + self.b_H)
        T = tf.sigmoid(tf.matmul(inputs, self.W_T) + self.b_T)
        return T * H + (1 - T) * inputs

    def get_config(self):
        config = super(Highway, self).get_config()
        config.update({'units': self.units})
        return config

class Prenet(Layer):
    """Pre-net layer."""

    def __init__(self, units, **kwargs):
        super(Prenet, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.dense1 = Dense(units=self.units // 2, activation='relu')
        self.dense2 = Dense(units=self.units, activation='relu')
        super(Prenet, self).build(input_shape)

    def call(self, inputs, **kwargs):
        outputs = self.dense1(inputs)
        outputs = self.dense2(outputs)
        return outputs

    def get_config(self):
        config = super(Prenet, self).get_config()
        config.update({'units': self.units})
        return config

class Postnet(Layer):
    """Post-net layer."""

    def __init__(self, num_channels, **kwargs):
        super(Postnet, self).__init__(**kwargs)
        self.num_channels = num_channels

    def build(self, input_shape):
        self.conv1 = Conv1D(filters=self.num_channels, kernel_size=5, activation='relu', padding='same')
        self.conv2 = Conv1D(filters=self.num_channels, kernel_size=5, activation='relu', padding='same')
        self.conv3 = Conv1D(filters=self.num_channels, kernel_size=5, activation='relu', padding='same')
        self.conv4 = Conv1D(filters=self.num_channels, kernel_size=5, activation='relu', padding='same')
        self.conv5 = Conv1D(filters=self.num_channels, kernel_size=5, activation='relu', padding='same')
        self.linear_projection = Conv1D(filters=80, kernel_size=1, activation='linear', padding='same')
        super(Postnet, self).build(input_shape)

    def call(self, inputs, **kwargs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        outputs = self.conv3(outputs)
        outputs = self.conv4(outputs)
        outputs = self.conv5(outputs)
        outputs = self.linear_projection(outputs)
        return outputs

    def get_config(self):
        config = super(Postnet, self).get_config()
        config.update({'num_channels': self.num_channels})
        return config


class Tacotron2(Model):
    """Tacotron 2 model."""

    def __init__(self, vocab_size, embedding_dim, encoder_lstm_units, decoder_lstm_units, cbhg_units, postnet_channels, **kwargs):
        super(Tacotron2, self).__init__(**kwargs)
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.encoder_lstm_units = encoder_lstm_units
        self.decoder_lstm_units = decoder_lstm_units
        self.cbhg_units = cbhg_units
        self.postnet_channels = postnet_channels

    def build(self, input_shape):
        # Встраивание (Embedding) для преобразования слов в векторы
        self.embedding = Embedding(input_dim=self.vocab_size, output_dim=self.embedding_dim)

        # Энкодер
        self.encoder_lstm = Bidirectional(LSTM(units=self.encoder_lstm_units, return_sequences=True))
        self.cbhg = CBHG(projection_dim=self.cbhg_units)

        # Декодер
        self.prenet = Prenet(units=256)
        self.decoder_lstm = LSTM(units=self.decoder_lstm_units, return_sequences=True, return_state=True)
        self.attention = LocationSensitiveAttention(units=self.decoder_lstm_units*2)

        # Проектирование на мел-спектрограмму
        self.linear_projection = Dense(units=80, activation='linear')

        # Пост-нет
        self.postnet = Postnet(num_channels=self.postnet_channels)

        super(Tacotron2, self).build(input_shape)
    @tf.function
    def call(self, inputs, **kwargs):
        # Встраивание текста
        embedded_inputs = self.embedding(inputs)

        # Энкодер
        encoder_outputs = self.encoder_lstm(embedded_inputs)
        encoder_outputs = self.cbhg(encoder_outputs)

        # Декодер
        decoder_inputs = tf.zeros((tf.shape(inputs)[0], 1, self.decoder_lstm_units))
        decoder_hidden_state = tf.zeros((tf.shape(inputs)[0], self.decoder_lstm_units))
        decoder_cell_state = tf.zeros((tf.shape(inputs)[0], self.decoder_lstm_units))
        attention_weights = []
        mel_outputs_array = tf.TensorArray(dtype=tf.float32, size=tf.shape(inputs)[1], dynamic_size=False)
        for i in tf.range(tf.shape(inputs)[1]):
            # Пренет
            prenet_outputs = self.prenet(decoder_inputs)

            # Декодер LSTM
            decoder_outputs, decoder_hidden_state, decoder_cell_state = self.decoder_lstm(prenet_outputs, initial_state=[decoder_hidden_state, decoder_cell_state])

            # Внимание
            context_vector = self.attention([decoder_outputs, encoder_outputs])

            # Соединение контекста и скрытых состояний декодера
            concat_outputs = Concatenate()([context_vector, decoder_outputs])

            # Линейное проектирование на мел-спектрограмму
            mel_output = self.linear_projection(concat_outputs)
    

            # Сохранение весов внимания
            attention_weights.append(tf.reduce_sum(context_vector * encoder_outputs, axis=-1))
            mel_outputs_array = mel_outputs_array.write(i, mel_output)

            # Обновление входных данных декодера
            decoder_inputs = decoder_outputs
            

        # Создание тензора мел-спектрограмм
    
        mel_outputs =  mel_outputs_array.concat()
        # Пост-нет
        postnet_outputs = self.postnet(mel_outputs)
        print(mel_outputs)

        # Добавление мел-спектрограмм и выходных данных пост-нета
        outputs = Concatenate()([mel_outputs, postnet_outputs])

        return outputs, attention_weights