# Train and Prediction
model training and inference
- <a href='#1'>1. model</a> 
- <a href='#2'>2. train and predict</a> 

In [1]:
import os
import sys
import pickle
import gc
from datetime import datetime

from gensim.models import Word2Vec
from gensim.models import KeyedVectors
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import (
    to_categorical,
    Sequence
)
from tensorflow.keras import (
    Input,
    Model
)
from tensorflow.keras.layers import (
    Embedding,
    Bidirectional,
    LSTM,
    GRU,
    Dense,
    concatenate,
    Activation,
    BatchNormalization,
    TimeDistributed,
    Dropout,
    Lambda,
    Conv1D,
    GlobalMaxPooling1D,
    GlobalAveragePooling1D,
    TimeDistributed,
    Dropout,
    Lambda,
    Conv1D,
    Conv2D,
    MaxPooling2D,
    Flatten
)

from tensorflow.keras.callbacks import (
    EarlyStopping,
    ReduceLROnPlateau,
    ModelCheckpoint,
    LearningRateScheduler
)
from tensorflow.keras.optimizers import Adagrad
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from transformers import (
    BertConfig,
    TFBertPreTrainedModel
)
from transformers.modeling_tf_utils import (
    keras_serializable,
    shape_list,
    get_initializer
)
from transformers.modeling_tf_bert import (
    TFBertEncoder
)
from sklearn.model_selection import StratifiedKFold

sys.path.append('../')
from utils import (
    LogManager
)
import conf
import traceback

In [None]:
# global setting
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

LogManager.created_filename = os.path.join(conf.LOG_DIR, 'train.log')
logger = LogManager.get_logger(__name__)

In [None]:
# global variables
embed_size = 200
max_len = 80  # 用户点击的中位数为22个广告。
batch_size = 256
epochs = 20
age_class_num = 10
gender_class_num = 2
len_stats = 50
agg_col = ['creative_id', 'ad_id', 'product_id', 'advertiser_id', 'industry']

In [None]:
# utility functions 
def decay_schedule(epoch, lr):
    # decay by 0.1 every 9 epochs; use `% 1` to decay after each epoch
    if epoch == 9:
        lr = lr * 0.1
    return lr

class CustomCallback(tf.keras.callbacks.Callback):

    @staticmethod
    def on_epoch_end(epoch, logs=None):
        logger.info('epoch: %s, val_loss: %s, val_age_out_accuracy: %s, val_gender_out_accuracy: %s, val_acc: %s' % (
            epoch,
            logs['val_loss'],
            logs['val_age_out_accuracy'],
            logs['val_gender_out_accuracy'],
            float(logs['val_age_out_accuracy']) +
            float(logs['val_gender_out_accuracy'])
        )
                    )
        
def search_weight(
        valid_y,
        raw_prob,
        class_num,
        step=0.001
):
    init_weight = [1.0] * class_num
    weight = init_weight.copy()
    f_best = accuracy_score(
        y_true=valid_y,
        y_pred=raw_prob.argmax(
            axis=1))

    flag_score = 0
    round_num = 1
    while flag_score != f_best:
        logger.info("round: %s" % round_num)
        round_num += 1
        flag_score = f_best
        for c in range(class_num):
            for n_w in range(0, 2000, 10):
                num = n_w * step
                new_weight = weight.copy()
                new_weight[c] = num
                prob_df = raw_prob.copy()
                prob_df = prob_df * np.array(new_weight)
                f = accuracy_score(y_true=valid_y, y_pred=prob_df.argmax(
                    axis=1))
                if f > f_best:
                    weight = new_weight.copy()
                    f_best = f
    logger.info('flag_score: %s' % flag_score)
    return weight, flag_score

class DataGenerator(tf.keras.utils.Sequence):
    """"Generates data for Keras"""

    def __init__(self, dataset_train, labels_train):
        """Initialization"""
        self.dataset_train = dataset_train
        self.labels_train = labels_train
        self.batch_size = batch_size
        self.click_times_length = self.dataset_train['click_times_length']
        self.click_times_prob = self.dataset_train['click_times_prob']
        self.stats_feat = self.dataset_train['stats_feat']
        self.total_size = self.click_times_length.shape[0]
        self.indexes = np.arange(self.total_size)

        self.enhanced_data = {}
        self.epoch_count = 0
        self.on_epoch_end()

    def on_epoch_end(self):
        logger.info('reinforment data begin')
        del self.enhanced_data
        gc.collect()
        self.enhanced_data = {}

        keys = [key for key in self.dataset_train.keys() if
                key != 'click_times_length' and key != 'user_id' and key != 'click_times_prob' and key != 'stats_feat']
        for key in keys:
            self.enhanced_data[key] = []
        for index, value in enumerate(self.click_times_length):
            if value[0] > max_len:
                value[0] = max_len
            big = np.random.randint(int(value[0] * (2 / 3)), value[0])
            sampled_action = np.random.choice(value[0], big, replace=False, p=self.click_times_prob[index])
            for key in keys:
                row = self.dataset_train[key][index]
                if key == 'attention_mask':
                    self.enhanced_data[key].append(
                        np.hstack([row[sampled_action], np.array([-10000] * (max_len - len(sampled_action)))]))
                else:
                    self.enhanced_data[key].append(
                        np.hstack([row[sampled_action], np.array([0] * (max_len - len(sampled_action)))]))
        for key in keys:
            self.enhanced_data[key] = np.stack(self.enhanced_data[key])

        self.enhanced_data['click_times_length'] = self.click_times_length
        self.enhanced_data['click_times_prob'] = self.click_times_prob
        self.enhanced_data['stats_feat'] = self.stats_feat
        np.random.shuffle(self.indexes)
        logger.info('reinforment data end')

    def __len__(self):
        """"Denotes the number of batches per epoch"""
        return int(np.floor(self.total_size / self.batch_size))

    def __getitem__(self, index):
        """Generate one batch of data"""
        # Generate indexes of the batch
        high = (index + 1) * self.batch_size
        if high > self.total_size:
            high = self.total_size
        batch_data = {}
        selected = self.indexes[index * self.batch_size:high]
        for key in self.enhanced_data.keys():
            batch_data[key] = self.enhanced_data[key][selected]
        batch_labels = {
            'age_out': self.labels_train['age_out'][selected],
            'gender_out': self.labels_train['gender_out'][selected]
        }
        return batch_data, batch_labels

class OnEpochEnd(tf.keras.callbacks.Callback):
    def __init__(self, callbacks):
        self.callbacks = callbacks

    def on_epoch_end(self, epoch, logs=None):
        for callback in self.callbacks:
            callback()

### <a id='1'> 1.model</a>

In [None]:
class TFBertMainLayer(tf.keras.layers.Layer):
    config_class = BertConfig

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.num_hidden_layers = config.num_hidden_layers
        self.initializer_range = config.initializer_range
        self.output_attentions = config.output_attentions
        self.encoder = TFBertEncoder(config, name="encoder")

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        raise NotImplementedError

    def call(
            self,
            inputs,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            output_attentions=None,
            training=True,
    ):
        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
            output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
            assert len(inputs) <= 7, "Too many inputs."
        else:
            input_ids = inputs

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = shape_list(input_ids)
        elif inputs_embeds is not None:
            input_shape = shape_list(inputs_embeds)[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if attention_mask is None:
            attention_mask = tf.fill(input_shape, 1)
        if token_type_ids is None:
            token_type_ids = tf.fill(input_shape, 0)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.

        extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.num_hidden_layers

        encoder_outputs = self.encoder(
            [inputs_embeds, extended_attention_mask, head_mask, None, None], training=training
        )

        sequence_output = encoder_outputs[0]
        outputs = sequence_output  # add hidden_states and attentions if they are here
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)

class TransformerBasedModel(tf.keras.layers.Layer):
    def __init__(self, config, embedding_group, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.config = config
        vocab_len_list = []
        for index, key in enumerate(embedding_group.keys()):
            vocab_len_list += [len(embedding_group[key])]

        self.config.vocab_size = vocab_len_list[0]
        self.trans_layer_1 = TFBertMainLayer(self.config, name='transformer1')
        self.config.vocab_size = vocab_len_list[1]
        self.trans_layer_2 = TFBertMainLayer(self.config, name='transformer2')
        self.config.vocab_size = vocab_len_list[2]
        self.trans_layer_3 = TFBertMainLayer(self.config, name='transformer3')
        self.config.vocab_size = vocab_len_list[3]
        self.trans_layer_4 = TFBertMainLayer(self.config, name='transformer4')
        self.config.vocab_size = vocab_len_list[4]
        self.trans_layer_5 = TFBertMainLayer(self.config, name='transformer5')
        self.trans_layer_list = [
            self.trans_layer_1,
            self.trans_layer_2,
            self.trans_layer_3,
            self.trans_layer_4,
            self.trans_layer_5,
        ]

        self.bilstm_layer_1 = Bidirectional(LSTM(128, return_sequences=True, kernel_initializer='glorot_uniform'))
        self.bilstm_layer_2 = Bidirectional(LSTM(128, return_sequences=True, kernel_initializer='glorot_uniform'))
        self.bilstm_layer_3 = Bidirectional(LSTM(128, return_sequences=True, kernel_initializer='glorot_uniform'))
        self.bilstm_layer_4 = Bidirectional(LSTM(128, return_sequences=True, kernel_initializer='glorot_uniform'))
        self.bilstm_layer_5 = Bidirectional(LSTM(128, return_sequences=True, kernel_initializer='glorot_uniform'))
        self.bilstm_layer_list = [
            self.bilstm_layer_1,
            self.bilstm_layer_2,
            self.bilstm_layer_3,
            self.bilstm_layer_4,
            self.bilstm_layer_5,
        ]
        self.bilstm_layer_7 = Bidirectional(LSTM(640, return_sequences=True, kernel_initializer='glorot_uniform'))
        self.bilstm_layer_8 = Bidirectional(LSTM(320, return_sequences=True, kernel_initializer='glorot_uniform'))

        self.bilstm_max_pool_1 = GlobalMaxPooling1D()
        self.bilstm_max_pool_2 = GlobalMaxPooling1D()
        self.bilstm_max_pool_3 = GlobalMaxPooling1D()
        self.bilstm_max_pool_4 = GlobalMaxPooling1D()
        self.bilstm_max_pool_5 = GlobalMaxPooling1D()
        self.bilstm_max_pool_6 = GlobalMaxPooling1D()

        self.bilstm_max_pool_list = [
            self.bilstm_max_pool_1,
            self.bilstm_max_pool_2,
            self.bilstm_max_pool_3,
            self.bilstm_max_pool_4,
            self.bilstm_max_pool_5,
            self.bilstm_max_pool_6,
        ]

        self.conv_max_pool_1 = GlobalMaxPooling1D()
        self.conv_max_pool_2 = GlobalMaxPooling1D()
        self.conv_max_pool_3 = GlobalMaxPooling1D()
        self.conv_max_pool_list = [
            self.conv_max_pool_1,
            self.conv_max_pool_2,
            self.conv_max_pool_3
        ]

        self.conv_1d_1 = Conv1D(128, kernel_size=2, padding="valid", kernel_initializer="he_uniform")
        self.conv_1d_2 = Conv1D(128, kernel_size=3, padding="valid", kernel_initializer="he_uniform")
        self.conv_1d_3 = Conv1D(128, kernel_size=4, padding="valid", kernel_initializer="he_uniform")
        self.conv_1d_list = [
            self.conv_1d_1,
            self.conv_1d_2,
            self.conv_1d_3
        ]

        self.dropout_1 = Dropout(0.4)
        self.dropout_2 = Dropout(0.4)

        self.dense_1 = Dense(512, activation='relu')
        self.dense_2 = Dense(128, activation='relu')

        self.stats_dense1 = Dense(256, activation='relu')
        self.stats_dropout_1 = Dropout(0.2)
        self.stats_dropout_2 = Dropout(0.2)
        self.stats_dense2 = Dense(128, activation='relu')

    def call(self, embeds, attention_mask, stats_feat, **kwargs):
        temp_out = []
        origin_mask = attention_mask[:, :, tf.newaxis]
        clip_mask = tf.clip_by_value(origin_mask, clip_value_min=0, clip_value_max=1)
        max_mask = (1.0 - clip_mask) * -10000.0
        embed_all = concatenate(embeds)
        out = self.bilstm_layer_7(embed_all)
        out = out * clip_mask
        out = self.bilstm_layer_8(out)
        for_max = out + max_mask

        for i, conv1_d in enumerate(self.conv_1d_list):
            con_x = conv1_d(out)
            max_pool = self.conv_max_pool_list[i](con_x)
            temp_out.append(max_pool)
        max_pool = self.bilstm_max_pool_list[-1](for_max)
        temp_out.append(max_pool)

        for index, embed in enumerate(embeds):
            trans_inputs = [None, attention_mask, None, None, None, embed]
            trans_outputs = self.trans_layer_list[i](trans_inputs, training=True)
            bilstm_outputs = self.bilstm_layer_list[i](trans_outputs)
            bilstm_max_outputs = self.bilstm_max_pool_list[i](bilstm_outputs)
            temp_out.append(bilstm_max_outputs)

        stats_out = self.stats_dropout_1(stats_feat)
        stats_out = self.stats_dense1(stats_out)
        stats_out = self.stats_dropout_2(stats_out)
        stats_out = self.stats_dense2(stats_out)
        temp_out.append(stats_out)

        x = concatenate(temp_out)
        x = self.dropout_1(x)
        x = self.dense_1(x)
        x = self.dropout_2(x)
        x = self.dense_2(x)
        return x

def model_(
        config,
        embedding_group
):
    strategy = tf.distribute.MirroredStrategy()
    logger.info('Number of devices: {}'.format(strategy.num_replicas_in_sync))

    with strategy.scope():
        embeds = []
        all_inputs = []
        for col in agg_col:
            input_x = Input(shape=(max_len,), name=col)
            embed_x = Embedding(
                input_dim=embedding_group[col].shape[0],
                output_dim=embedding_group[col].shape[1],
                weights=[embedding_group[col]],
                input_length=max_len,
                trainable=False,
                name=col + '_embeding',
                mask_zero=False,
            )(input_x)
            embeds.append(embed_x)
            all_inputs.append(input_x)

        attention_mask = Input(shape=(max_len,), name='attention_mask')
        click_times_prob = Input(shape=(max_len,), name='click_times_prob')
        click_length_input = Input(shape=(1,), name='click_times_length')
        stats_feat = Input(shape=(len_stats,), name='stats_feat')
        all_inputs.append(attention_mask)
        all_inputs.append(click_length_input)
        all_inputs.append(stats_feat)
        all_inputs.append(click_times_prob)
        gc.collect()

        x = TransformerBasedModel(config, embedding_group, name='transformer')(embeds, attention_mask,
                                                                               stats_feat=stats_feat)
        age_out = Dense(10, activation='softmax', name='age_out')(x)
        gender_out = Dense(2, activation='softmax', name='gender_out')(x)
        model = Model(inputs=all_inputs, outputs=[age_out, gender_out])
        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3, epsilon=1e-08, clipnorm=1.0)
        model.compile(
            optimizer=optimizer,
            loss={
                'age_out': 'categorical_crossentropy',
                'gender_out': 'categorical_crossentropy'
            },
            metrics={
                'age_out': 'accuracy',
                'gender_out': 'accuracy'
            },
            loss_weights={
                'age_out': 0.5,
                'gender_out': 0.5
            }
        )
    return model

### <a id='2'> 2.train and predict</a>

In [None]:
def train(
        dataset_train_x,
        dataset_test_x,
        df_test_final,
        label_age,
        label_gender,
        folds,
        config,
        embedding_group,
):
    try:
        age_score = []
        age_score_weight = []
        age_sub = np.zeros((df_test_final.shape[0], 10))
        age_sub_weight = np.zeros((df_test_final.shape[0], 10))
        gender_score = []
        gender_score_weight = []
        gender_sub = np.zeros((df_test_final.shape[0], 2))
        gender_sub_weight = np.zeros((df_test_final.shape[0], 2))

        skf = StratifiedKFold(n_splits=5, random_state=100, shuffle=True)
        count = 0
        for i, (train_index, test_index) in enumerate(folds):
            logger.info("FOLD | %s" % (count + 1))
            logger.info("###" * 35)
            gc.collect()
            filepath = "./model/nn_v1_%s.ckpt" % count
            checkpoint = ModelCheckpoint(
                filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min', save_weights_only=True)
            #         reduce_lr = ReduceLROnPlateau(
            #             monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6, verbose=1, mode='min')
            earlystopping = EarlyStopping(
                monitor='val_loss', min_delta=0.0001, patience=5, verbose=1, mode='min')
            lr_scheduler = LearningRateScheduler(decay_schedule)
            logcallback = CustomCallback()
            logger.info('model params: %s' % config)
            model = model_(
                config,
                embedding_group
            )
            if count == 0:
                stringlist = []
                model.summary(print_fn=lambda x: stringlist.append(x))
                short_model_summary = "\n".join(stringlist)
                logger.info(short_model_summary)

            fold_data_train_x = {}
            fold_data_val_x = {}
            for col in agg_col + ['attention_mask'] + ['click_times_length'] + ['stats_feat']:
                fold_data_train_x[col] = dataset_train_x[col][train_index]
                fold_data_val_x[col] = dataset_train_x[col][test_index]
            fold_data_train_y = {
                'age_out': label_age[train_index],
                'gender_out': label_gender[train_index]
            }
            fold_data_val_y = {
                'age_out': label_age[test_index],
                'gender_out': label_gender[test_index]
            }
            data_generator = DataGenerator(fold_data_train_x, fold_data_train_y)
            callbacks = [checkpoint, earlystopping, lr_scheduler, logcallback,
                         OnEpochEnd([data_generator.on_epoch_end])]
            hist = model.fit(
                data_generator,
                epochs=epochs,
                validation_data=(fold_data_val_x, fold_data_val_y),
                callbacks=callbacks,
                verbose=1,
                shuffle=True
            )
            model.load_weights(filepath)
            age_val_prob, gender_val_prob = model.predict(fold_data_val_x, batch_size=512, verbose=1)
            age_weight, age_flag_score = search_weight(
                'age',
                fold_data_val_y['age_out'].argmax(
                    axis=1),
                age_val_prob,
                age_class_num
            )
            gender_weight, gender_flag_score = search_weight(
                'gender',
                fold_data_val_y['gender_out'].argmax(
                    axis=1),
                gender_val_prob,
                gender_class_num
            )
            # oof_pred[test_index] = model_age.predict(dataset_test_x,batch_size=512,verbose=1)
            age_tmp_sub, gender_tmp_sub = model.predict(dataset_test_x, batch_size=512, verbose=1)
            cur_time = datetime.now().isoformat()
            np.save('./age_prob/age_result_fold_%s_acc_%s_%s' % (
                count + 1, np.max(hist.history['val_age_out_accuracy']), cur_time), age_tmp_sub / skf.n_splits)
            np.save('./age_prob/age_result_weight_fold_%s_acc_%s_%s' % (count + 1, age_flag_score, cur_time),
                    age_tmp_sub / skf.n_splits * np.asarray(age_weight))
            np.save('./gender_prob/gender_result_fold_%s_acc_%s_%s' % (
                count + 1, np.max(hist.history['val_gender_out_accuracy']), cur_time), gender_tmp_sub / skf.n_splits)
            np.save('./gender_prob/gender_result_weight_fold_%s_acc_%s_%s' % (count + 1, gender_flag_score, cur_time),
                    gender_tmp_sub / skf.n_splits * np.asarray(gender_weight))
            age_sub_weight += age_tmp_sub / skf.n_splits * np.asarray(age_weight)
            gender_sub_weight += gender_tmp_sub / skf.n_splits * np.asarray(gender_weight)
            age_sub += age_tmp_sub / skf.n_splits
            gender_sub += gender_tmp_sub / skf.n_splits
            logger.info(np.min(hist.history['val_loss']))
            age_score.append(np.max(hist.history['val_age_out_accuracy']))
            gender_score.append(np.max(hist.history['val_gender_out_accuracy']))
            age_score_weight.append(age_flag_score)
            gender_score_weight.append(gender_flag_score)
            count += 1

            del model, data_generator, hist, fold_data_train_x, fold_data_train_y, fold_data_val_x, fold_data_val_y
            tf.keras.backend.clear_session()
            gc.collect()

        logger.info("age score: %s" % age_score)
        logger.info("age acc: %s" % np.mean(age_score))
        logger.info("age score weight: %s" % age_score_weight)
        logger.info("age weight acc: %s" % np.mean(age_score_weight))

        logger.info("gender score: %s" % gender_score)
        logger.info("gender acc: %s" % np.mean(gender_score))
        logger.info("gender score weight: %s" % gender_score_weight)
        logger.info("gender weight acc: %s" % np.mean(gender_score_weight))
        return age_sub_weight, age_sub, gender_sub_weight, gender_sub

    except Exception as e:
        logger.error(str(e))
        logger.error(traceback.format_exc())

In [None]:
# main 
if __name__ == "__main__":

    logger.info('train start')
    word_indexs = pickle.load(open("./model/cache/word_index.pkl", "rb"))
    fe_df = pd.read_parquet('./model/cache/fe_df_150.parquet')
    logger.info('data has been loaded')

    # load w2v
    wv_group = {}
    for col in agg_col:
        wv_group[col] = KeyedVectors.load(f"./martin_model/model/word2vec_{col}_128.model", mmap='r')
    #       wv_group[col] = KeyedVectors.load(f"./model/word2vec_{col}_200.model", mmap='r')

    # divide data into train and test
    df_test_final = fe_df[fe_df.age.isna()]
    df_train_val_final = fe_df[~fe_df.age.isna()]
    assert df_train_val_final.shape[0] == 3000000

    del fe_df
    gc.collect()

    # combine age and gender into one label
    df_train_val_final['y'] = list(zip(df_train_val_final['gender'], df_train_val_final['age']))
    label_map_dict = {
        0: (1, 1),
        1: (1, 2),
        2: (1, 3),
        3: (1, 4),
        4: (1, 5),
        5: (1, 6),
        6: (1, 7),
        7: (1, 8),
        8: (1, 9),
        9: (1, 10),
        10: (2, 1),
        11: (2, 2),
        12: (2, 3),
        13: (2, 4),
        14: (2, 5),
        15: (2, 6),
        16: (2, 7),
        17: (2, 8),
        18: (2, 9),
        19: (2, 10)
    }
    reverse_label_map_dict = dict([(value, key) for key, value in label_map_dict.items()])
    df_train_val_final['y'] = df_train_val_final['y'].apply(lambda x: reverse_label_map_dict[x])
    label_age = to_categorical(df_train_val_final['age'] - 1)
    label_gender = to_categorical(df_train_val_final['gender'] - 1)
    label_y = to_categorical(df_train_val_final['y'])

    # 转成numpy. 用于kfold
    dataset_train_x = {}
    dataset_test_x = {}

    for col in agg_col + ['attention_mask'] + ['click_times_length'] + ['click_times_prob'] + ['stats_feat']:
        dataset_train_x[col] = np.stack(df_train_val_final[col].values)
        dataset_test_x[col] = np.stack(df_test_final[col].values)

    dataset_train_y = {
        'age_out': label_age,
        'gender_out': label_gender,
        'y_out': label_y
    }

    embedding_group = {}
    for col in agg_col:
        nb_words = len(word_indexs[col]) + 1
        embedding_matrix = np.zeros((nb_words, embed_size))
        count = 0
        for word, i in tqdm(word_indexs[col].items()):
            try:
                embedding_vector = wv_group[col][word]
            except KeyError:
                embedding_vector = np.zeros(embed_size)
                count += 1
            if embedding_vector is not None:
                embedding_matrix[i] = embedding_vector
        embedding_group[col] = embedding_matrix
        logger.info("col: %s, null cnt: %s" % (col, count))

    del wv_group, word_indexs
    gc.collect()

    # generate folds
    folds = []
    if not os.path.exists('./temp/folder.pkl'):
        skf = StratifiedKFold(n_splits=5, random_state=1011, shuffle=True)
        count = 0
        for i, (train_index, test_index) in enumerate(skf.split(df_train_val_final, df_train_val_final['y'])):
            folds.append((train_index, test_index))
        f = open("./temp/folder.pkl", "wb")
        pickle.dump(folds, f)
        f.close()
    else:
        f = open("./temp/folder.pkl", "rb")
        folds = pickle.load(f)
        f.close()

    # train
    # model params
    config = BertConfig(
        vocab_size=None,
        hidden_size=128,  # for transformer, should be same as emb_dim
        num_hidden_layers=1,  # for transformer
        num_attention_heads=8,
        intermediate_size=256,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=max_len,
        type_vocab_size=1,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
    )
    cur_time = datetime.now().isoformat()
    age_sub_weight, age_sub, gender_sub_weight, gender_sub = train(
                                                                    dataset_train_x,
                                                                    dataset_test_x,
                                                                    df_test_final,
                                                                    label_age,
                                                                    label_gender,
                                                                    folds,
                                                                    config,
                                                                    embedding_group,
                                                                   )
    age_weight_result = age_sub_weight
    age_result = age_sub
    gender_weight_result = gender_sub_weight
    gender_result = gender_sub
    np.save('./model/age_result_5zhe_weight_transformer_%s' % cur_time, age_weight_result)
    np.save('./model/age_result_5zhe_transformer_%s' % cur_time, age_result)
    np.save('./model/gender_result_5zhe_weight_transformer_%s' % cur_time, gender_weight_result)
    np.save('./model/gender_result_5zhe_transformer_%s' % cur_time, gender_result)
    logger.info('train has been completed')