### This is the pre-training of color representation model
- It takes 5-10mins on GPU (Tesla T4 * 1) within 100 epochs for one time training
  (with batch-size=2048 in modelConfig)

In [3]:
%cd ../

/Users/s12497/workspace/forShare/crello_color_recomm


In [17]:
import os
import math
import tensorflow as tf
import src.modeling as modeling
from src.dataGenerator import DataGenerator
from src.modelConfig import Config

from datetime import datetime


In [19]:
def calculate_pretrain_task_accuracy(mlm_predict, batch_mlm_mask, origin_x):

    batch_mlm_mask = tf.cast(batch_mlm_mask, dtype=tf.int32)
    index = tf.where(batch_mlm_mask == 1)
    x_predict = tf.math.argmax(mlm_predict, axis=-1) # top1
    x_predict = tf.gather_nd(x_predict, index)
    x_real = tf.gather_nd(origin_x, index)
    mlm_accuracy = tf.keras.metrics.Accuracy()
    mlm_accuracy.update_state(x_predict, x_real)
    mlm_accuracy = mlm_accuracy.result().numpy()

    return mlm_accuracy

In [21]:
# pretrain

# training on CPU
physical_devices = tf.config.experimental.list_physical_devices('CPU')
assert len(physical_devices) > 0, "Not enough CPU hardware devices available"

# training on GPU
# physical_devices = tf.config.experimental.list_physical_devices('GPU')
# assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
# tf.config.experimental.set_memory_growth(physical_devices[0], True)

model = modeling.Bert(Config)
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)
loss_fn = modeling.BERT_Loss()
dataset = DataGenerator(Config)
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(tf.train.latest_checkpoint(Config['Saved_Weight']))
manager = tf.train.CheckpointManager(checkpoint, directory=Config['Saved_Weight'], max_to_keep=5)
log_dir = os.path.join(Config['Log_Dir'], datetime.now().strftime("%Y-%m-%d"))
writer = tf.summary.create_file_writer(log_dir)


# create the data for validation and test
PROJECT_PATH = Config['Project_path']
Config_val = Config.copy()
Config_val['Corpus_File_Path'] = os.path.join(PROJECT_PATH, f'Data_color/color_corpus_lab_bins_16_val_sklearn.txt')
dataset_val = DataGenerator(Config_val)

Config_test = Config.copy()
Config_test['Corpus_File_Path'] = os.path.join(PROJECT_PATH, f'Data_color/color_corpus_lab_bins_16_test_sklearn.txt')
dataset_test = DataGenerator(Config_test)

patience = 30
best = math.inf
wait = 0

for n in range(1):
    EPOCH = 2 # 100 is enough
    for epoch in range(EPOCH):
    #     print(f'dataset length: {len(dataset)}')
        for step in range(len(dataset)):
            batch_x, batch_mlm_mask, batch_mcc_mask, origin_x, batch_segment, batch_padding_mask = dataset[step]
            with tf.GradientTape() as t:
                mlm_predict, sequence_output = model((batch_x, batch_mlm_mask, batch_segment), training=True)

                mlm_loss = loss_fn((mlm_predict, batch_mlm_mask, origin_x))
                mlm_loss = tf.reduce_mean(mlm_loss)

                loss = mlm_loss

            gradients = t.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

            # get acc of random mask
            mlm_acc = calculate_pretrain_task_accuracy(mlm_predict, batch_mlm_mask, origin_x)

            if step == len(dataset) - 1 and epoch % 1 == 0:
                print(
                    'Epoch {}, step {}, loss {:.4f}, mlm_loss {:.4f}, mlm_acc {:.4f}'.format(
                        epoch, step, loss.numpy(),
                        mlm_loss.numpy(),
                        mlm_acc,
                        ))

        for val_step in range(len(dataset_val)):
            val_batch_x, val_batch_mlm_mask, val_batch_mcc_mask, val_origin_x, val_batch_segment, val_batch_padding_mask = dataset_val[val_step]
            val_mlm_predict, val_sequence_output = model((val_batch_x, val_batch_mlm_mask, val_batch_segment), training=False)

            val_mlm_loss = loss_fn((val_mlm_predict, val_batch_mlm_mask, val_origin_x))
            val_mlm_loss = tf.reduce_mean(val_mlm_loss)

            # get acc of random mask
            val_mlm_acc = calculate_pretrain_task_accuracy(val_mlm_predict, val_batch_mlm_mask, val_origin_x)

            val_loss = val_mlm_loss

            if val_step == len(dataset_val) - 1 and epoch % 1 == 0:
                print(
                    'Val: Epoch {}, step {}, loss {:.4f}, mlm_loss {:.4f}, mlm_acc {:.4f}'.format(
                        epoch, val_step, val_loss.numpy(),
                        val_mlm_loss.numpy(),
                        val_mlm_acc,
                        ))

        path = manager.save(checkpoint_number=epoch)

        # early stopping
        wait += 1
        if val_loss < best:
            best = val_loss
            wait = 0
        if wait >= patience:
            break

    Config['Mask_Rate'] = 0
    for test_step in range(len(dataset_test)):
        test_batch_x, test_batch_mlm_mask, test_batch_mcc_mask, test_origin_x, test_batch_segment, test_batch_padding_mask = dataset_test[test_step]
        test_mlm_predict, test_sequence_output = model((test_batch_x, test_batch_mlm_mask, test_batch_segment), training=False)

        test_mlm_loss = loss_fn((test_mlm_predict, test_batch_mlm_mask, test_origin_x))
        test_mlm_loss = tf.reduce_mean(test_mlm_loss)

        # get acc of random mask
        test_mlm_acc = calculate_pretrain_task_accuracy(test_mlm_predict, test_batch_mlm_mask, test_origin_x)

        test_loss = test_mlm_loss

        if test_step == len(dataset_test) - 1:
            print(
                'Test: Epoch {}, step {}, loss {:.4f}, mlm_loss {:.4f}, mlm_acc {:.4f}'.format(
                    epoch, test_step, test_loss.numpy(),
                    test_mlm_loss.numpy(),
                    test_mlm_acc,
                    ))

    # model.save(f'model/bert_{representation}_{n}')

Epoch 0, step 292, loss 0.6399, mlm_loss 0.6399, mlm_acc 0.1111
Val: Epoch 0, step 35, loss 0.9824, mlm_loss 0.9824, mlm_acc 0.1250
Epoch 1, step 292, loss 1.0544, mlm_loss 1.0544, mlm_acc 0.0968
Val: Epoch 1, step 35, loss 0.5287, mlm_loss 0.5287, mlm_acc 0.1250
Test: Epoch 1, step 34, loss 0.7219, mlm_loss 0.7219, mlm_acc 0.0952
