This notebook contains testing for the various custom metrics.

In [None]:
import time
import sys
import os
import glob
import math
import threading
import concurrent.futures as cf
import random
import re

import numpy as np
import pandas as pd
import tensorflow as tf
from keras import Input, Model, layers, metrics, losses, callbacks, optimizers, models, utils
from keras import backend as K
import gc
import keras_tuner as kt
from pyfaidx import Fasta

K.clear_session()
gc.collect()

datasets_path = "../../Datasets/"
models_path = "../../Models/"

2025-02-26 01:30:58.996611: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-26 01:30:59.185169: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-26 01:30:59.239137: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-26 01:30:59.609060: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
@utils.register_keras_serializable()
class CustomNoBackgroundF1Score(metrics.Metric):
    def __init__(self, num_classes, average='weighted', threshold=0.5, name='no_background_f1', **kwargs):
        """
        Custom F1 score metric that only considers non-dominant classes (ignoring index 0).
        
        This version is designed for multi-encoded labels where:
          - The dominant class (index 0) is represented as a hard label [1, 0, 0, ...]
          - For non-dominant classes (indices 1 to num_classes-1), only an exact label of 1 is considered positive.
            (Any partial credit/smoothed values below 1 are treated as 0.)
          - Predictions are thresholded (default threshold = 0.5) to decide 1 vs. 0.
        
        Args:
            num_classes (int): Total number of classes.
            average (str): 'weighted' (default) to weight by support or 'macro' for a simple average.
            threshold (float): Threshold on y_pred to decide a positive (default 0.5).
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNoBackgroundF1Score, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average

        # Create state variables to accumulate counts for each class.
        # We use a vector of length num_classes but we will update only indices 1...num_classes-1.
        self.true_positives = self.add_weight(
            name='tp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_positives = self.add_weight(
            name='fp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_negatives = self.add_weight(
            name='fn', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        Updates the metric state.
        
        Args:
            y_true: Tensor of shape (batch_size, num_classes). These are multi-encoded labels.
                    For non-dominant classes, a label is considered positive only if it is exactly 1.
            y_pred: Tensor of shape (batch_size, num_classes) with predictions (e.g. probabilities).
            sample_weight: Optional sample weights.
        """
        
        # Flatten all dimensions except the last one (which should be num_classes).
        y_true = tf.reshape(y_true, [-1, self.num_classes])
        y_pred = tf.reshape(y_pred, [-1, self.num_classes])
        
        # We want to ignore the dominant class (index 0) and work on classes 1...num_classes-1.
        # Assume y_true and y_pred are both of shape (batch_size, num_classes).
        y_true_non_dominant = y_true[:, 1:]
        y_pred_non_dominant = y_pred[:, 1:]
        
        # For ground truth: treat a class as positive only if its value is exactly 1.
        one_value = tf.cast(1.0, dtype=y_true_non_dominant.dtype)
        y_true_bin = tf.cast(tf.equal(y_true_non_dominant, one_value), tf.int32)
        # For predictions: apply thresholding.
        y_pred_bin = tf.cast(y_pred_non_dominant >= self.threshold, tf.int32)
        
        # (Optionally) apply sample weighting.
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.int32)
            sample_weight = tf.reshape(sample_weight, (-1, 1))
            y_true_bin = y_true_bin * sample_weight
            y_pred_bin = y_pred_bin * sample_weight
        
        # Compute per-class true positives, false positives, and false negatives for non-dominant classes.
        tp = tf.reduce_sum(tf.cast(y_true_bin * y_pred_bin, tf.float32), axis=0)
        fp = tf.reduce_sum(tf.cast((1 - y_true_bin) * y_pred_bin, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true_bin * (1 - y_pred_bin), tf.float32), axis=0)
        
        # Our state variables have length num_classes. We want to update only indices 1... with our computed values.
        zeros = tf.zeros([1], dtype=tf.float32)
        tp_update = tf.concat([zeros, tp], axis=0)
        fp_update = tf.concat([zeros, fp], axis=0)
        fn_update = tf.concat([zeros, fn], axis=0)
        
        self.true_positives.assign_add(tp_update)
        self.false_positives.assign_add(fp_update)
        self.false_negatives.assign_add(fn_update)

    def result(self):
        """
        Computes the F1 score over the non-dominant classes (indices 1...num_classes-1).
        """
        # Select non-dominant classes only.
        tp = self.true_positives[1:]
        fp = self.false_positives[1:]
        fn = self.false_negatives[1:]
        
        precision = tf.math.divide_no_nan(tp, tp + fp)
        recall = tf.math.divide_no_nan(tp, tp + fn)
        f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall)
        
        if self.average == 'weighted':
            support = tp + fn
            weighted_f1 = tf.reduce_sum(f1 * support) / (tf.reduce_sum(support) + K.epsilon())
            return weighted_f1
        else:  # macro
            return tf.reduce_mean(f1)

    def reset_states(self):
        """
        Resets all of the metric state variables.
        """
        for v in self.variables:
            v.assign(tf.zeros_like(v))
            
    def get_config(self):
        """
        Returns the configuration of the metric, so it can be recreated later.
        """
        config = super(CustomNoBackgroundF1Score, self).get_config()
        config.update({
            'num_classes': self.num_classes,
            'average': self.average,
            'threshold': self.threshold,
        })
        return config

@utils.register_keras_serializable()
class CustomConditionalF1Score(metrics.Metric):
    def __init__(self, threshold=0.5, average='weighted', filter_mode='either', name='conditional_f1', **kwargs):
        """
        Custom F1 score metric that computes the F1 score only for target columns (columns 1-4).
        Additionally, only rows meeting a filtering criterion are included in the calculation.
        
        Args:
            threshold (float): Threshold on y_pred to decide a positive (default = 0.5).
            average (str): 'weighted' (default) to weight by support or 'macro' for a simple average.
            filter_mode (str): Determines which rows to include based on the target columns.
                               Options:
                                  - 'pred': Only rows where y_pred (after thresholding) has at least one 1.
                                  - 'true': Only rows where y_true (exactly equal to 1) has at least one 1.
                                  - 'either': Rows where either y_true or y_pred has at least one 1.
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        
        Note:
            This metric only tracks columns 1-4 (0-indexed). Column 0 (the dominant background class)
            is ignored completely.
        """
        metric_name = f'{name}_{filter_mode}'
        
        super(CustomConditionalF1Score, self).__init__(name=metric_name, **kwargs)
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average
        
        if filter_mode not in ['pred', 'true', 'either']:
            raise ValueError("filter_mode must be 'pred', 'true', or 'either'")
        self.filter_mode = filter_mode
        
        # We are tracking only 4 target columns (columns 1 to 4).
        self.num_target_columns = 4
        self.true_positives = self.add_weight(
            name='tp', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )
        self.false_positives = self.add_weight(
            name='fp', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )
        self.false_negatives = self.add_weight(
            name='fn', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Reshape inputs so that the last dimension is the number of classes.
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        
        # Only consider columns 1-4 (ignoring index 0).
        y_true_subset = y_true[:, 1:5]
        y_pred_subset = y_pred[:, 1:5]
        
        # For ground truth, treat a label as positive only if its value is exactly 1.
        y_true_bin = tf.cast(tf.equal(y_true_subset, 1.0), tf.int32)
        # For predictions, apply the threshold to decide 1 vs. 0.
        y_pred_bin = tf.cast(y_pred_subset >= self.threshold, tf.int32)
        
        # Compute a row-level mask based on the filter_mode.
        if self.filter_mode == 'pred':
            mask = tf.reduce_any(tf.equal(y_pred_bin, 1), axis=1)
        elif self.filter_mode == 'true':
            mask = tf.reduce_any(tf.equal(y_true_bin, 1), axis=1)
        else:  # 'either'
            mask = tf.logical_or(
                tf.reduce_any(tf.equal(y_pred_bin, 1), axis=1),
                tf.reduce_any(tf.equal(y_true_bin, 1), axis=1)
            )
        
        # Apply the mask so only selected rows are used for the metric update.
        y_true_filtered = tf.boolean_mask(y_true_bin, mask)
        y_pred_filtered = tf.boolean_mask(y_pred_bin, mask)
        
        # Optionally apply sample weighting.
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.reshape(sample_weight, [-1, 1])
            y_true_filtered = y_true_filtered * sample_weight
            y_pred_filtered = y_pred_filtered * sample_weight
        
        # Compute per-column true positives, false positives, and false negatives.
        tp = tf.reduce_sum(tf.cast(y_true_filtered * y_pred_filtered, tf.float32), axis=0)
        fp = tf.reduce_sum(tf.cast((1 - y_true_filtered) * y_pred_filtered, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true_filtered * (1 - y_pred_filtered), tf.float32), axis=0)
        
        self.true_positives.assign_add(tp)
        self.false_positives.assign_add(fp)
        self.false_negatives.assign_add(fn)

    def result(self):
        precision = tf.math.divide_no_nan(self.true_positives, self.true_positives + self.false_positives)
        recall = tf.math.divide_no_nan(self.true_positives, self.true_positives + self.false_negatives)
        f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall)
        
        if self.average == 'weighted':
            support = self.true_positives + self.false_negatives
            return tf.reduce_sum(f1 * support) / (tf.reduce_sum(support) + K.epsilon())
        else:  # 'macro'
            return tf.reduce_mean(f1)

    def reset_states(self):
        for v in self.variables:
            v.assign(tf.zeros_like(v))
            
    def get_config(self):
        config = super(CustomConditionalF1Score, self).get_config()
        config.update({
            'threshold': self.threshold,
            'average': self.average,
            'filter_mode': self.filter_mode,
        })
        return config


@utils.register_keras_serializable()
class CustomFalsePositiveDistance(metrics.Metric):
    def __init__(self, num_classes, threshold=0.5, window=100, name='false_positive_distance', **kwargs):
        """
        Metric that accumulates a running average “distance” error for false positive predictions,
        ignoring the dominant (background) class (index 0).

        For each false positive (i.e. a prediction >= threshold when the strict label is not 1),
        the distance is computed from the raw label value (which encodes proximity to an actual annotation)
        as follows:

            distance = 1 + ((max_credit - v) * (window / max_credit))

        where:
            - v is the raw label value at that position,
            - max_credit is the maximum smoothing credit (0.5 in our scheme), so that if v == 0.5 the distance is 1,
              and if v == 0 the distance is 1 + window (i.e. 101 for window=100).

        Args:
            num_classes (int): Total number of classes.
            threshold (float): Threshold on y_pred to decide a positive.
            window (int): Window size used in the smoothing scheme.
            name (str): Name of the metric.
        """
        super(CustomFalsePositiveDistance, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.threshold = threshold
        self.window = float(window)
        self.max_credit = 0.5  # Based on your smoothing scheme.

        # State variables to accumulate total distance and count of false positives.
        self.total_distance = self.add_weight(
            name='total_distance', initializer='zeros', dtype=tf.float32
        )
        self.false_positive_count = self.add_weight(
            name='false_positive_count', initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        For non-dominant classes (indices 1:), this method:
          - thresholds predictions,
          - identifies false positives (prediction is positive while strict label != 1),
          - computes the distance error from the raw (smoothed) label value, and
          - accumulates the sum of distances and count of false positives.
        """
        # Ensure shape (batch_size, num_classes)
        y_true = tf.reshape(y_true, [-1, self.num_classes])
        y_pred = tf.reshape(y_pred, [-1, self.num_classes])

        # Ignore the dominant/background class (index 0)
        y_true_non = y_true[:, 1:]
        y_pred_non = y_pred[:, 1:]

        # Threshold predictions
        y_pred_bin = tf.cast(y_pred_non >= self.threshold, tf.float32)

        # For strict classification, a label is positive only if it is exactly 1.
        # So a false positive is when y_pred_bin==1 but y_true (strict) is not 1.
        # (This is similar to your F1 metric, i.e. smoothing values are treated as negatives.)
        false_positive_mask = tf.logical_and(
            tf.equal(y_pred_bin, 1.0),
            tf.not_equal(y_true_non, 1.0)
        )
        false_positive_mask = tf.cast(false_positive_mask, tf.float32)

        # Compute distance per element.
        # In our smoothing scheme:
        #   - At a true annotation (v = 1), we wouldn’t count a false positive.
        #   - In a smoothed region, the maximum credit is 0.5.
        #   - We define:
        #       distance = 1 + ((max_credit - v) * (window / max_credit))
        #     so that if v == 0.5, distance = 1, and if v == 0, distance = 1 + window.
        distance = 1.0 + (self.max_credit - y_true_non) * (self.window / self.max_credit)
        distance = tf.where(distance >= 101.0, tf.constant(125.0, dtype=distance.dtype), distance)

        # Only include entries that are false positives.
        false_positive_distance = distance * false_positive_mask

        # Sum distances and count false positives.
        sum_distance = tf.reduce_sum(false_positive_distance)
        count = tf.reduce_sum(false_positive_mask)

        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.reshape(sample_weight, [-1, 1])
            sum_distance = tf.reduce_sum(false_positive_distance * sample_weight)
            count = tf.reduce_sum(false_positive_mask * sample_weight)

        self.total_distance.assign_add(sum_distance)
        self.false_positive_count.assign_add(count)

    def result(self):
        """Returns the average distance error over all false positives (or 0 if none)."""
        return tf.math.divide_no_nan(self.total_distance, self.false_positive_count)

    def reset_states(self):
        """Resets the accumulated total distance and count."""
        self.total_distance.assign(0.0)
        self.false_positive_count.assign(0.0)

    def get_config(self):
        config = super(CustomFalsePositiveDistance, self).get_config()
        config.update({
            'num_classes': self.num_classes,
            'threshold': self.threshold,
            'window': self.window,
        })
        return config

@utils.register_keras_serializable()
class CustomNoBackgroundAUC(metrics.Metric):
    def __init__(self, curve='PR', name='no_background_auc', **kwargs):
        """
        Custom AUC metric computed only for columns 1-4.

        Args:
            curve (str): The type of AUC curve to use, e.g. 'ROC' (default) or 'PR'.
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNoBackgroundAUC, self).__init__(name=name, **kwargs)
        # Store the curve parameter as a string to aid serialization.
        self.curve = curve  
        # Create one AUC metric per target column (columns 1-4).
        self.auc_metrics = [
            metrics.AUC(curve=self.curve, name=f'auc_col_{i+1}')
            for i in range(4)
        ]

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Ensure inputs are 2D tensors with shape (batch_size, num_classes).
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        # Select target columns (1-4) and ignore background (column 0).
        y_true_subset = y_true[:, 1:5]
        y_pred_subset = y_pred[:, 1:5]
        # For each target column, update the corresponding AUC metric.
        for i, auc_metric in enumerate(self.auc_metrics):
            # Ground truth: positive only if exactly equal to 1.
            y_true_col = tf.cast(tf.equal(y_true_subset[:, i], 1.0), tf.float32)
            y_pred_col = y_pred_subset[:, i]
            if sample_weight is not None:
                sample_weight = tf.reshape(sample_weight, [-1])
                auc_metric.update_state(y_true_col, y_pred_col, sample_weight=sample_weight)
            else:
                auc_metric.update_state(y_true_col, y_pred_col)

    def result(self):
        # Average AUC over all target columns.
        auc_results = [auc_metric.result() for auc_metric in self.auc_metrics]
        return tf.reduce_mean(auc_results)

    def reset_states(self):
        for auc_metric in self.auc_metrics:
            auc_metric.reset_states()

    def get_config(self):
        config = super(CustomNoBackgroundAUC, self).get_config()
        # Return the curve as a string.
        config.update({
            'curve': self.curve,
        })
        return config

@utils.register_keras_serializable()
class CustomNoBackgroundAccuracy(metrics.Metric):
    def __init__(self, threshold=0.5, name='no_background_accuracy', **kwargs):
        """
        Custom accuracy metric computed only for columns 1-4.

        Args:
            threshold (float): Threshold for y_pred (default 0.5).
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNoBackgroundAccuracy, self).__init__(name=name, **kwargs)
        self.threshold = threshold
        self.total_correct = self.add_weight(name='total_correct', initializer='zeros', dtype=tf.float32)
        self.total_count = self.add_weight(name='total_count', initializer='zeros', dtype=tf.float32)

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Reshape inputs to 2D tensors.
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        # Extract columns 1-4.
        y_true_subset = y_true[:, 1:5]
        y_pred_subset = y_pred[:, 1:5]
        # Binarize ground truth: positive if exactly 1.
        y_true_bin = tf.cast(tf.equal(y_true_subset, 1.0), tf.int32)
        # Binarize predictions using the threshold.
        y_pred_bin = tf.cast(y_pred_subset >= self.threshold, tf.int32)
        # Element-wise correctness.
        correct = tf.cast(tf.equal(y_true_bin, y_pred_bin), tf.float32)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            # Tile sample weights to match the shape of correct.
            sample_weight = tf.tile(sample_weight, [1, tf.shape(correct)[1]])
            correct = correct * sample_weight
            count = tf.reduce_sum(sample_weight)
        else:
            count = tf.cast(tf.size(correct), tf.float32)
        self.total_correct.assign_add(tf.reduce_sum(correct))
        self.total_count.assign_add(count)

    def result(self):
        return tf.math.divide_no_nan(self.total_correct, self.total_count)

    def reset_states(self):
        self.total_correct.assign(0)
        self.total_count.assign(0)

    def get_config(self):
        config = super(CustomNoBackgroundAccuracy, self).get_config()
        config.update({'threshold': self.threshold})
        return config

@utils.register_keras_serializable()
class CustomNoBackgroundPrecision(metrics.Metric):
    def __init__(self, threshold=0.5, average='weighted', name='no_background_precision', **kwargs):
        """
        Custom precision metric computed only for columns 1-4.

        Args:
            threshold (float): Threshold for y_pred (default 0.5).
            average (str): 'weighted' (default) or 'macro'.
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNoBackgroundPrecision, self).__init__(name=name, **kwargs)
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average
        self.num_target_columns = 4
        self.true_positives = self.add_weight(
            name='tp', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )
        self.false_positives = self.add_weight(
            name='fp', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )
        # For weighted averaging, we also need the support (true positives + false negatives).
        self.false_negatives = self.add_weight(
            name='fn', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Reshape inputs.
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        # Extract target columns (1-4).
        y_true_subset = y_true[:, 1:5]
        y_pred_subset = y_pred[:, 1:5]
        # Binarize ground truth and predictions.
        y_true_bin = tf.cast(tf.equal(y_true_subset, 1.0), tf.int32)
        y_pred_bin = tf.cast(y_pred_subset >= self.threshold, tf.int32)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.tile(sample_weight, [1, tf.shape(y_true_bin)[1]])
            y_true_bin = y_true_bin * tf.cast(sample_weight, tf.int32)
            y_pred_bin = y_pred_bin * tf.cast(sample_weight, tf.int32)
        # Compute counts per column.
        tp = tf.reduce_sum(tf.cast(y_true_bin * y_pred_bin, tf.float32), axis=0)
        fp = tf.reduce_sum(tf.cast((1 - y_true_bin) * y_pred_bin, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true_bin * (1 - y_pred_bin), tf.float32), axis=0)
        self.true_positives.assign_add(tp)
        self.false_positives.assign_add(fp)
        self.false_negatives.assign_add(fn)

    def result(self):
        # Precision: TP / (TP + FP)
        precision = tf.math.divide_no_nan(self.true_positives, self.true_positives + self.false_positives)
        if self.average == 'weighted':
            # Weight each column by its support (TP + FN).
            support = self.true_positives + self.false_negatives
            weighted_precision = tf.reduce_sum(precision * support) / (tf.reduce_sum(support) + K.epsilon())
            return weighted_precision
        else:  # macro
            return tf.reduce_mean(precision)

    def reset_states(self):
        self.true_positives.assign(tf.zeros_like(self.true_positives))
        self.false_positives.assign(tf.zeros_like(self.false_positives))
        self.false_negatives.assign(tf.zeros_like(self.false_negatives))

    def get_config(self):
        config = super(CustomNoBackgroundPrecision, self).get_config()
        config.update({
            'threshold': self.threshold,
            'average': self.average,
        })
        return config


@utils.register_keras_serializable()
class CustomNoBackgroundRecall(metrics.Metric):
    def __init__(self, threshold=0.5, average='weighted', name='no_background_recall', **kwargs):
        """
        Custom recall metric computed only for columns 1-4.

        Args:
            threshold (float): Threshold for y_pred (default 0.5).
            average (str): 'weighted' (default) or 'macro'.
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNoBackgroundRecall, self).__init__(name=name, **kwargs)
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average
        self.num_target_columns = 4
        self.true_positives = self.add_weight(
            name='tp', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )
        self.false_negatives = self.add_weight(
            name='fn', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Reshape inputs.
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        # Extract target columns (1-4).
        y_true_subset = y_true[:, 1:5]
        y_pred_subset = y_pred[:, 1:5]
        # Binarize ground truth and predictions.
        y_true_bin = tf.cast(tf.equal(y_true_subset, 1.0), tf.int32)
        y_pred_bin = tf.cast(y_pred_subset >= self.threshold, tf.int32)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.tile(sample_weight, [1, tf.shape(y_true_bin)[1]])
            y_true_bin = y_true_bin * tf.cast(sample_weight, tf.int32)
            y_pred_bin = y_pred_bin * tf.cast(sample_weight, tf.int32)
        # Compute per-column true positives and false negatives.
        tp = tf.reduce_sum(tf.cast(y_true_bin * y_pred_bin, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true_bin * (1 - y_pred_bin), tf.float32), axis=0)
        self.true_positives.assign_add(tp)
        self.false_negatives.assign_add(fn)

    def result(self):
        # Recall: TP / (TP + FN)
        recall = tf.math.divide_no_nan(self.true_positives, self.true_positives + self.false_negatives)
        if self.average == 'weighted':
            support = self.true_positives + self.false_negatives
            weighted_recall = tf.reduce_sum(recall * support) / (tf.reduce_sum(support) + K.epsilon())
            return weighted_recall
        else:
            return tf.reduce_mean(recall)

    def reset_states(self):
        self.true_positives.assign(tf.zeros_like(self.true_positives))
        self.false_negatives.assign(tf.zeros_like(self.false_negatives))

    def get_config(self):
        config = super(CustomNoBackgroundRecall, self).get_config()
        config.update({
            'threshold': self.threshold,
            'average': self.average,
        })
        return config


In [None]:
# Each test case defines y_true and y_pred as np.arrays with shape (batch_size, 5).

# Test Case 1: Perfect prediction on a single sample.
y_true1 = np.array([[1, 0, 1, 0, 0]], dtype=np.float32)
y_pred1 = np.array([[0.0, 0.2, 0.9, 0.1, 0.0]], dtype=np.float32)
# Explanation: Only the non-background at index 2 is positive, and y_pred2 >= 0.5.

# Test Case 2: Two samples with mixed positives in target columns.
y_true2 = np.array([
    [1, 1, 0, 0, 0],
    [1, 0, 1, 1, 0]
], dtype=np.float32)
y_pred2 = np.array([
    [0.0, 0.8, 0.2, 0.1, 0.1],
    [0.0, 0.3, 0.7, 0.9, 0.2]
], dtype=np.float32)
# Explanation: Some rows have extra positives while others miss a target.

# Test Case 3: All negatives on target columns.
y_true3 = np.array([
    [1, 0, 0, 0, 0],
    [1, 0, 0, 0, 0]
], dtype=np.float32)
y_pred3 = np.array([
    [0.0, 0.1, 0.2, 0.3, 0.4],
    [0.0, 0.1, 0.2, 0.3, 0.4]
], dtype=np.float32)
# Explanation: No target positive exists in either sample.

# Test Case 4: Over-prediction scenario with many false positives.
y_true4 = np.array([
    [1, 0, 1, 0, 0],
    [1, 0, 0, 1, 0]
], dtype=np.float32)
y_pred4 = np.array([
    [0.0, 0.9, 0.8, 0.9, 0.9],
    [0.0, 0.9, 0.9, 0.8, 0.9]
], dtype=np.float32)
# Explanation: High prediction values across the board lead to extra positives.

# Test Case 5: Under-prediction scenario with many false negatives.
y_true5 = np.array([
    [1, 1, 1, 1, 1],
    [1, 0, 1, 0, 1]
], dtype=np.float32)
y_pred5 = np.array([
    [0.0, 0.2, 0.2, 0.2, 0.2],
    [0.0, 0.2, 0.2, 0.2, 0.2]
], dtype=np.float32)
# Explanation: Even though y_true has several positives, y_pred values are too low to cross the threshold.

# Test Case 6: Mixed values with smoothing (note: only exact 1 counts as positive).
y_true6 = np.array([
    [1, 0, 0.5, 1, 0],
    [1, 0, 0, 1, 0]
], dtype=np.float32)
y_pred6 = np.array([
    [0.0, 0.6, 0.6, 0.95, 0.3],
    [0.0, 0.4, 0.4, 0.9, 0.3]
], dtype=np.float32)
# Explanation: The 0.5 in y_true6 is not counted as positive by the metric (only an exact 1 does).

# Test Case 7: Borderline threshold values.
y_true7 = np.array([[1, 1, 0, 1, 0]], dtype=np.float32)
y_pred7 = np.array([[0.0, 0.5, 0.5, 0.499, 0.5]], dtype=np.float32)
# Explanation: Values exactly equal to 0.5 should be interpreted as positive; those below remain negative.
# TP, FP, FN, FP Precision = 1/3, recall = 1/2, accuracy = 0.25, support = 2 

# Test Case 8: Multiple samples where no sample has a target positive.
y_true8 = np.array([
    [1, 0, 0, 0, 0],
    [1, 0, 0, 0, 0],
    [1, 0, 0, 0, 0]
], dtype=np.float32)
y_pred8 = np.array([
    [0.0, 0.2, 0.3, 0.1, 0.4],
    [0.0, 0.2, 0.3, 0.1, 0.4],
    [0.0, 0.2, 0.3, 0.1, 0.4]
], dtype=np.float32)
# Explanation: All rows are “negative” for target classes.

# Test Case 9: Larger batch with a mix of correct and error cases.
y_true9 = np.array([
    [1, 1, 0, 0, 0],
    [1, 0, 1, 0, 0],
    [1, 0, 0, 1, 0],
    [1, 1, 1, 1, 1]
], dtype=np.float32)
y_pred9 = np.array([
    [0.0, 0.9, 0.2, 0.1, 0.3],
    [0.0, 0.3, 0.8, 0.2, 0.1],
    [0.0, 0.2, 0.1, 0.85, 0.4],
    [0.0, 0.95, 0.95, 0.95, 0.95]
], dtype=np.float32)
# Explanation: Varying behavior across rows; the last row has all targets positive.

# Test Case 10: Mixed predictions with different pattern across two samples.
y_true10 = np.array([
    [1, 0, 1, 0, 1],
    [1, 1, 0, 1, 0]
], dtype=np.float32)
y_pred10 = np.array([
    [0.0, 0.1, 0.8, 0.2, 0.9],
    [0.0, 0.85, 0.2, 0.9, 0.1]
], dtype=np.float32)
# Explanation: One row has a positive at index 2 and index 4; the other at indices 1 and 3.

# You can now use the pairs (y_true1, y_pred1) through (y_true10, y_pred10)
# to run your tests on each of the custom metrics.
# For example:
print("Test Case 1 - y_true:", y_true1, "\ny_pred:", y_pred1)
print("Test Case 2 - y_true:", y_true2, "\ny_pred:", y_pred2)
print("Test Case 3 - y_true:", y_true3, "\ny_pred:", y_pred3)
print("Test Case 4 - y_true:", y_true4, "\ny_pred:", y_pred4)
print("Test Case 5 - y_true:", y_true5, "\ny_pred:", y_pred5)
print("Test Case 6 - y_true:", y_true6, "\ny_pred:", y_pred6)
print("Test Case 7 - y_true:", y_true7, "\ny_pred:", y_pred7)
print("Test Case 8 - y_true:", y_true8, "\ny_pred:", y_pred8)
print("Test Case 9 - y_true:", y_true9, "\ny_pred:", y_pred9)
print("Test Case 10 - y_true:", y_true10, "\ny_pred:", y_pred10)


Test Case 1 - y_true: [[1. 0. 1. 0. 0.]] 
y_pred: [[0.  0.2 0.9 0.1 0. ]]
Test Case 2 - y_true: [[1. 1. 0. 0. 0.]
 [1. 0. 1. 1. 0.]] 
y_pred: [[0.  0.8 0.2 0.1 0.1]
 [0.  0.3 0.7 0.9 0.2]]
Test Case 3 - y_true: [[1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]] 
y_pred: [[0.  0.1 0.2 0.3 0.4]
 [0.  0.1 0.2 0.3 0.4]]
Test Case 4 - y_true: [[1. 0. 1. 0. 0.]
 [1. 0. 0. 1. 0.]] 
y_pred: [[0.  0.9 0.8 0.9 0.9]
 [0.  0.9 0.9 0.8 0.9]]
Test Case 5 - y_true: [[1. 1. 1. 1. 1.]
 [1. 0. 1. 0. 1.]] 
y_pred: [[0.  0.2 0.2 0.2 0.2]
 [0.  0.2 0.2 0.2 0.2]]
Test Case 6 - y_true: [[1.  0.  0.5 1.  0. ]
 [1.  0.  0.  1.  0. ]] 
y_pred: [[0.   0.6  0.6  0.95 0.3 ]
 [0.   0.4  0.4  0.9  0.3 ]]
Test Case 7 - y_true: [[1. 1. 0. 1. 0.]] 
y_pred: [[0.    0.5   0.5   0.499 0.5  ]]
Test Case 8 - y_true: [[1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]] 
y_pred: [[0.  0.2 0.3 0.1 0.4]
 [0.  0.2 0.3 0.1 0.4]
 [0.  0.2 0.3 0.1 0.4]]
Test Case 9 - y_true: [[1. 1. 0. 0. 0.]
 [1. 0. 1. 0. 0.]
 [1. 0. 0. 1. 0.]
 [1. 1. 1. 1. 1

In [7]:
metrics_lst=[
        CustomNoBackgroundF1Score(num_classes=5, threshold=0.5, average='weighted'),  # your existing F1 metric
        CustomConditionalF1Score(threshold=0.5, average='weighted', filter_mode='pred'),  # 'pred' 'true' or 'either'
        CustomConditionalF1Score(threshold=0.5, average='weighted', filter_mode='true'),  # 'pred' 'true' or 'either'
        CustomFalsePositiveDistance(num_classes=5, threshold=0.5, window=100),
        # CustomNoBackgroundAUC(curve='ROC'),
        CustomNoBackgroundAccuracy(threshold=0.5),
        CustomNoBackgroundPrecision(threshold=0.5, average='macro'),
        CustomNoBackgroundRecall(threshold=0.5, average='macro'),
        CustomNoBackgroundF1Score(num_classes=5, threshold=0.5, average='macro'),  # your existing F1 metric
        CustomConditionalF1Score(threshold=0.5, average='macro', filter_mode='pred'),  # 'pred' 'true' or 'either'
        CustomConditionalF1Score(threshold=0.5, average='macro', filter_mode='true')
    ]

test_arrays = [(y_true1, y_pred1), (y_true2, y_pred2), (y_true3, y_pred3), (y_true4, y_pred4), (y_true5, y_pred5), (y_true6, y_pred6), (y_true7, y_pred7), (y_true8, y_pred8), (y_true9, y_pred9), (y_true10, y_pred10)]

for metric in metrics_lst:
    print(metric)
    i = 1
    for test in test_arrays:
        print(f'Test Array {i}')
        i += 1
        metric.update_state(test[0], test[1])
        print(metric.result().numpy())
        metric.reset_states()
        

<CustomNoBackgroundF1Score name=no_background_f1>
Test Array 1
0.9999999
Test Array 2
1.0
Test Array 3
0.0
Test Array 4
0.6666667
Test Array 5
0.0
Test Array 6
1.0
Test Array 7
0.5
Test Array 8
0.0
Test Array 9
1.0
Test Array 10
1.0
<CustomConditionalF1Score name=conditional_f1_pred>
Test Array 1
0.9999999
Test Array 2
1.0
Test Array 3
0.0
Test Array 4
0.6666667
Test Array 5
0.0
Test Array 6
1.0
Test Array 7
0.5
Test Array 8
0.0
Test Array 9
1.0
Test Array 10
1.0
<CustomConditionalF1Score name=conditional_f1_true>
Test Array 1
0.9999999
Test Array 2
1.0
Test Array 3
0.0
Test Array 4
0.6666667
Test Array 5
0.0
Test Array 6
1.0
Test Array 7
0.5
Test Array 8
0.0
Test Array 9
1.0
Test Array 10
1.0
<CustomFalsePositiveDistance name=false_positive_distance>
Test Array 1
0.0
Test Array 2
0.0
Test Array 3
0.0
Test Array 4
125.0
Test Array 5
0.0
Test Array 6
63.0
Test Array 7
125.0
Test Array 8
0.0
Test Array 9
0.0
Test Array 10
0.0
<CustomNoBackgroundAccuracy name=no_background_accuracy>
Test 

In [2]:
class CustomNonZeroF1Score(tf.keras.metrics.Metric):
    def __init__(self, num_classes, average='weighted', name='non_zero_f1', **kwargs):
        """
        Custom F1 score metric that only considers non-zero classes.
        
        Args:
            num_classes (int): Total number of classes. Class 0 is assumed to be the "background" class.
            average (str): 'weighted' (default) to weight by support or 'macro' for a simple average.
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNonZeroF1Score, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average
        
        # Accumulate counts per class
        self.true_positives = self.add_weight(
            name='tp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_positives = self.add_weight(
            name='fp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_negatives = self.add_weight(
            name='fn', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        Updates the confusion matrix statistics.
        
        Args:
            y_true: Tensor of shape (batch_size, seq_length) with integer class labels.
            y_pred: Tensor of shape (batch_size, seq_length, num_classes) with probability distributions.
            sample_weight: Optional sample weights.
        """
        # Convert predictions to class labels using argmax along the last axis.
        y_pred = tf.argmax(y_pred, axis=-1)
        
        # Flatten the batch and sequence dimensions.
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(y_pred, [-1])
        
        # Compute confusion matrix over all predictions.
        cm = tf.math.confusion_matrix(
            y_true, y_pred, num_classes=self.num_classes, dtype=tf.float32
        )
        tp = tf.linalg.diag_part(cm)
        fp = tf.reduce_sum(cm, axis=0) - tp
        fn = tf.reduce_sum(cm, axis=1) - tp
        
        # Update state variables.
        self.true_positives.assign_add(tp)
        self.false_positives.assign_add(fp)
        self.false_negatives.assign_add(fn)
    
    def result(self):
        """
        Computes the F1 score for non-zero classes.
        
        Returns:
            F1 score computed over the non-zero classes.
        """
        precision = tf.math.divide_no_nan(
            self.true_positives, self.true_positives + self.false_positives
        )
        recall = tf.math.divide_no_nan(
            self.true_positives, self.true_positives + self.false_negatives
        )
        f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall)
        
        # Exclude class 0 (the background) from the evaluation.
        f1_non_zero = f1[1:]
        support_non_zero = (self.true_positives + self.false_negatives)[1:]
        
        if self.average == 'weighted':
            # Weight F1 by the support of each class.
            weighted_f1 = tf.reduce_sum(f1_non_zero * support_non_zero) / (tf.reduce_sum(support_non_zero) + K.epsilon())
            return weighted_f1
        else:  # macro
            return tf.reduce_mean(f1_non_zero)
    
    def reset_states(self):
        """
        Resets the metric state variables.
        """
        for v in self.variables:
            v.assign(tf.zeros_like(v))

In [None]:
def update_state(self, y_true, y_pred, sample_weight=None):
    ...
    tf.print("Confusion matrix:", cm)
    tf.print("True Positives:", tp)
    tf.print("False Positives:", fp)
    tf.print("False Negatives:", fn)
    ...


In [3]:
# Assume CustomNonZeroF1Score is already defined as provided earlier.

def test_custom_nonzero_f1():
    num_classes = 3  # e.g., class 0 is background, classes 1 and 2 are "interesting"
    
    # Instantiate the metric (using weighted average for non-zero classes)
    metric = CustomNonZeroF1Score(num_classes=num_classes, average='weighted')
    
    ### Test Case 1: Perfect Predictions
    # Create a small example where predictions are exactly correct.
    # For instance, a single batch with a sequence of 4 values.
    # Let y_true be: [0, 1, 2, 1]
    y_true = tf.constant([[0, 1, 2, 1]], dtype=tf.int32)
    # Create one-hot predictions corresponding exactly to y_true.
    y_pred = tf.one_hot(y_true, depth=num_classes)  # shape: (1, 4, 3)
    
    # Update metric state and get result.
    metric.update_state(y_true, y_pred)
    result = metric.result().numpy()
    print("Test 1 (Perfect Predictions) - Non-zero F1:", result)
    # Expected: Since predictions are perfect, F1 for classes 1 and 2 should be 1.
    
    # Reset metric for the next test.
    metric.reset_states()
    
    ### Test Case 2: Imperfect Predictions
    # Construct a small example with some errors.
    # For example, consider:
    #   y_true: [0, 1, 1, 2, 0]
    #   y_pred: [0, 1, 2, 2, 0]
    # Here, for class 1: one correct and one misclassification,
    #       for class 2: one correct and one false positive (predicted instead of a 1).
    y_true = tf.constant([[0, 1, 1, 2, 0]], dtype=tf.int32)
    
    # Manually build the one-hot predictions.
    # For each position, the vector is one-hot for the predicted class.
    y_pred = tf.constant([[
        [1, 0, 0],  # Correctly predicts class 0.
        [0, 1, 0],  # Correctly predicts class 1.
        [0, 0, 1],  # Incorrectly predicts class 2 (should be class 1).
        [0, 0, 1],  # Correctly predicts class 2.
        [1, 0, 0]   # Correctly predicts class 0.
    ]], dtype=tf.float32)
    
    metric.update_state(y_true, y_pred)
    result = metric.result().numpy()
    print("Test 2 (Imperfect Predictions) - Non-zero F1:", result)
    # Expected (manually computed):
    #   For class 1: true positives = 1, false negatives = 1, precision=1, recall=0.5, F1 ~ 0.6667.
    #   For class 2: true positives = 1, false positives = 1, precision=0.5, recall=1, F1 ~ 0.6667.
    # Weighted average F1 = (0.6667*2 + 0.6667*1) / 3 ~ 0.6667.
    
    # Optionally, reset states again.
    metric.reset_states()

# Run the test function
test_custom_nonzero_f1()

I0000 00:00:1738565800.403144     713 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:04:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1738565800.597694     713 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:04:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1738565800.597805     713 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:04:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1738565800.602279     713 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:04:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1738565800.602337     713 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:04:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:0

Test 1 (Perfect Predictions) - Non-zero F1: 1.0
Test 2 (Imperfect Predictions) - Non-zero F1: 0.6666667


In [None]:
import tensorflow as tf

def custom_binary_crossentropy_loss(
    dominant_class_index=0,
    dominant_correct_multiplier=0.1,    # reward factor when the dominant class is correct
    dominant_incorrect_multiplier=2.0,    # penalty factor when the dominant class is predicted incorrectly
    other_class_multiplier=1.0,           # multiplier for non-dominant classes when y_true == 1
    smoothing_multiplier=1.0              # multiplier for non-dominant classes when y_true is a smoothed value (0 < y_true < 1)
):
    """
    Returns a custom binary crossentropy loss function that treats the dominant class specially,
    and applies different multipliers for non-dominant classes based on their true label values.
    
    For the dominant class (specified by dominant_class_index):
      - If y_true == 1, the loss is scaled by dominant_correct_multiplier.
      - Otherwise, it is scaled by dominant_incorrect_multiplier.
    
    For non-dominant classes:
      - If y_true == 1, the loss is scaled by other_class_multiplier.
      - If 0 < y_true < 1 (e.g. label-smoothed values, typically in (0, 0.5]), the loss is scaled by smoothing_multiplier.
      - If y_true == 0, no additional scaling is applied.
      
    Parameters:
      dominant_class_index (int): Index of the dominant class in the output vector.
      dominant_correct_multiplier (float): Multiplier for the loss when the dominant class is correctly predicted.
      dominant_incorrect_multiplier (float): Multiplier for the loss when the dominant class is incorrectly predicted.
      other_class_multiplier (float): Multiplier for non-dominant classes when the true label is 1.
      smoothing_multiplier (float): Multiplier for non-dominant classes when the true label is a smoothed value (0 < y_true < 1).
      
    Returns:
      A callable loss function usable with model.compile(loss=...).
    """
    def loss(y_true, y_pred):
        # Prevent issues with log(0)
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
        
        # Compute standard element-wise binary crossentropy.
        base_loss = - (y_true * tf.math.log(y_pred) +
                       (1 - y_true) * tf.math.log(1 - y_pred))
        
        # Determine the number of classes.
        num_classes = tf.shape(y_true)[1]
        
        # Create a one-hot mask for the dominant class.
        dominant_mask = tf.one_hot(dominant_class_index, depth=num_classes, dtype=y_true.dtype)
        # The complement selects all non-dominant columns.
        non_dominant_mask = 1 - dominant_mask
        
        # --- Dominant Class Weighting ---
        # For the dominant class: if y_true == 1 use dominant_correct_multiplier; otherwise use dominant_incorrect_multiplier.
        dominant_true = y_true[:, dominant_class_index]  # Shape: (batch_size,)
        dominant_weight = tf.where(tf.equal(dominant_true, 1.0),
                                   dominant_correct_multiplier,
                                   dominant_incorrect_multiplier)  # Shape: (batch_size,)
        dominant_weight = tf.expand_dims(dominant_weight, axis=1)  # Shape: (batch_size, 1)
        
        # --- Non-Dominant Class Weighting ---
        # For non-dominant classes, apply:
        #   - other_class_multiplier if y_true == 1
        #   - smoothing_multiplier if 0 < y_true < 1 (i.e. a smoothed value)
        #   - otherwise (y_true == 0) leave as 1.
        non_dominant_weight = tf.where(
            tf.equal(y_true, 1.0),
            other_class_multiplier,
            tf.where(tf.greater(y_true, 0.0),
                     smoothing_multiplier,
                     1.0)
        )
        
        # Combine the weights for each class.
        weights = dominant_mask * dominant_weight + non_dominant_mask * non_dominant_weight
        
        # Compute and return the weighted loss.
        weighted_loss = base_loss * weights
        return tf.reduce_mean(weighted_loss)
    
    return loss


In [None]:
# Instantiate the loss with desired multipliers.
loss_fn = custom_binary_crossentropy_loss(
    dominant_class_index=0,
    dominant_correct_multiplier=0.1,    # barely reward a correct dominant guess
    dominant_incorrect_multiplier=2.0,    # strongly penalize an incorrect dominant guess
    other_class_multiplier=1.0,           # multiplier when a non-dominant class's y_true is 1
    smoothing_multiplier=0.8              # multiplier for non-dominant classes with smoothed values (0 < y_true < 1)
)

model.compile(optimizer='adam', loss=loss_fn)


In [None]:
import tensorflow as tf
from keras import backend as K

class CustomNonZeroF1Score(tf.keras.metrics.Metric):
    def __init__(self, num_classes, average='weighted', threshold=0.5, name='non_zero_f1', **kwargs):
        """
        Custom F1 score metric that only considers non-dominant classes (ignoring index 0).
        
        This version is designed for multi-encoded labels where:
          - The dominant class (index 0) is represented as a hard label [1, 0, 0, ...]
          - For non-dominant classes (indices 1 to num_classes-1), only an exact label of 1 is considered positive.
            (Any partial credit/smoothed values below 1 are treated as 0.)
          - Predictions are thresholded (default threshold = 0.5) to decide 1 vs. 0.
        
        Args:
            num_classes (int): Total number of classes.
            average (str): 'weighted' (default) to weight by support or 'macro' for a simple average.
            threshold (float): Threshold on y_pred to decide a positive (default 0.5).
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNonZeroF1Score, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average

        # Create state variables to accumulate counts for each class.
        # We use a vector of length num_classes but we will update only indices 1...num_classes-1.
        self.true_positives = self.add_weight(
            name='tp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_positives = self.add_weight(
            name='fp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_negatives = self.add_weight(
            name='fn', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        Updates the metric state.
        
        Args:
            y_true: Tensor of shape (batch_size, num_classes). These are multi-encoded labels.
                    For non-dominant classes, a label is considered positive only if it is exactly 1.
            y_pred: Tensor of shape (batch_size, num_classes) with predictions (e.g. probabilities).
            sample_weight: Optional sample weights.
        """
        # We want to ignore the dominant class (index 0) and work on classes 1...num_classes-1.
        # Assume y_true and y_pred are both of shape (batch_size, num_classes).
        y_true_non_dominant = y_true[:, 1:]
        y_pred_non_dominant = y_pred[:, 1:]
        
        # For ground truth: treat a class as positive only if its value is exactly 1.
        y_true_bin = tf.cast(tf.equal(y_true_non_dominant, 1.0), tf.int32)
        # For predictions: apply thresholding.
        y_pred_bin = tf.cast(y_pred_non_dominant >= self.threshold, tf.int32)
        
        # (Optionally) apply sample weighting.
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.int32)
            sample_weight = tf.reshape(sample_weight, (-1, 1))
            y_true_bin = y_true_bin * sample_weight
            y_pred_bin = y_pred_bin * sample_weight
        
        # Compute per-class true positives, false positives, and false negatives for non-dominant classes.
        tp = tf.reduce_sum(tf.cast(y_true_bin * y_pred_bin, tf.float32), axis=0)
        fp = tf.reduce_sum(tf.cast((1 - y_true_bin) * y_pred_bin, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true_bin * (1 - y_pred_bin), tf.float32), axis=0)
        
        # Our state variables have length num_classes. We want to update only indices 1... with our computed values.
        zeros = tf.zeros([1], dtype=tf.float32)
        tp_update = tf.concat([zeros, tp], axis=0)
        fp_update = tf.concat([zeros, fp], axis=0)
        fn_update = tf.concat([zeros, fn], axis=0)
        
        self.true_positives.assign_add(tp_update)
        self.false_positives.assign_add(fp_update)
        self.false_negatives.assign_add(fn_update)

    def result(self):
        """
        Computes the F1 score over the non-dominant classes (indices 1...num_classes-1).
        """
        # Select non-dominant classes only.
        tp = self.true_positives[1:]
        fp = self.false_positives[1:]
        fn = self.false_negatives[1:]
        
        precision = tf.math.divide_no_nan(tp, tp + fp)
        recall = tf.math.divide_no_nan(tp, tp + fn)
        f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall)
        
        if self.average == 'weighted':
            support = tp + fn
            weighted_f1 = tf.reduce_sum(f1 * support) / (tf.reduce_sum(support) + K.epsilon())
            return weighted_f1
        else:  # macro
            return tf.reduce_mean(f1)

    def reset_states(self):
        """
        Resets all of the metric state variables.
        """
        for v in self.variables:
            v.assign(tf.zeros_like(v))


In [None]:
model.compile(optimizer='adam',
              loss='your_loss_function',  # e.g., your custom loss
              metrics=[CustomNonZeroF1Score(num_classes=5, average='weighted', threshold=0.5)])


In [None]:
class CustomNonZeroF1Score(tf.keras.metrics.Metric):
    def __init__(self, num_classes, average='weighted', threshold=0.5, name='non_zero_f1', **kwargs):
        """
        Custom F1 score metric that only considers non-dominant classes (ignoring index 0).
        
        This version is designed for multi-encoded labels where:
          - The dominant class (index 0) is represented as a hard label (e.g. [1, 0, 0, ...]).
          - For non-dominant classes (indices 1 to num_classes-1), only an exact label of 1 is considered positive.
            (Any partial credit/smoothed values below 1 are treated as 0.)
          - Predictions are thresholded (default threshold = 0.5) to decide a positive.
        
        Args:
            num_classes (int): Total number of classes.
            average (str): 'weighted' (default) to weight by support or 'macro' for a simple average.
            threshold (float): Threshold on y_pred to decide a positive (default 0.5).
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNonZeroF1Score, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average

        # State variables: accumulate true positives, false positives, and false negatives per class.
        # We will update only indices 1...num_classes-1.
        self.true_positives = self.add_weight(
            name='tp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_positives = self.add_weight(
            name='fp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_negatives = self.add_weight(
            name='fn', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        Updates the metric state.
        
        Args:
            y_true: Tensor of shape (batch_size, num_classes) with multi-encoded labels.
                    For non-dominant classes, a label is considered positive only if it is exactly 1.
            y_pred: Tensor of shape (batch_size, num_classes) with predicted scores/probabilities.
            sample_weight: Optional sample weights.
        """
        # We want to ignore the dominant class (index 0); work only on indices 1...num_classes-1.
        y_true_non_dominant = y_true[:, 1:]
        y_pred_non_dominant = y_pred[:, 1:]
        
        # For ground truth: treat a class as positive only if its value is exactly 1.
        y_true_bin = tf.cast(tf.equal(y_true_non_dominant, 1.0), tf.int32)
        # For predictions: threshold to decide positive.
        y_pred_bin = tf.cast(y_pred_non_dominant >= self.threshold, tf.int32)
        
        # Optionally, apply sample weights (if provided).
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.int32)
            sample_weight = tf.reshape(sample_weight, (-1, 1))
            y_true_bin = y_true_bin * sample_weight
            y_pred_bin = y_pred_bin * sample_weight
        
        # Compute per-class true positives, false positives, and false negatives for non-dominant classes.
        tp = tf.reduce_sum(tf.cast(y_true_bin * y_pred_bin, tf.float32), axis=0)
        fp = tf.reduce_sum(tf.cast((1 - y_true_bin) * y_pred_bin, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true_bin * (1 - y_pred_bin), tf.float32), axis=0)
        
        # Our state variables are of shape (num_classes,).
        # Since we are ignoring the dominant class (index 0), prepend zeros to match the full shape.
        zeros = tf.zeros([1], dtype=tf.float32)
        tp_update = tf.concat([zeros, tp], axis=0)
        fp_update = tf.concat([zeros, fp], axis=0)
        fn_update = tf.concat([zeros, fn], axis=0)
        
        self.true_positives.assign_add(tp_update)
        self.false_positives.assign_add(fp_update)
        self.false_negatives.assign_add(fn_update)

    def result(self):
        """
        Computes the F1 score over the non-dominant classes (indices 1 to num_classes-1).
        """
        # Use only non-dominant classes.
        tp = self.true_positives[1:]
        fp = self.false_positives[1:]
        fn = self.false_negatives[1:]
        
        precision = tf.math.divide_no_nan(tp, tp + fp)
        recall = tf.math.divide_no_nan(tp, tp + fn)
        f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall)
        
        if self.average == 'weighted':
            support = tp + fn
            weighted_f1 = tf.reduce_sum(f1 * support) / (tf.reduce_sum(support) + K.epsilon())
            return weighted_f1
        else:  # macro average
            return tf.reduce_mean(f1)

    def reset_states(self):
        """
        Resets all of the metric state variables.
        """
        for v in self.variables:
            v.assign(tf.zeros_like(v))

# --- Testing the Custom Metric ---

def test_custom_non_zero_f1():
    num_classes = 5
    # Create an instance of the metric (we'll test both weighted and macro averaging).
    metric_weighted = CustomNonZeroF1Score(num_classes=num_classes, average='weighted', threshold=0.5)
    metric_macro = CustomNonZeroF1Score(num_classes=num_classes, average='macro', threshold=0.5)
    
    # Create some synthetic data.
    # Each row is a multi-encoded label vector.
    # The dominant class (index 0) is a hard label (always 1 if active).
    # For non-dominant classes, only an exact 1 is treated as positive.
    # Here are 3 samples:
    #
    # Sample 1:
    #   y_true: [1, 1, 0, 0, 0]
    #   y_pred: [0.9, 0.8, 0.3, 0.2, 0.1] -> after threshold (non-dominant): [1, 0, 0, 0]
    #
    # Sample 2:
    #   y_true: [1, 0, 1, 0, 1]
    #   y_pred: [0.95, 0.4, 0.6, 0.2, 0.7] -> after threshold (non-dominant): [0, 1, 0, 1]
    #
    # Sample 3:
    #   y_true: [1, 0, 0, 1, 0]
    #   y_pred: [0.8, 0.3, 0.4, 0.55, 0.2] -> after threshold (non-dominant): [0, 0, 1, 0]
    
    y_true = np.array([
        [1, 1, 0, 0, 0],
        [1, 0, 1, 1, 1],
        [1, 0, 0, 1, 0]
    ], dtype=np.float32)
    
    # y_pred = np.array([
    #     [0.9, 0.8, 0.3, 0.6, 0.1],
    #     [0.95, 0.4, 0.6, 0.2, 0.7],
    #     [0.8, 0.3, 0.4, 0.55, 0.2]
    # ], dtype=np.float32)
    
    y_pred = np.array([
        [1, 1, 0, 1, 0],
        [1, 0, 1, 0, 1],
        [1, 0, 0, 1, 0]
    ], dtype=np.float32)
    
    # Update the metric state with our batch.
    metric_weighted.update_state(y_true, y_pred)
    metric_macro.update_state(y_true, y_pred)
    
    # Get the results.
    f1_weighted = metric_weighted.result().numpy()
    f1_macro = metric_macro.result().numpy()
    
    print("Custom Non-Zero F1 Score (weighted average):", f1_weighted)
    print("Custom Non-Zero F1 Score (macro average):", f1_macro)
    
    # Optionally, reset the states and test with a new batch.
    metric_weighted.reset_states()
    # New synthetic data with 2 samples.
    y_true2 = np.array([
        [1, 0, 1, 0, 1],
        [1, 1, 0, 1, 0]
    ], dtype=np.float32)
    # y_pred2 = np.array([
    #     [0.99, 0.2, 0.8, 0.6, 0.9],
    #     [0.9, 0.7, 0.1, 0.2, 0.1]
    # ], dtype=np.float32)
    y_pred2 = np.array([
        [1, 0, 1, 1, 1],
        [1, 1, 0, 0, 0]
    ], dtype=np.float32)
    
    metric_weighted.update_state(y_true2, y_pred2)
    f1_weighted_new = metric_weighted.result().numpy()
    print("Custom Non-Zero F1 Score after reset (weighted average):", f1_weighted_new)

if __name__ == "__main__":
    test_custom_non_zero_f1()


Custom Non-Zero F1 Score (weighted average): 0.8
Custom Non-Zero F1 Score (macro average): 0.875
Custom Non-Zero F1 Score after reset (weighted average): 0.75


In [None]:
from sklearn.metrics import f1_score

# Synthetic test data (each row is one sample).
# The dominant class is in column 0 (this column will be ignored).
# For non-dominant classes, only an exact 1 is treated as positive.
# y_true = np.array([
#     [1, 1, 0, 0, 0],  # Sample 1
#     [1, 0, 1, 1, 1],  # Sample 2
#     [1, 0, 0, 1, 0]   # Sample 3
# ], dtype=np.float32)

# # Predicted probabilities for each class.
# y_pred = np.array([
#     [0.9, 0.8, 0.3, 0.9, 0.1],  # Sample 1 predictions
#     [0.95, 0.4, 0.6, 0.2, 0.7],  # Sample 2 predictions
#     [0.8, 0.3, 0.4, 0.55, 0.2]   # Sample 3 predictions
# ], dtype=np.float32)
y_true = np.array([
    [1, 0, 1, 0, 1],
    [1, 1, 0, 1, 0]
], dtype=np.float32)

# Predicted probabilities for each class.
y_pred = np.array([
    [1, 0, 1, 1, 1],
    [1, 1, 0, 0, 0]
], dtype=np.float32)

# Define the threshold for converting probabilities to binary predictions.
threshold = 0.5

# Since the dominant class is at index 0 and should be ignored, work only on columns 1 to end.
y_true_non_dom = y_true[:, 1:]
y_pred_non_dom = y_pred[:, 1:]

# Binarize the ground truth: Only a value exactly equal to 1 counts as positive.
y_true_bin = (y_true_non_dom == 1).astype(int)

# Binarize the predictions using the threshold.
y_pred_bin = (y_pred_non_dom >= threshold).astype(int)

print("Ground truth (non-dominant classes):")
print(y_true_bin)
print("Predictions (non-dominant classes):")
print(y_pred_bin)

# Compute the F1 score using scikit-learn.
# Use 'macro' to simply average the F1 scores of each class,
# or 'weighted' to weight them by the support (number of true instances).
f1_macro = f1_score(y_true_bin, y_pred_bin, average='macro')
f1_weighted = f1_score(y_true_bin, y_pred_bin, average='weighted')



print("Sklearn F1 Score (macro average):", f1_macro)
print("Sklearn F1 Score (weighted average):", f1_weighted)

Ground truth (non-dominant classes):
[[0 1 0 1]
 [1 0 1 0]]
Predictions (non-dominant classes):
[[0 1 1 1]
 [1 0 0 0]]
Sklearn F1 Score (macro average): 0.75
Sklearn F1 Score (weighted average): 0.75
