# Estructura del modelo

### Imports

In [None]:
import os
import librosa
import numpy as np
import soundfile as sf
import tensorflow as tf
import matplotlib.pyplot as plt

## Gestión del Dataset

### Funciones

In [None]:
def audio_to_spectrogram(file_path, sr = 20500, n_fft = 2048, hop_length = 512):
    # Cargar el audio
    y, sr = librosa.load(file_path, sr = sr)
    
    # Calcular el espectrograma (STFT)
    S = librosa.stft(y, n_fft = n_fft, hop_length = hop_length)
    
    # Convertir a magnitudes (la magnitud es el espectrograma)
    spectrogram = np.abs(S)
    
    # Escalar en logaritmo (opcional para mejor visualización y aprendizaje)
    log_spectrogram = librosa.amplitude_to_db(spectrogram, ref = np.max)

    return log_spectrogram

def spectrogram_to_audio_sin_fase(magnitud, sr, n_fft = 2048, hop_length = 512, n_iter = 32):
    # Reconstrucción del audio con Griffin-Lim
    audio_reconstruido = librosa.griffinlim(magnitud, n_iter = n_iter, hop_length = hop_length, n_fft = n_fft)
    
    return audio_reconstruido

def visualize_spectrogram(spectrogram, title = "Spectrogram"):
    plt.figure(figsize=(10, 4))

    # Mostrar el espectrograma con un mapa de colores (viridis o inferno suelen ser útiles)
    plt.imshow(spectrogram, aspect = 'auto', origin = 'lower', cmap = 'viridis')
    plt.colorbar(label = "Decibels (dB)")
    plt.title(title)
    plt.xlabel("Time (frames)")
    plt.ylabel("Frequency (bins)")
    plt.tight_layout()
    
    plt.show()

def pad_or_trim(spectrogram, max_length = 188):
    if spectrogram.shape[1] > max_length:  # Recortar
        return spectrogram[:, :max_length]
    else:  # Rellenar
        padding = np.zeros((spectrogram.shape[0], max_length - spectrogram.shape[1]))
        return np.hstack((spectrogram, padding))


### Carga, clasificación y estructuración de los datos

In [None]:

path = "path"

vocab = []
word_to_index = {}
index_to_word = {}
data = []
flat_data = []
num = 0

batch_size = 32

max_length = 188
max_height = 1025

for word in os.listdir(path):
    print(num)
    vocab.append(word)

    word_path = os.path.join(path, word)
    spectrogram_list = []
    
    for index, audio in enumerate(os.listdir(word_path)):
        if index > 100:
            break
        
        audio_path = os.path.join(word_path, audio)
        spectrogram = audio_to_spectrogram(audio_path)

        # Ajustar dimensiones
        spectrogram = pad_or_trim(spectrogram)

        spectrogram_list.append((spectrogram))

    data.append(spectrogram_list)
    num += 1


word_to_index = {word: index for index, word in enumerate(vocab)}
index_to_word = {idx: word for word, idx in word_to_index.items()}

for i in range(len(data)):
    for j in range(len(data[i])):
        flat_data.append((i, data[i][j]))

### Creación del Dataset

In [None]:
labels = [item[0] for item in flat_data]  # Extraer etiquetas
spectrograms = [item[1] for item in flat_data]  # Extraer espectrogramas

# Convertir a tensores
label_tensor = tf.convert_to_tensor(labels, dtype=tf.int32)
spectrogram_tensor = tf.convert_to_tensor(spectrograms, dtype=tf.float32)

# Crear el dataset
dataset = tf.data.Dataset.from_tensor_slices((label_tensor, spectrogram_tensor))

# Barajar y dividir en lotes
dataset = dataset.shuffle(len(labels)).batch(batch_size)

In [None]:
print(vocab)
print(word_to_index)

print(spectrograms)

print(f"Valores mínimos y máximos en labels: {min(labels)}, {max(labels)}")
print(f"Vocab size: {len(vocab)}")

# Modelo

### Estructura

In [None]:
class TextEncoder(tf.keras.Model):
    def __init__(self, vocab_size, embed_dim):
        super(TextEncoder, self).__init__()
        self.embedding = tf.keras.layers.Embedding(vocab_size, embed_dim)

    def call(self, x):
        return self.embedding(x)

class SpectrogramDecoder(tf.keras.Model):
    def __init__(self, target_dim):
        super(SpectrogramDecoder, self).__init__()
        self.dense = tf.keras.Sequential([
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(target_dim, activation=None)  # No activación para espectrogramas sin normalizar
        ])

    def call(self, x):
        return self.dense(x)

class TextToSpectrogram(tf.keras.Model):
    def __init__(self, vocab_size, embed_dim, target_dim):
        super(TextToSpectrogram, self).__init__()
        self.encoder = TextEncoder(vocab_size, embed_dim)
        self.decoder = SpectrogramDecoder(target_dim)

    def call(self, x):
        x = self.encoder(x)
        return self.decoder(x)
    
def create_model(vocab_size, embed_dim, target_dim):
    # Definir el encoder y el decoder dentro del modelo funcional
    input_text = tf.keras.Input(shape=(1,), name="text_input")  # Entrada de texto
    embedding = tf.keras.layers.Embedding(vocab_size, embed_dim)(input_text)

    # Decoder
    dense_1 = tf.keras.layers.Dense(256, activation='relu')(embedding)
    output_flat = tf.keras.layers.Dense(target_dim, activation=None, name="output")(dense_1)

    # Ajustar dimensiones al formato (1025, 94)
    output_spectrogram = tf.keras.layers.Reshape((max_height, max_length))(output_flat)
    
    # Crear modelo
    model = tf.keras.Model(inputs=input_text, outputs=output_spectrogram, name="TextToSpectrogram")
    return model


# Entrenamiento del Modelo

### Funciones

In [None]:
def plot_losses(history):
    plt.rcParams['figure.figsize'] = [20, 5]  # Ajustar tamaño de las gráficas
    f, (ax1, ax2) = plt.subplots(1, 2, sharex=True)

    # Pérdidas
    ax1.set_title('Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.grid()
    ax1.plot(history.history['loss'], label='Training Loss', color='blue')
    if 'val_loss' in history.history:
        ax1.plot(history.history['val_loss'], label='Validation Loss', color='orange')
    ax1.legend(loc="upper right")

    # Métricas
    ax2.set_title('Mean Absolute Error (MAE)')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('MAE')
    ax2.grid()
    if 'mae' in history.history:
        ax2.plot(history.history['mae'], label='Training MAE', color='green')
    if 'val_mae' in history.history:
        ax2.plot(history.history['val_mae'], label='Validation MAE', color='red')
    ax2.legend(loc="upper right")

    # Mostrar las gráficas
    plt.show()


### Configuración

In [None]:
vocab_size = len(vocab) # Tamaño del vocabulario
embed_dim = 128  # Dimensión del embedding
target_dim = max_length * max_height  # Dimensiones del espectrograma

model = create_model(vocab_size, embed_dim, target_dim)
model.summary()


### Compilación con Funciones de Pérdida

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss="mse",  # Error cuadrático medio
    metrics=["mae"]  # Error absoluto medio como métrica adicional
)

### Entrenamiento

In [None]:
epochs = 5

history = model.fit(dataset, batch_size = batch_size, epochs = epochs)

In [None]:
plot_losses(history)

# Obtención de los espectogramas

### Obtener la predicción

In [None]:
# Hacer la predicción
predicted_spectrogram = model.predict([3])

# Como 'predicted_spectrogram' tendrá la forma (1, 1025, 94), accedemos al primer elemento
spectrogram = predicted_spectrogram[0]  # El espectrograma de la palabra

### Visualizar el espectrograma predecido

In [None]:
# Asume que el modelo devuelve un espectrograma con dimensiones (1025, 94)
spectrogram = predicted_spectrogram[0]  # Toma el primer resultado en el batch

# Visualiza el espectrograma
plt.figure(figsize=(10, 6))
plt.imshow(spectrogram, aspect='auto', origin='lower', cmap='viridis')
plt.colorbar(label='Amplitude')
plt.title('Predicted Spectrogram')
plt.xlabel('Time')
plt.ylabel('Frequency')
plt.show()

### Obtener Audio

In [None]:
audio_generado = spectrogram_to_audio_sin_fase(spectrogram, 20500)

sf.write("audio_generado.wav", audio_generado, samplerate = 20500, format = "wav")