# tf.kerasのカスタムloss/metricsデモ

In [13]:
import warnings
warnings.filterwarnings('ignore')

## サンプルデータセット用意

In [14]:
import numpy as np
import tensorflow as tf


In [15]:
with tf.device("/CPU:0"):
    MAX_GROUP_ID_NUMBER = 10
    SAMPLE_SIZE = 1000
    target = tf.random.uniform([SAMPLE_SIZE], maxval=2, dtype=tf.int32)
    feat_grpby = tf.random.uniform(
        [SAMPLE_SIZE], minval=1, maxval=MAX_GROUP_ID_NUMBER + 1, dtype=tf.int32
    )
    feat_grpby_ohe = tf.one_hot(feat_grpby, depth=tf.reduce_max(feat_grpby))
    feat_grouped = tf.random.normal([SAMPLE_SIZE, 2])

    explains = tf.concat([feat_grouped, feat_grpby_ohe], axis=1)
    target_w_feat_grpby = tf.stack([target, feat_grpby], axis=1)


In [16]:
from random import random

group_weight = {k: 30.0 * random() for k in range(1, MAX_GROUP_ID_NUMBER + 1)}
group_weight


{1: 15.422958906298708,
 2: 28.35185428831561,
 3: 6.526315065849767,
 4: 9.755581898876136,
 5: 17.092259540150007,
 6: 1.2714226403896756,
 7: 24.640910300692433,
 8: 20.271732513558447,
 9: 2.705301692460056,
 10: 7.096694648754539}

## Custom Loss

NOTE:

微分可能な処理でなければならない．
勾配情報を保持できるかは以下を参照のこと

- https://www.tensorflow.org/api_docs/python/tf/raw_ops/
- https://stackoverflow.com/a/44575034


In [17]:
from typing import Optional

from tensorflow.keras.losses import Loss, binary_crossentropy as bce
import pandas as pd


class GroupWeightedBinaryCrossentropy(Loss):
    def __init__(self, group_weight: dict, name="GroupWeightedBCE"):
        super().__init__(name=name)
        self.name = name

        if not isinstance(group_weight, dict):
            errmsg = "For the feature column to be grouped, "
            errmsg += "give the weights of each ID as arguments in dict format."
            raise TypeError(errmsg)
        self.group_weight = group_weight

        self.avg_loss_grpby = {
            k: tf.convert_to_tensor(0, dtype=tf.float32) for k in group_weight.keys()
        }
        self.global_avg_loss = tf.convert_to_tensor(0, dtype=tf.float32)
        self.batch_cnt = tf.convert_to_tensor(0, dtype=tf.float32) 

    def _update_loss_for_each_group(self, y_true_w_id, y_pred):
        y_true = y_true_w_id[:, 0]
        feat_id = y_true_w_id[:, 1]

        y_pred = tf.convert_to_tensor(y_pred)
        y_pred = tf.cast(y_pred, tf.float32)
        y_pred = tf.reshape(y_pred, shape=y_true.shape)
        y_true = tf.cast(y_true, y_pred.dtype)

        # Group内，IDごとの平均lossを計算
        for k in self.group_weight.keys():
            selected_indices = tf.where(tf.equal(feat_id, k))
            # NaN対策（IDが存在しない場合）
            if selected_indices.shape[0] < 1:
                continue
            y_pred_grpby = tf.gather_nd(y_pred, selected_indices)
            y_true_grpby = tf.gather_nd(y_true, selected_indices)
            y_pred_grpby = tf.reshape(y_pred_grpby, shape=[1, -1])
            y_true_grpby = tf.reshape(y_true_grpby, shape=[1, -1])
            loss_for_group = bce(y_true_grpby, y_pred_grpby)[0]

            multiplied = tf.multiply(self.avg_loss_grpby[k], self.batch_cnt)
            numerator = tf.add(multiplied, loss_for_group)
            denominator = tf.add(self.batch_cnt, 1)
            self.avg_loss_grpby[k] = tf.divide(numerator, denominator)

        self.batch_cnt = tf.add(self.batch_cnt, 1)

    def _update_global_loss(self, avg_loss_grpby, group_weight: dict):
        sum_loss_weighted = tf.convert_to_tensor(0, dtype=tf.float32)
        denominator = tf.convert_to_tensor(0, dtype=tf.float32)
        for k, avg_loss in avg_loss_grpby.items():
            weight = group_weight[k]
            multiplied = tf.multiply(avg_loss, weight)
            sum_loss_weighted = tf.add(sum_loss_weighted, multiplied)
            denominator = tf.add(denominator, weight)

        self.global_avg_loss = tf.divide(sum_loss_weighted, denominator)

    def call(self, y_true_w_id, y_pred):
        # model.compile, fit時の内部挙動でエラーが吐きそうなため
        # y_trueはdictではなくarray likeに与える
        self._update_loss_for_each_group(y_true_w_id, y_pred)
        self._update_global_loss(self.avg_loss_grpby, self.group_weight)

        return self.global_avg_loss


In [18]:
# 計算値確認
tmp = GroupWeightedBinaryCrossentropy(group_weight=group_weight)
y_pred = tf.random.normal(shape=[target_w_feat_grpby.shape[0]])
global_weighted_loss = tmp.call(target_w_feat_grpby, y_pred)
global_weighted_loss


<tf.Tensor: shape=(), dtype=float32, numpy=5.236294>

In [19]:
# 重み付け確認
wo_weight = np.mean([v.numpy() for v in tmp.avg_loss_grpby.values()])

assert wo_weight != global_weighted_loss


## 適当なモデルを用意

In [20]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(1, activation=tf.nn.softmax)

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)


In [21]:
model = MyModel()
# TODO: 適当にTrainする
model.compile(
    loss=GroupWeightedBinaryCrossentropy(group_weight=group_weight),
    optimizer="adam",
    run_eagerly=True,
)


## 学習

### Train-test split

In [22]:
from sklearn.model_selection import train_test_split


train_indices, valid_indices = train_test_split(np.arange(SAMPLE_SIZE), train_size=0.9)

with tf.device("/CPU:0"):
    X_train = tf.gather(explains, train_indices, axis=0)
    X_valid = tf.gather(explains, valid_indices, axis=0)
    y_train_w_feat_grpby = tf.gather(target_w_feat_grpby, train_indices, axis=0)
    y_valid_w_feat_grpby = tf.gather(target_w_feat_grpby, valid_indices, axis=0)


## Fit

In [23]:
with tf.device("/CPU:0"):
    model.fit(
        X_train,
        y_train_w_feat_grpby,
        validation_data=(X_valid, y_valid_w_feat_grpby),
        epochs=1,
        batch_size=32,
    )


