### This is the pre-training of text-aware masked color model for color palette completion
- It takes 50mins on GPU (Tesla T4 * 1) for one time training with early stop (patience=30)
- For a quick start, do Pretraining with pre-created color corpus files and text embedding files
    - To create color and text data for training, please check preprocess.ipynb

In [1]:
import tensorflow as tf

print(len(tf.config.experimental.list_physical_devices('GPU')))

2023-09-29 06:25:58.095672: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.11.0


1


2023-09-29 06:26:00.034805: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2023-09-29 06:26:00.667292: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-29 06:26:00.667954: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:00:04.0 name: Tesla T4 computeCapability: 7.5
coreClock: 1.59GHz coreCount: 40 deviceMemorySize: 14.75GiB deviceMemoryBandwidth: 298.08GiB/s
2023-09-29 06:26:00.667995: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.11.0
2023-09-29 06:26:00.677587: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.11
2023-09-29 06:26:00.679843: I tensorflow/stream_executor/platform/default/d

/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2023-09-29 06:26:00.690920: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
2023-09-29 06:26:00.692408: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.11
2023-09-29 06:26:00.693314: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.8
2023-09-29 06:26:00.693444: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-29 06:26:00.694227: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-29 06:26:00.694781: I tensorflow/core/common_ru

In [2]:
import pandas as pd
from collections import defaultdict  # For word frequency
import os
import math
import random
import ast
from datetime import datetime
from collections import Counter
import numpy as np

import sys
sys.path.append('../src')

import color_palette_completion.text_color_model.color_model as Model
from color_palette_completion.text_color_model.input_data_generator import DataGenerator
from color_palette_completion.text_color_model.model_config import Config

In [3]:
def calculate_pretrain_task_accuracy(mlm_predict, batch_mlm_mask, origin_x):
    
    batch_mlm_mask = tf.cast(batch_mlm_mask, dtype=tf.int32)
    index = tf.where(batch_mlm_mask == 1)
    x_predict = tf.math.argmax(mlm_predict, axis=-1)
    x_predict = tf.gather_nd(x_predict, index)
    x_real = tf.gather_nd(origin_x, index)
    mlm_accuracy = tf.keras.metrics.Accuracy()
    mlm_accuracy.update_state(x_predict, x_real)
    mlm_accuracy = mlm_accuracy.result().numpy()

    return mlm_accuracy

In [4]:
# pretrain

for n in range(1):
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
    # physical_devices = tf.config.experimental.list_physical_devices('CPU')
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

    model = Model.Text2Palettes(Config)
    optimizer = tf.keras.optimizers.Adam(Config['Learning_Rate'])
    loss_fn = Model.Text2Palettes_Loss()
    dataset = DataGenerator(Config)

    # create the data for validation and test
    PROJECT_PATH = Config['project_path']
    representation = Config['representation']
    text_model = Config['text_model']
    emb_file = Config['emb_file']
    db_tag = Config['db_tag']
    langType = Config['langType']
    kmeansType = Config['kmeansType']

    Config_val = Config.copy()
    Config_val['Corpus_File_Path'] = os.path.join(PROJECT_PATH, f'Data_color/color_corpus_{representation}_val{kmeansType}.txt')
    Config_val['Text_Contents_File_Path'] = os.path.join(PROJECT_PATH, f'Data_text/text_contents{db_tag}_val{langType}.txt')
    Config_val['Image_Labels_File_Path'] = os.path.join(PROJECT_PATH, f'Data_text/image_labels{db_tag}_val{langType}.txt')
    Config_val['Text_Contents_Emb_File_Path'] = os.path.join(PROJECT_PATH, f'Data_text/{emb_file}/text_contents_emb{text_model}_val.txt')
    Config_val['Image_Labels_Emb_File_Path'] = os.path.join(PROJECT_PATH, f'Data_text/{emb_file}/image_labels_emb{text_model}_val.txt')

    Config_test = Config.copy()
    Config_test['Corpus_File_Path'] = os.path.join(PROJECT_PATH, f'Data_color/color_corpus_{representation}_test{kmeansType}.txt')
    Config_test['Text_Contents_File_Path'] = os.path.join(PROJECT_PATH, f'Data_text/text_contents{db_tag}_test{langType}.txt')
    Config_test['Image_Labels_File_Path'] = os.path.join(PROJECT_PATH, f'Data_text/image_labels{db_tag}_test{langType}.txt')
    Config_test['Text_Contents_Emb_File_Path'] = os.path.join(PROJECT_PATH, f'Data_text/{emb_file}/text_contents_emb{text_model}_test.txt')
    Config_test['Image_Labels_Emb_File_Path'] = os.path.join(PROJECT_PATH, f'Data_text/{emb_file}/image_labels_emb{text_model}_test.txt')

    dataset_val = DataGenerator(Config_val)
    dataset_test = DataGenerator(Config_test)

    patience = 30 # baseline:10
    best = math.inf
    wait = 0

    Config['Saved_Weight'] = os.path.join(PROJECT_PATH, f'Saved_Weight{text_model}_{Config["Embedding_Size"]}d_{representation}_{Config["Mask_Rate"]}_{Config["Mask_Token_Rate"]}_v{n}')
    checkpoint = tf.train.Checkpoint(model=model)
    checkpoint.restore(tf.train.latest_checkpoint(Config['Saved_Weight']))
    manager = tf.train.CheckpointManager(checkpoint, directory=Config['Saved_Weight'], max_to_keep=5)
    log_dir = os.path.join(Config['Log_Dir'], datetime.now().strftime("%Y-%m-%d"))
    writer = tf.summary.create_file_writer(log_dir)
    
    EPOCH = 2000 # 2000 for training
    for epoch in range(EPOCH):
        for step in range(len(dataset)):
            batch_x, batch_mlm_mask, batch_mcc_mask, origin_x, batch_segment, batch_padding_mask, batch_text_contents_embed, batch_image_labels_embed = dataset[step]
      
            with tf.GradientTape() as t:
                mlm_predict, sequence_output = model((batch_x, batch_mlm_mask, batch_segment, batch_text_contents_embed, batch_image_labels_embed), training=True)

                mlm_loss = loss_fn((mlm_predict, batch_mlm_mask, origin_x))
                mlm_loss = tf.reduce_mean(mlm_loss)

                loss = mlm_loss

            gradients = t.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

            mlm_acc = calculate_pretrain_task_accuracy(mlm_predict, batch_mlm_mask, origin_x)

            if step == len(dataset) - 1 and epoch % 10 == 0:
                print(
                    'Epoch {}, step {}, loss {:.4f}, mlm_loss {:.4f}, mlm_acc {:.4f}'.format(
                        epoch, step, loss.numpy(),
                        mlm_loss.numpy(),
                        mlm_acc,
                        ))

        for val_step in range(len(dataset_val)):
            val_batch_x, val_batch_mlm_mask, val_batch_mcc_mask, val_origin_x, val_batch_segment, val_batch_padding_mask, val_batch_text_contents_embed, val_batch_image_labels_embed = dataset_val[val_step]

            val_mlm_predict, val_sequence_output = model((val_batch_x, val_batch_mlm_mask, val_batch_segment, val_batch_text_contents_embed, val_batch_image_labels_embed), training=False)

            val_mlm_loss = loss_fn((val_mlm_predict, val_batch_mlm_mask, val_origin_x))
            val_mlm_loss = tf.reduce_mean(val_mlm_loss)
            
            val_mlm_acc = calculate_pretrain_task_accuracy(val_mlm_predict, val_batch_mlm_mask, val_origin_x)

            val_loss = val_mlm_loss

            if val_step == len(dataset_val) - 1 and epoch % 10 == 0:
                print(
                    'Val: Epoch {}, step {}, loss {:.4f}, mlm_loss {:.4f}, mlm_acc {:.4f}'.format(
                        epoch, val_step, val_loss.numpy(),
                        val_mlm_loss.numpy(),
                        val_mlm_acc,
                        ))
        
        path = manager.save(checkpoint_number=epoch)

        # early stopping
        wait += 1
        if val_loss < best:
            best = val_loss
            wait = 0
        if wait >= patience:
            break
                
    for test_step in range(len(dataset_test)):
        test_batch_x, test_batch_mlm_mask, test_batch_mcc_mask, test_origin_x, test_batch_segment, test_batch_padding_mask, test_batch_text_contents_embed, test_batch_image_labels_embed = dataset_test[test_step]
           
        test_mlm_predict, test_sequence_output = model((test_batch_x, test_batch_mlm_mask, test_batch_segment, test_batch_text_contents_embed, test_batch_image_labels_embed), training=False)

        test_mlm_loss = loss_fn((test_mlm_predict, test_batch_mlm_mask, test_origin_x))
        test_mlm_loss = tf.reduce_mean(test_mlm_loss)

        test_mlm_acc = calculate_pretrain_task_accuracy(test_mlm_predict, test_batch_mlm_mask, test_origin_x)

        test_loss = test_mlm_loss
        
        if test_step == len(dataset_test) - 1:
            print(
                'Test: Epoch {}, step {}, loss {:.4f}, mlm_loss {:.4f}, mlm_acc {:.4f}'.format(
                    epoch, test_step, test_loss.numpy(),
                    test_mlm_loss.numpy(),
                    test_mlm_acc,
                    ))

    # save model
    model.save(f'../data/trained_model/t2p_ca1_mca1_i10t_stop{patience}_lr{Config["Learning_Rate"]}_{text_model}_{Config["Embedding_Size"]}d_{representation}_{Config["Mask_Rate"]}_{Config["Mask_Token_Rate"]}_{n}')


2023-09-29 06:26:46.254903: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2199995000 Hz
2023-09-29 06:26:46.255638: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55e304465560 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2023-09-29 06:26:46.255667: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2023-09-29 06:26:46.354601: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-09-29 06:26:46.355390: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55e305581640 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-09-29 06:26:46.355421: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
2023-09-2



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



2023-09-29 06:28:41.141051: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.11




To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Epoch 0, step 15, loss 1.7074, mlm_loss 1.7074, mlm_acc 0.1689
Val: Epoch 0, step 1, loss 1.7433, mlm_loss 1.7433, mlm_acc 0.1831
Epoch 10, step 15, loss 1.4053, mlm_loss 1.4053, mlm_acc 0.2564
Val: Epoch 10, step 1, loss 1.2969, mlm_loss 1.2969, mlm_acc 0.2735
Epoch 20, step 15, loss 1.2231, mlm_loss 1.2231, mlm_acc 0.3307
Val: Epoch 20, step 1, loss 1.2067, mlm_loss 1.2067, mlm_acc 0.3261
Epoch 30, step 15, loss 1.0874, mlm_loss 1.0874, mlm_acc 0.3146
Val: Epoch 30, step 1, loss 1.2477, mlm_loss 1.2477, mlm_acc 0.3245
Epoch 40, step 15, loss 1.1058, mlm_loss 1.1058, mlm_acc 0.3698
Val: Epoch 40, step 1, loss 1.1427, mlm_loss 1.1427, mlm_acc 0.3675
Epoch 50, step 15, loss 1.0422, mlm_loss 

2023-09-29 07:10:16.008276: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: ../data/pretrained_model/t2p_ca1_mca1_i10t_stop30_lr0.0002__clip_512d_lab_bins_16_0.4_0.5_0/assets
