diff --git a/scripts/bert/bert_train.py b/scripts/bert/bert_train.py index 2ddf22f0d041b6..497f0f345dfd03 100644 --- a/scripts/bert/bert_train.py +++ b/scripts/bert/bert_train.py @@ -34,4 +34,6 @@ def tokenize_function(examples): ) print("==================================== Evaluating Model =================================") -model.fit(train_tf_dataset, validation_data=eval_tf_dataset, epochs=3) +model.fit(train_tf_dataset, validation_data=eval_tf_dataset, epochs=1) +info = model.evaluate(eval_tf_dataset, verbose=2) + diff --git a/scripts/bert/bert_train.sh b/scripts/bert/bert_train.sh index 2756c74a70b895..325e174b38d19f 100644 --- a/scripts/bert/bert_train.sh +++ b/scripts/bert/bert_train.sh @@ -4,7 +4,8 @@ set -x pip3 install transformers datasets -cd ~ && git clone --branch bert-tf2 https://github.com/ROCmSoftwarePlatform/transformers +cd ~ && git clone https://github.com/ROCmSoftwarePlatform/transformers +cd ~ # Script to train the small 117M model python3 transformers/scripts/bert/bert_train.py > log.txt cat log.txt | tail -n 1