## 아래 git을 참고함
+ ㅁ
+ Pytorch 구현 -> TensorFlow로 재구현함

In [None]:
import tensorflow

from tensorflow.keras import layers
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Model

from dataclasses import dataclass

## model 사용 시 활용되는 config 정의

In [None]:
@dataclass
class SidConfig:
    num_numerical_features: int
    num_categorical_embedding_features: int
    num_total_features: int
    embedding_size = 64
    im_size = 1024
    hidden_size = 256
    num_hidden_layers = 18
    num_attention_blocks = 1
    num_transform_blocks = 1
    hidden_dropout_prob = 0.5
    attention_dropout_prob = 0.5
    
    num_labels: int

## 모델 정의

In [None]:
class SidEmbeddings(Layer):
    def __init__(self, config):
        super(SidEmbeddings, self).__init__()
        
        self.cat_emb = layers.Embedding(config.num_categorical_embedding_features,
                                       config.embedding_size,
                                       name = 'categorical_embedding')
        
        self.numerical_direction = tf.Variable(tf.random.uniform(shape = (config.num_numerical_features,
                                                                         config.embedding_size),
                                                                name = 'numerical_direction'))
        self.numerical_anchor = tf.Variable(tf.random.uniform(shape = (config.num_numerical_features,
                                                                      config.embedding_size),
                                                             name = 'numerical_anchor'))
        
    def call(self, categorical_inputs, numerical_inputs):
        """
            categorical_inputs: (bs, num_categorical)
            numerical_inputs: (bs, num_numerical)
        """
        
        # (bs, num_cat_features, embedding_size)
        categorical_embeddings = self.cat_emb(categorical_inputs)
        
        # (bs, num_numerical_features, 1) * (num_numerical_features, embedding_size)
        numerical_embeddings = numerical_inputs[: ,:, None] * self.numerical_direction
        # (bs, num_numerical_features, embedding_size) + (bs, num_numerical_features, embedding_size)
        numerical_embeddings = numerical_embeddings + self.numerical_anchor
        
        return tf.concat((categorical_embeddings, numerical_embeddings), axis = 1)
    
    
class SidResidualBlock(Layer):
    def __init__(self, config, **kwargs):
        super(SidResidualBlock, self).__init__(**kwargs)
        
        self.feedforward = tf.keras.Sequential([
            layer.Dense(config.im_size, activation = 'elu'),
            layer.Dense(config.hidden_size * 2)
        ])
        self.dropout = layers.Dropout(config.hidden_dropout_prob)
        
    def call(self, hidden_states):
        """
            hidden_states: (bs, hidden_size)
        """
        
        # (bs, hidden_size * 2)
        output = self.feedforward(hidden_states)
        
        # (bs, hidden_size), (bs, hidden_size)
        output, gating = tf.split(output, num_or_size_splits = 2, axis = 1)
        
        # (bs, hidden_size) * (bs, hidden_size)
        output = output * tf.math.sigmoid(gating)
        
        return hidden_states + self.dropout(output)
    
    
class SidLayer(Layer):
    def __init__(self, config, use_attention = True, **kwargs):
        super(SidLayer, self).__init__(**kwargs)
        
        if use_attention:
            self.attention = tf.keras.Sequential([
                *[SidResidualBlock(config, name = f'attention_block_{i}' for i in range(config.num_attention_blocks))],
                layers.Dense(config.num_total_features, activation = 'sigmoid')
            ])
            
            self.dropout = layers.Dropout(config.attention_dropout_prob)
            
        self.projection = layers.Dense(config.hidden_size)
        self.transform = tf.keras.Sequential([
            *[SidResidualBlock(config, name = f'transform_block_{i}' for i in range(config.num_transform_blocks))]
        ])
        
    def call(self, input_embeddings, hidden_states):
        """
            input_embeddings: (bs, num_cat_features + num_numerical_features, embedding_size)
            hidden_states: (bs, hidden_size)
        """
        if hasattr(self, "attention") and hidden_states is not None:
            attention_probs = self.attention(hidden_states)
            attention_probs = self.dropout(attention_probs)
            input_embeddings = input_embeddings * attention_prob[:, :, None]
            
        # (bs, num_total_features * embedding_size)
        output = layers.Flatten()(input_embeddings)
        # (bs, hs)
        output = self.projection(output)
        # (bs, hs)
        output = self.transform(output)
        
        if hidden_states is not None:
            return hidden_states + output
        else:
            return output
        
    
class SidModel(Layer):
    def __init__(self, config, **kwargs):
        super(SidModel, self).__init__(**kwargs)
        
        self.config = config
        self.embeddings = SidEmbeddings(config)
        
        self.layers = [SidLayer(config, use_attention = i > 0, name = f'SidLayer_{i}') for i in range(config.num_hidden_layers)]
        
        self.normalization = layers.LayerNormalization()
        
    def call(self, inputs):
        """
            categorical_inputs: (bs, num_cat_features, )
            numerical_inputs: (bs, num_numerical_features, )
        """
        
        categorical_inputs, numerical_inputs = inputs['categorical'], inputs['numerical']
        
        # (bs, num_cat_features + num_numerical_features, embedding_size)
        input_embeddings = self.embeddings(categorical_inputs, numerical_inputs)
        
        hidden_states = None
        for i in range(len(self.layers)):
            hidden_states = self.layers[i](input_embeddings, hidden_states)
            
        # (bs, hidden_size)
        hidden_states = self.normalization(hidden_states)
        
        return hidden_states
    
    
class SidClassifier(Model):
    def __init__(self, config, **kwargs):
        super(SidClassifier, self).__init__(**kwargs)
        
        self.config = config
        self.model = SidModel(config, name = 'SidModel')
        self.classifier = layers.Dense(config.num_labels, activation = 'sigmoid', name = 'final_dense')
        
    def call(self, inputs):
        # (bs, hs)
        hidden_states = self.model(inputs)
        # (bs, num_labels)
        logits = self.classifier(hidden_states)
        
        return logits

## Trainer 예시

In [None]:
import tensorflow as tf
import tensorflow.keras.backend as K
from tqdm import tqdm

In [None]:
class Trainer:
    def __init__(self, model, optimizer = None, loss_fn = None, metrics = None):
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.metrics = metrics
        
        self.loss_plot = []
        self.val_loss_plot = []
        
    @tf.function
    def train_step(self, inp, tar, training = True):
        with tf.GradientTape() as tape:
            predictions = self.model(inp)
            loss = self.loss_fn(tar, predictions)
            
        if training:
            gradients = tape.gradient(loss, self.model.trainable_variables)
            self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
            
        acc = self._get_acc(tar, predictions)
        
        return loss, predictions, acc
    
    def _get_acc(self, y_true, y_pred):
        def get_acc(y_true, y_pred):
            return K.mean(K.mean(K.equal(tf.cast(y_true, tf.int64),
                                         tf.cast(K.round(y_pred), tf.int64)),
                                 axis =-1))
        return tf.py_function(get_acc, (y_true, y_pred), tf.float32)
    
    def save_model(self, savepath = None):
        assert savepath == None
        
        self.model.save_weights(savepath)
        print(f'success model weights save, {savepath}')
        
    def train(self, train_dataset, valid_dataset = None, epochs = 1):
        train_auc = tf.keras.metrics.AUC(multi_label = True)
        val_auc = tf.keras.metrics.AUC(multi_label = True)
        
        for epoch in range(epochs):
            train_total_loss = 0.
            train_total_acc = 0.
            
            valid_total_loss = 0.
            valid_total_acc = 0.
            
            train_auc.reset_states()
            valid_auc.reset_states()
            
            train_tqdm_dataset = tqdm(enumerate(train_dataset))
            for (batch, (tensor, target)) in train_tqdm_dataset:
                loss, predictions, acc = self.train_step(tensor, target)
                train_total_loss += loss
                train_total_acc += acc
                train_auc.update_state(target, predictions)
                
                train_total_loss_format = train_total_loss / (batch + 1)
                train_total_acc_format = train_total_acc / (batch + 1)
                train_auc_format = train_auc.result().numpy()
                
                train_tqdm_dataset.set_postfix({
                    'Epoch': epoch,
                    'Loss': '{:.4f}'.format(loss.numpy()),
                    'Total Loss': '{:.4f}'.format(train_total_loss_foramt),
                    'Accuracy': '{:.4f}'.format(train_total_acc_format),
                    'AUC': '{:.4f}'.format(train_auc_format)
                })
                
            self.loss_plot.append(train_total_loss_format)
            
            val_tqdm_dataset = tqdm(enumerate(valid_dataset))
            for (batch, (tensor, target)) in val_tqdm_dataset:
                val_loss, val_predictions, val_acc = self.train_step(tensor, target, training = False)
                valid_total_loss += val_loss
                valid_total_acc += val_acc
                val_auc.update_state(target, val_predictions)
                
                valid_total_loss_format = valid_total_loss / (batch + 1)
                valid_total_acc_format = valid_total_acc / (batch + 1)
                val_auc_format = val_auc.result().numpy()
                
                val_tqdm_dataset.set_postfix({
                    'Epoch': epoch,
                    'val_Loss': '{:.4f}'.format(val_loss.numpy()),
                    'val Total Loss': '{:.4f}'.format(val_total_loss_foramt),
                    'val Accuracy': '{:.4f}'.format(val_total_acc_format),
                    'val AUC': '{:.4f}'.format(val_auc_format)  
                })
                
            self.val_loss_plot.append(valid_total_loss_format)
            
            print(f'{epoch} - Training loss: {self.loss_plot[-1]:.4f}\
                            - Training Accuracy: {train_total_acc_format:.4f}\
                            - Training AUC: {train_auc_format:.4f}\
                            - Validation loss: {self.val_loss_plot[-1]:.4f}\
                            - Validation Accuracy: {valid_total_acc_format:.4f}\
                            - Validation AUC: {val_auc_format:.4f}')