In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import mlflow
import mlflow.tensorflow

# 載入MNIST數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [2]:
# 正規化數據
x_train = x_train / 255.0
x_test = x_test / 255.0

In [3]:
# 將資料分為training、validation和testing dataset
validation_split = 0.2
num_val_samples = int(len(x_train) * validation_split)

x_val = x_train[:num_val_samples]
y_val = y_train[:num_val_samples]
x_train = x_train[num_val_samples:]
y_train = y_train[num_val_samples:]

In [4]:
# 定義模型
model = tf.keras.Sequential([
    layers.Flatten(input_shape=(28, 28)),  # 將28 x 28的圖像攤平成784維向量
    layers.Dense(128, activation='relu'),  # 添加一層具有128個神經元的全連接層
    layers.Dropout(0.2),  # 添加Dropout層，減少過度擬合
    layers.Dense(10)  # 添加具有10個神經元的全連接層，用於分類
])

# 定義損失函數和優化器
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

# 編譯模型
model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])

In [5]:
# 定義訓練的超參數
epochs = 10
batch_size = 32

In [6]:
# 使用mlflow紀錄模型parameter
mlflow.set_experiment("mnist_train")
mlflow.tensorflow.autolog()

In [7]:
# 訓練模型，使用validation dataset當作early stop的依據
with mlflow.start_run() as run:
    mlflow.set_tag("model_type", "v1")
    
    history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,
                        validation_data=(x_val, y_val), callbacks=[tf.keras.callbacks.EarlyStopping(patience=3)])
    
    # 使用mlflow紀錄模型metrics
    mlflow.log_metric("accuracy", history.history["accuracy"][-1])
    mlflow.log_metric("val_accuracy", history.history["val_accuracy"][-1])
    
    # 評估模型準確度
    test_loss, test_accuracy = model.evaluate(x_test, y_test)
    mlflow.log_metric("test_loss", test_loss)
    mlflow.log_metric("test_accuracy", test_accuracy)
    
    # 使用mlflow紀錄模型
    mlflow.keras.log_model(model, "model")


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
INFO:tensorflow:Assets written to: /tmp/tmp_r7vfa3q/model/data/model/assets
INFO:tensorflow:Assets written to: /tmp/tmp4db43r27/model/data/model/assets
