# Prepare


In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

def print_confusion_matrix(model, testset, num_classes:int):
  predictions = []
  labels = []
  for batch in testset:
    x, y = batch
    predictions += list(tf.math.argmax(model.predict(x, verbose=0), axis=1).numpy())
    labels += list(y.numpy())
  print('confusion matrix: ')
  print(tf.math.confusion_matrix(labels, predictions, num_classes=num_classes).numpy())

In [3]:
import tensorflow as tf
import tensorflow.keras.backend as K


class NegativePenaltySparseCategoricalCrossentropy(tf.keras.losses.Loss):
    def __init__(self, class_num:int, p_indices:list, alpha=1.0, penalty_scale=None, \
                 reduction=tf.keras.losses.Reduction.AUTO, \
                 name='negative_penalty_sparse_categorical_crossentropy'):
        super(NegativePenaltySparseCategoricalCrossentropy, self).__init__(reduction=reduction, name=name)
        self.p_indices = [[p_index] for p_index in p_indices]
        self.alpha = alpha
        self.penalty_scale = float(len(p_indices)) if penalty_scale is None else penalty_scale
        self.penalty_label = _get_penalty_label(class_num, p_indices)

    def call(self, y_true, y_pred):
        num_classes = y_pred.shape[-1]
        y_true = tf.squeeze(tf.one_hot(y_true, num_classes), axis=1)
        losses = _get_losses(y_true, y_pred, self.p_indices, self.penalty_label, self.alpha, self.penalty_scale)
        return losses


def _get_losses(y_true, y_pred, p_indices:list, penalty_label:list, alpha:float, penalty_scale:float):
    batch_size = y_true.shape[0]
    y_true = tf.cast(y_true, tf.float32)
    # cce loss part for positive samples
    cce_loss_sample_weights = tf.cast(
        tf.reduce_any(
            tf.transpose(tf.math.equal(tf.math.argmax(y_true, axis=1), p_indices), perm=(1, 0)), axis=1
        ),
        dtype=tf.float32
    )
    cce_losses = K.categorical_crossentropy(y_true, y_pred, from_logits=False)  # shape: (batch_size,)
    cce_losses = cce_loss_sample_weights * cce_losses
    # penalty loss part for negative samples
    y_penalty = tf.repeat(tf.expand_dims(tf.constant(penalty_label), axis=0), batch_size, axis=0)
    y_penalty = tf.cast(y_penalty, tf.float32)
    penalty_loss_sample_weights = tf.where(cce_loss_sample_weights == 1.0, 0.0, 1.0)  # 1.0: negative sample, 0.0: postive sample
    # option 1
    # penalty_losses = 1 / K.categorical_crossentropy(y_penalty, y_pred, from_logits=False)  # shape: (batch_size,)
    # option 2
    # penalty_losses = K.categorical_crossentropy(y_penalty, 1.0 - y_pred, from_logits=False)  # shape: (batch_size,)
    # option 3
    penalty_losses = -tf.math.reduce_sum(
        y_penalty * tf.math.log(tf.clip_by_value(1.0 - y_pred, K.epsilon(), 1.0 - K.epsilon())),
        axis=-1
    )
    # scale penalty_losses
    penalty_losses = penalty_losses / penalty_scale
    penalty_losses = penalty_loss_sample_weights * penalty_losses
    # total loss
    losses = cce_losses + alpha * penalty_losses
    return losses


def _get_penalty_label(class_num:int, p_indices:list):
    penalty_label = [1 if i in p_indices else 0 for i in range(0, class_num)]
    return penalty_label

In [4]:
class NegativePenaltySparseCategoricalAccuracy(tf.keras.metrics.Metric):
    def __init__(self, p_indices:list, name='negative_penalty_sparse_categorical_accuracy', **kwargs):
        super().__init__(name=name, **kwargs)
        self.p_indices = [[p_index] for p_index in p_indices]

    def update_state(self, y_true, y_pred, sample_weight=None):
        num_classes = y_pred.shape[-1]
        y_true = tf.squeeze(tf.one_hot(y_true, num_classes), axis=1)
        self.accuracy = _get_accuracy(y_true, y_pred, sample_weight, self.p_indices)

    def result(self):
        return self.accuracy


def _get_accuracy(y_true, y_pred, sample_weight, p_indices:list):
    batch_size = y_true.shape[0]
    y_true = tf.cast(y_true, tf.float32)
    # compute accuracy for positive samples in a batch
    positive_sample_weights = tf.cast(
        tf.reduce_any(
            tf.transpose(tf.math.equal(tf.math.argmax(y_true, axis=1), p_indices), perm=(1, 0)), axis=1
        ),
        dtype=tf.float32
    )  # 1.0: postive sample, 0.0: negative sample
    positive_sample_values = tf.where(
        tf.math.argmax(y_true, axis=1) ==  tf.math.argmax(y_pred, axis=1), 1.0, 0.0
    )
    positive_sample_values = positive_sample_weights * positive_sample_values
    # compute accuracy for negative samples in a batch
    negative_sample_weights = tf.where(positive_sample_weights == 1.0, 0.0, 1.0)  # 1.0: negative sample, 0.0: postive sample
    negative_sample_values = tf.cast(
        tf.reduce_all(
            tf.transpose(tf.math.not_equal(tf.math.argmax(y_pred, axis=1), p_indices), perm=(1, 0)), axis=1
        ),
        dtype=tf.float32
    )
    negative_sample_values = negative_sample_weights * negative_sample_values
    # combine positive values and negative values
    positive_sample_values = tf.cast(positive_sample_values, tf.bool)
    negative_sample_values = tf.cast(negative_sample_values, tf.bool)
    values = tf.math.logical_or(positive_sample_values, negative_sample_values)
    values = tf.cast(values, tf.float32)
    if sample_weight is not None:
        values = tf.math.multiply(values, tf.squeeze(sample_weight, axis=1))
    else:
        sample_weight = tf.repeat([[1.0]], batch_size, axis=0)
        values = tf.math.multiply(values, tf.squeeze(sample_weight, axis=1))
    accuracy = tf.math.reduce_sum(values, axis=None) / tf.math.reduce_sum(sample_weight, axis=None)
    return accuracy

In [5]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.


In [6]:
ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(32, drop_remainder=True)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(32, drop_remainder=True)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

# Simply train model on mnist dataset with normal categorical crossentroy loss

In [7]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(14, activation='softmax')  # set to 14 not 10
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=3,
    validation_data=ds_test,
)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x7f43ba7ed1b0>

In [8]:
model.evaluate(ds_test)



[0.06670098006725311, 0.9779647588729858]

plot confusion matrix

In [9]:
print_confusion_matrix(model, ds_test, num_classes=14)

confusion matrix: 
[[ 968    0    1    0    0    1    3    1    4    0    0    0    0    0]
 [   2 1120    1    1    0    1    4    1    3    0    0    0    0    0]
 [   2    3 1018    1    0    0    1    2    3    0    0    0    0    0]
 [   1    0    5  983    0    8    0    1    9    2    0    0    0    0]
 [   0    0    3    1  947    0   12    1    2   14    0    0    0    0]
 [   2    0    0    3    0  877    3    1    2    0    0    0    0    0]
 [   3    2    0    0    1    2  948    0    2    0    0    0    0    0]
 [   2    2   19    2    1    0    0  975    3   23    0    0    0    0]
 [   6    0    2    1    1    4    2    0  952    5    0    0    0    0]
 [   3    4    3    4    5    4    0    3    6  976    0    0    0    0]
 [   0    0    0    0    0    0    0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0 

# Train model on mnist dataset with the proposed 'NegativePenaltySparseCategoricalCrossentropy' loss function

In [11]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(14, activation='softmax')  # set to 14 not 10
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=NegativePenaltySparseCategoricalCrossentropy(class_num=14, p_indices=[0, 1, 2, 3, 4, 5, 6, 7]),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(), NegativePenaltySparseCategoricalAccuracy(p_indices=[0, 1, 2, 3, 4, 5, 6, 7])],
)

model.fit(
    ds_train,
    epochs=5,
    validation_data=ds_test,
)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x7f43b9504eb0>

# Conclusion:
#### 1. Accuracy dropped to approximately 80% (79.72%) indicates lots of images belongs to '8' and '9' were not classified as '8' or '9' as before.
#### 2. Most images belongs to '8' and '9' were not classified as '0', '1', '2', '3', '4', '5', '6' or '7' which is exactly what we expected.
## Above 1. and 2. proved that the proposed 'NegativePenaltySparseCategoricalCrossentropy' loss function workd as expected

In [12]:
model.evaluate(ds_test)



[0.01882198266685009, 0.7971754670143127, 0.96875]

plot confusion matrix

In [13]:
print_confusion_matrix(model, ds_test, num_classes=14)

confusion matrix: 
[[ 976    0    0    0    0    0    1    1    0    0    0    0    0    0]
 [   0 1130    1    0    0    0    0    2    0    0    0    0    0    0]
 [   0    0 1024    0    1    0    1    4    0    0    0    0    0    0]
 [   0    0    0 1004    0    1    0    4    0    0    0    0    0    0]
 [   0    0    0    0  978    0    0    2    0    0    0    0    0    0]
 [   0    0    1    8    0  876    1    2    0    0    0    0    0    0]
 [   2    3    0    0    2    2  949    0    0    0    0    0    0    0]
 [   0    1    4    0    0    0    0 1022    0    0    0    0    0    0]
 [   4    4   10   19   11    4    2   11    0    0    0    0  908    0]
 [   1    3    3    7   19    3    0   32    0    0    0    0  940    0]
 [   0    0    0    0    0    0    0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0 