In [11]:
from file_encoder import FileEncoder
from data_manager import DataManager
from audio_utils import N_MELS
import numpy as np
import sys
import tensorflow as tf
from tensorflow.keras import layers, models, Model


TARGET_DIR = "../thirdparty/「波音リツ」歌声データベースVer2/DATABASE"
OUTPUT_DIR = "../master/ust/json"
    
PITCH_FEATURES = 1 # MIDI番号を0 ~ 100のスカラー
LYRIC_FEATURES = 1 # 歌詞をインデックスしたのでスカラー
DURATION_FEATURES = 1 # msの長さを正規化したのでスカラー
NOTE_CHUK_INDEX_FEATURES = 1 # 何番目の分割化を正規化したのでスカラー

LYRIC_INDEX_MAX_DIM = 256 # インデックス化した歌詞の最大インデックス(237だったので若干多めに確保)
LYRIC_DIM = 64 # インデックス化した歌詞を埋め込みした後の次元
DURATION_DIM = 1 # 正規化した長さのミリ秒（0ミリ秒 ~ 10000ミリ秒（10秒）)

# TODO: N分割したいが今後考える（ビブラートなどを考慮するにはノートを更に分割したい）
NOTE_CHUNK = 1 # 1ノートに対して1分割する


MODEL_FILE = "../data/model.keras"

LYRIC_INDEX_FILE = "../data/npy/lyric_indexs.npy"
DURATION_INDEX_FILE = "../data/npy/duration_indexs.npy"
NOTENUM_INDEX_FILE = "../data/npy/notenum_indexs.npy"
NAMES_FILE = "../data/json/names.json"
Y_FILE = "../data/npy/y.npy"


def build_model():
    pitch_input = layers.Input(shape=(NOTE_CHUNK,PITCH_FEATURES), name="pitch_input")
    lyric_input = layers.Input(shape=(NOTE_CHUNK,), name="lyric_input")
    duration_input = layers.Input(shape=(NOTE_CHUNK,DURATION_FEATURES), name="duration_input")
    # note_chunk_index_input = layers.Input(shape=(NOTE_CHUNK,NOTE_CHUK_INDEX_FEATURES), name="note_chunk_index_input") # 何番目に分割したのか保持
    
    # 歌詞を埋め込み
    lyric_embedding = layers.Embedding(output_dim=LYRIC_DIM, input_dim=LYRIC_INDEX_MAX_DIM, name="lyric_embedding")(lyric_input) # 埋め込み
    
    # 時系列
    lstm_pitch = layers.LSTM(units=64, return_sequences=True)(pitch_input)
    lstm_duration = layers.LSTM(units=64, return_sequences=True)(duration_input)
    lstm_lyric_embedding = layers.LSTM(units=64, return_sequences=True)(lyric_embedding)
    # lstm_note_chunk_index = layers.LSTM(units=64, return_sequences=True)(note_chunk_index_input)
    
    
    merged = layers.Concatenate(axis=-1, name="merged_features")(
        [lstm_pitch, lstm_lyric_embedding, lstm_duration]
        # [lstm_pitch, lstm_lyric_embedding, lstm_duration, lstm_note_chunk_index]
    )
    
    # 最終的な時系列処理
    final_lstm = layers.LSTM(128, return_sequences=True, name="final_lstm")(merged)
    
    # 出力層（例: メルスペクトログラムへの回帰）
    output = layers.TimeDistributed(layers.Dense(5, activation="tanh"), name="output")(final_lstm)
    
    # モデル定義
    model = Model(inputs=[pitch_input, lyric_input, duration_input], outputs=output)
    return model



def load_data(
    from_storage=False
):
    if from_storage:
        lyric_indexs = np.load(LYRIC_INDEX_FILE)
        durations = np.load(DURATION_INDEX_FILE)
        notenums = np.load(NOTENUM_INDEX_FILE)
        y = np.load(Y_FILE)
        return lyric_indexs, durations, notenums, y
        
    encoder = FileEncoder(TARGET_DIR, OUTPUT_DIR)
    lyric_indexs, durations, notenums, y = encoder.encode()
    
    lyric_indexs = np.array(lyric_indexs).reshape(-1, NOTE_CHUNK)  # 歌詞
    durations = np.array(durations).reshape(-1, NOTE_CHUNK, DURATION_FEATURES)  # 長さ
    notenums = np.array(notenums).reshape(-1, NOTE_CHUNK, PITCH_FEATURES)  # ピッチ
    y = np.array(y)

    np.save(LYRIC_INDEX_FILE, lyric_indexs)
    np.save(DURATION_INDEX_FILE, durations)
    np.save(NOTENUM_INDEX_FILE, notenums)
    np.save(Y_FILE, y)

    return lyric_indexs, durations, notenums, y

# モデルの学習
def train_model():
    # モデル構築
    model = build_model()

    # モデルコンパイル
    model.compile(optimizer="adam", loss="mse", metrics=["mae"])

    # データ読み込み
    lyric_indexs, durations, notenums, y = load_data(from_storage=True)
    print("lyric_indexs.shape",  lyric_indexs.shape)
    print("durations.shape", durations.shape)
    print("notenums.shahpe", notenums.shape)
    print("y.shape", y.shape)

    # モデルの学習
    history = model.fit(
        [notenums, lyric_indexs, durations], # 入力データ
        y,  # ターゲットデータ
        batch_size=32,
        epochs=20,
        validation_split=0.2
    )

    model.save(MODEL_FILE)

    return history

def plot_history(history):
    # 損失を描画
    plt.figure(figsize=(10, 5))
    plt.plot(history.history['loss'], label='Training Loss')
    if 'val_loss' in history.history:
        plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    plt.show()
    
    # 精度 (MAE) を描画
    if 'mae' in history.history:
        plt.figure(figsize=(10, 5))
        plt.plot(history.history['mae'], label='Training MAE')
        if 'val_mae' in history.history:
            plt.plot(history.history['val_mae'], label='Validation MAE')
        plt.title('MAE over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel('Mean Absolute Error')
        plt.legend()
        plt.grid()
        plt.show()

# 実行
if __name__ == "__main__":
    history = train_model()
    plot_history(history)
    print("done!")

lyric_indexs.shape (50872, 1)
durations.shape (50872, 1, 1)
notenums.shahpe (50872, 1, 1)
y.shape (50872, 128, 5)
Epoch 1/20


I0000 00:00:1735904729.085632  488864 cuda_dnn.cc:529] Loaded cuDNN version 90600


[1m1272/1272[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 21ms/step - loss: 0.0807 - mae: 0.1042 - val_loss: 8.3822e-04 - val_mae: 0.0041
Epoch 2/20
[1m1272/1272[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 19ms/step - loss: 8.0653e-04 - mae: 0.0041 - val_loss: 8.4096e-04 - val_mae: 0.0042
Epoch 3/20
[1m1272/1272[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 19ms/step - loss: 8.4171e-04 - mae: 0.0038 - val_loss: 8.3000e-04 - val_mae: 0.0029
Epoch 4/20
[1m1272/1272[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 18ms/step - loss: 8.2647e-04 - mae: 0.0039 - val_loss: 8.0264e-04 - val_mae: 0.0040
Epoch 5/20
[1m1272/1272[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 19ms/step - loss: 5.5650e-04 - mae: 0.0050 - val_loss: 2.6163e-04 - val_mae: 0.0032
Epoch 6/20
[1m1272/1272[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 19ms/step - loss: 2.9943e-04 - mae: 0.0035 - val_loss: 2.6364e-04 - val_mae: 0.0033
Epoch 7/20
[1m1272/1272[0m [