# Prepare


In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [4]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

In [107]:
import tensorflow as tf
import tensorflow.keras.backend as K


class NegativePenaltySparseCategoricalCrossentropy(tf.keras.losses.Loss):
    def __init__(self, class_num:int, p_indices:list, alpha=1.0, penalty_scale=None, \
                 reduction=tf.keras.losses.Reduction.AUTO, \
                 name='negative_penalty_sparse_categorical_crossentropy'):
        super(NegativePenaltySparseCategoricalCrossentropy, self).__init__(reduction=reduction, name=name)
        self.p_indices = [[p_index] for p_index in p_indices]
        self.alpha = alpha
        self.penalty_scale = float(len(p_indices)) if penalty_scale is None else penalty_scale
        self.penalty_label = _get_penalty_label(class_num, p_indices)

    def call(self, y_true, y_pred):
        num_classes = y_pred.shape[-1]
        y_true = tf.squeeze(tf.one_hot(y_true, num_classes), axis=1)
        losses = _get_losses(y_true, y_pred, self.p_indices, self.penalty_label, self.alpha, self.penalty_scale)
        return losses


def _get_losses(y_true, y_pred, p_indices:list, penalty_label:list, alpha:float, penalty_scale:float):
    batch_size = 64
    y_true = tf.cast(y_true, tf.float32)
    # cce loss part for positive samples
    cce_loss_sample_weights = tf.cast(
        tf.reduce_any(
            tf.transpose(tf.math.equal(tf.math.argmax(y_true, axis=1), p_indices), perm=(1, 0)), axis=1
        ),
        dtype=tf.float32
    )
    cce_losses = K.categorical_crossentropy(y_true, y_pred, from_logits=False)  # shape: (batch_size,)
    cce_losses = cce_loss_sample_weights * cce_losses
    # penalty loss part for negative samples
    y_penalty = tf.repeat(tf.expand_dims(tf.constant(penalty_label), axis=0), batch_size, axis=0)
    y_penalty = tf.cast(y_penalty, tf.float32)
    penalty_loss_sample_weights = tf.where(cce_loss_sample_weights == 1.0, 0.0, 1.0)  # 1.0: negative sample, 0.0: postive sample
    # option 1
    # penalty_losses = 1 / K.categorical_crossentropy(y_penalty, y_pred, from_logits=False)  # shape: (batch_size,)
    # option 2
    # penalty_losses = K.categorical_crossentropy(y_penalty, 1.0 - y_pred, from_logits=False)  # shape: (batch_size,)
    # option 3
    penalty_losses = -tf.math.reduce_sum(
        y_penalty * tf.math.log(tf.clip_by_value(1.0 - y_pred, K.epsilon(), 1.0 - K.epsilon())),
        axis=-1
    )
    # scale penalty_losses
    penalty_losses = penalty_losses / penalty_scale
    penalty_losses = penalty_loss_sample_weights * penalty_losses
    # total loss
    losses = cce_losses
    return losses


def _get_penalty_label(class_num:int, p_indices:list):
    penalty_label = [1 if i in p_indices else 0 for i in range(0, class_num)]
    return penalty_label

In [83]:
class NegativePenaltySparseCategoricalAccuracy(tf.keras.metrics.Metric):
    def __init__(self, p_indices:list, name='negative_penalty_sparse_categorical_accuracy', **kwargs):
        super().__init__(name=name, **kwargs)
        self.p_indices = [[p_index] for p_index in p_indices]

    def update_state(self, y_true, y_pred, sample_weight=None):
        num_classes = y_pred.shape[-1]
        y_true = tf.squeeze(tf.one_hot(y_true, num_classes), axis=1)
        self.accuracy = _get_accuracy(y_true, y_pred, sample_weight, self.p_indices)

    def result(self):
        return self.accuracy


def _get_accuracy(y_true, y_pred, sample_weight, p_indices:list):
    batch_size = y_true.shape[0]
    y_true = tf.cast(y_true, tf.float32)
    # compute accuracy for positive samples in a batch
    positive_sample_weights = tf.cast(
        tf.reduce_any(
            tf.transpose(tf.math.equal(tf.math.argmax(y_true, axis=1), p_indices), perm=(1, 0)), axis=1
        ),
        dtype=tf.float32
    )  # 1.0: postive sample, 0.0: negative sample
    positive_sample_values = tf.where(
        tf.math.argmax(y_true, axis=1) ==  tf.math.argmax(y_pred, axis=1), 1.0, 0.0
    )
    positive_sample_values = positive_sample_weights * positive_sample_values
    # compute accuracy for negative samples in a batch
    negative_sample_weights = tf.where(positive_sample_weights == 1.0, 0.0, 1.0)  # 1.0: negative sample, 0.0: postive sample
    negative_sample_values = tf.cast(
        tf.reduce_all(
            tf.transpose(tf.math.not_equal(tf.math.argmax(y_pred, axis=1), p_indices), perm=(1, 0)), axis=1
        ),
        dtype=tf.float32
    )
    negative_sample_values = negative_sample_weights * negative_sample_values
    # combine positive values and negative values
    positive_sample_values = tf.cast(positive_sample_values, tf.bool)
    negative_sample_values = tf.cast(negative_sample_values, tf.bool)
    values = tf.math.logical_or(positive_sample_values, negative_sample_values)
    values = tf.cast(values, tf.float32)
    if sample_weight is not None:
        values = tf.math.multiply(values, tf.squeeze(sample_weight, axis=1))
    else:
        sample_weight = tf.repeat([[1.0]], batch_size, axis=0)
        values = tf.math.multiply(values, tf.squeeze(sample_weight, axis=1))
    accuracy = tf.math.reduce_sum(values, axis=None) / tf.math.reduce_sum(sample_weight, axis=None)
    return accuracy

In [108]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [109]:
ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(64, drop_remainder=True)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(64, drop_remainder=True)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

# Simply train model on mnist dataset with normal categorical crossentroy loss

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(20, activation='softmax')  # set to 20 not 10
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=2,
    validation_data=ds_test,
)

Epoch 1/2
Epoch 2/2


In [9]:
model.evaluate(ds_test)



[0.08085029572248459, 0.9753000140190125]

# Train model on mnist dataset with the proposed 'NegativePenaltySparseCategoricalCrossentropy' loss function

In [110]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(20, activation='softmax')  # set to 20 not 10
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=NegativePenaltySparseCategoricalCrossentropy(class_num=20, p_indices=[0, 1, 2, 3, 4, 5, 6, 7]),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=2,
    validation_data=ds_test,
)

Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x79266fc1b5b0>

# Conclusion: Accuracy dropped to approximately 80% (79.60%) which means the proposed 'NegativePenaltySparseCategoricalCrossentropy' loss function workd as expected

In [111]:
model.evaluate(ds_test)



[0.01796453818678856, 0.795973539352417]