In [None]:
import time
import sys
import os
import glob
import math
import threading
import concurrent.futures as cf
import random
import re

import numpy as np
import pandas as pd
import tensorflow as tf
from keras import Input, Model, layers, metrics, losses, callbacks, optimizers, models, utils
from keras import backend as K
import gc
import keras_tuner as kt
from pyfaidx import Fasta

K.clear_session()
gc.collect()

datasets_path = "../../Datasets/"
models_path = "../../Models/"

The next cell contains the custom metrics, required because of label sparseness.  The background column (0 index in Python) which is more than 99% 1's would have too strong an effect on any scoring.  The final metric does score the background column which rapidly reaches an F1 score of 0.98 or 0.99 during training.

In [None]:
@utils.register_keras_serializable()
class CustomNoBackgroundF1Score(metrics.Metric):
    def __init__(self, num_classes, average='weighted', threshold=0.5, name='no_background_f1', **kwargs):
        """
        Custom F1 score metric that only considers non-dominant classes (ignoring index 0).
        
        This version is designed for multi-encoded labels where:
          - The dominant class (index 0) is represented as a hard label [1, 0, 0, ...]
          - For non-dominant classes (indices 1 to num_classes-1), only an exact label of 1 is considered positive.
            (Any partial credit/smoothed values below 1 are treated as 0.)
          - Predictions are thresholded (default threshold = 0.5) to decide 1 vs. 0.
        
        Args:
            num_classes (int): Total number of classes.
            average (str): 'weighted' (default) to weight by support or 'macro' for a simple average.
            threshold (float): Threshold on y_pred to decide a positive (default 0.5).
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNoBackgroundF1Score, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average

        # Create state variables to accumulate counts for each class.
        # We use a vector of length num_classes but we will update only indices 1...num_classes-1.
        self.true_positives = self.add_weight(
            name='tp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_positives = self.add_weight(
            name='fp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_negatives = self.add_weight(
            name='fn', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        Updates the metric state.
        
        Args:
            y_true: Tensor of shape (batch_size, num_classes). These are multi-encoded labels.
                    For non-dominant classes, a label is considered positive only if it is exactly 1.
            y_pred: Tensor of shape (batch_size, num_classes) with predictions (e.g. probabilities).
            sample_weight: Optional sample weights.
        """
        
        # Flatten all dimensions except the last one (which should be num_classes).
        y_true = tf.reshape(y_true, [-1, self.num_classes])
        y_pred = tf.reshape(y_pred, [-1, self.num_classes])
        
        # We want to ignore the dominant class (index 0) and work on classes 1...num_classes-1.
        # Assume y_true and y_pred are both of shape (batch_size, num_classes).
        y_true_non_dominant = y_true[:, 1:]
        y_pred_non_dominant = y_pred[:, 1:]
        
        # For ground truth: treat a class as positive only if its value is exactly 1.
        one_value = tf.cast(1.0, dtype=y_true_non_dominant.dtype)
        y_true_bin = tf.cast(tf.equal(y_true_non_dominant, one_value), tf.int32)
        # For predictions: apply thresholding.
        y_pred_bin = tf.cast(y_pred_non_dominant >= self.threshold, tf.int32)
        
        # (Optionally) apply sample weighting.
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.int32)
            sample_weight = tf.reshape(sample_weight, (-1, 1))
            y_true_bin = y_true_bin * sample_weight
            y_pred_bin = y_pred_bin * sample_weight
        
        # Compute per-class true positives, false positives, and false negatives for non-dominant classes.
        tp = tf.reduce_sum(tf.cast(y_true_bin * y_pred_bin, tf.float32), axis=0)
        fp = tf.reduce_sum(tf.cast((1 - y_true_bin) * y_pred_bin, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true_bin * (1 - y_pred_bin), tf.float32), axis=0)
        
        # Our state variables have length num_classes. We want to update only indices 1... with our computed values.
        zeros = tf.zeros([1], dtype=tf.float32)
        tp_update = tf.concat([zeros, tp], axis=0)
        fp_update = tf.concat([zeros, fp], axis=0)
        fn_update = tf.concat([zeros, fn], axis=0)
        
        self.true_positives.assign_add(tp_update)
        self.false_positives.assign_add(fp_update)
        self.false_negatives.assign_add(fn_update)

    def result(self):
        """
        Computes the F1 score over the non-dominant classes (indices 1...num_classes-1).
        """
        # Select non-dominant classes only.
        tp = self.true_positives[1:]
        fp = self.false_positives[1:]
        fn = self.false_negatives[1:]
        
        precision = tf.math.divide_no_nan(tp, tp + fp)
        recall = tf.math.divide_no_nan(tp, tp + fn)
        f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall)
        
        if self.average == 'weighted':
            support = tp + fn
            weighted_f1 = tf.reduce_sum(f1 * support) / (tf.reduce_sum(support) + K.epsilon())
            return weighted_f1
        else:  # macro
            return tf.reduce_mean(f1)

    def reset_states(self):
        """
        Resets all of the metric state variables.
        """
        for v in self.variables:
            v.assign(tf.zeros_like(v))
            
    def get_config(self):
        """
        Returns the configuration of the metric, so it can be recreated later.
        """
        config = super(CustomNoBackgroundF1Score, self).get_config()
        config.update({
            'num_classes': self.num_classes,
            'average': self.average,
            'threshold': self.threshold,
        })
        return config

@utils.register_keras_serializable()
class CustomConditionalF1Score(metrics.Metric):
    def __init__(self, threshold=0.5, average='weighted', filter_mode='either', name='conditional_f1', **kwargs):
        """
        Custom F1 score metric that computes the F1 score only for target columns (columns 1-4).
        Additionally, only rows meeting a filtering criterion are included in the calculation.
        
        Args:
            threshold (float): Threshold on y_pred to decide a positive (default = 0.5).
            average (str): 'weighted' (default) to weight by support or 'macro' for a simple average.
            filter_mode (str): Determines which rows to include based on the target columns.
                               Options:
                                  - 'pred': Only rows where y_pred (after thresholding) has at least one 1.
                                  - 'true': Only rows where y_true (exactly equal to 1) has at least one 1.
                                  - 'either': Rows where either y_true or y_pred has at least one 1.
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        
        Note:
            This metric only tracks columns 1-4 (0-indexed). Column 0 (the dominant background class)
            is ignored completely.
        """
        metric_name = f'{name}_{filter_mode}'
        
        super(CustomConditionalF1Score, self).__init__(name=metric_name, **kwargs)
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average
        
        if filter_mode not in ['pred', 'true', 'either']:
            raise ValueError("filter_mode must be 'pred', 'true', or 'either'")
        self.filter_mode = filter_mode
        
        # We are tracking only 4 target columns (columns 1 to 4).
        self.num_target_columns = 4
        self.true_positives = self.add_weight(
            name='tp', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )
        self.false_positives = self.add_weight(
            name='fp', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )
        self.false_negatives = self.add_weight(
            name='fn', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Reshape inputs so that the last dimension is the number of classes.
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        
        # Only consider columns 1-4 (ignoring index 0).
        y_true_subset = y_true[:, 1:5]
        y_pred_subset = y_pred[:, 1:5]
        
        # For ground truth, treat a label as positive only if its value is exactly 1.
        y_true_bin = tf.cast(tf.equal(y_true_subset, 1.0), tf.int32)
        # For predictions, apply the threshold to decide 1 vs. 0.
        y_pred_bin = tf.cast(y_pred_subset >= self.threshold, tf.int32)
        
        # Compute a row-level mask based on the filter_mode.
        if self.filter_mode == 'pred':
            mask = tf.reduce_any(tf.equal(y_pred_bin, 1), axis=1)
        elif self.filter_mode == 'true':
            mask = tf.reduce_any(tf.equal(y_true_bin, 1), axis=1)
        else:  # 'either'
            mask = tf.logical_or(
                tf.reduce_any(tf.equal(y_pred_bin, 1), axis=1),
                tf.reduce_any(tf.equal(y_true_bin, 1), axis=1)
            )
        
        # Apply the mask so only selected rows are used for the metric update.
        y_true_filtered = tf.boolean_mask(y_true_bin, mask)
        y_pred_filtered = tf.boolean_mask(y_pred_bin, mask)
        
        # Optionally apply sample weighting.
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.reshape(sample_weight, [-1, 1])
            y_true_filtered = y_true_filtered * sample_weight
            y_pred_filtered = y_pred_filtered * sample_weight
        
        # Compute per-column true positives, false positives, and false negatives.
        tp = tf.reduce_sum(tf.cast(y_true_filtered * y_pred_filtered, tf.float32), axis=0)
        fp = tf.reduce_sum(tf.cast((1 - y_true_filtered) * y_pred_filtered, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true_filtered * (1 - y_pred_filtered), tf.float32), axis=0)
        
        self.true_positives.assign_add(tp)
        self.false_positives.assign_add(fp)
        self.false_negatives.assign_add(fn)

    def result(self):
        precision = tf.math.divide_no_nan(self.true_positives, self.true_positives + self.false_positives)
        recall = tf.math.divide_no_nan(self.true_positives, self.true_positives + self.false_negatives)
        f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall)
        
        if self.average == 'weighted':
            support = self.true_positives + self.false_negatives
            return tf.reduce_sum(f1 * support) / (tf.reduce_sum(support) + K.epsilon())
        else:  # 'macro'
            return tf.reduce_mean(f1)

    def reset_states(self):
        for v in self.variables:
            v.assign(tf.zeros_like(v))
            
    def get_config(self):
        config = super(CustomConditionalF1Score, self).get_config()
        config.update({
            'threshold': self.threshold,
            'average': self.average,
            'filter_mode': self.filter_mode,
        })
        return config


@utils.register_keras_serializable()
class CustomFalsePositiveDistance(metrics.Metric):
    def __init__(self, num_classes, threshold=0.5, window=100, name='false_positive_distance', **kwargs):
        """
        Metric that accumulates a running average “distance” error for false positive predictions,
        ignoring the dominant (background) class (index 0).

        For each false positive (i.e. a prediction >= threshold when the strict label is not 1),
        the distance is computed from the raw label value (which encodes proximity to an actual annotation)
        as follows:

            distance = 1 + ((max_credit - v) * (window / max_credit))

        where:
            - v is the raw label value at that position,
            - max_credit is the maximum smoothing credit (0.5 in our scheme), so that if v == 0.5 the distance is 1,
              and if v == 0 the distance is 1 + window (i.e. 101 for window=100).

        Args:
            num_classes (int): Total number of classes.
            threshold (float): Threshold on y_pred to decide a positive.
            window (int): Window size used in the smoothing scheme.
            name (str): Name of the metric.
        """
        super(CustomFalsePositiveDistance, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.threshold = threshold
        self.window = float(window)
        self.max_credit = 0.5  # Based on smoothing scheme.

        # State variables to accumulate total distance and count of false positives.
        self.total_distance = self.add_weight(
            name='total_distance', initializer='zeros', dtype=tf.float32
        )
        self.false_positive_count = self.add_weight(
            name='false_positive_count', initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        For non-dominant classes (indices 1:), this method:
          - thresholds predictions,
          - identifies false positives (prediction is positive while strict label != 1),
          - computes the distance error from the raw (smoothed) label value, and
          - accumulates the sum of distances and count of false positives.
        """
        # Ensure shape (batch_size, num_classes)
        y_true = tf.reshape(y_true, [-1, self.num_classes])
        y_pred = tf.reshape(y_pred, [-1, self.num_classes])

        # Ignore the dominant/background class (index 0)
        y_true_non = y_true[:, 1:]
        y_pred_non = y_pred[:, 1:]

        # Threshold predictions
        y_pred_bin = tf.cast(y_pred_non >= self.threshold, tf.float32)

        # For strict classification, a label is positive only if it is exactly 1.
        # So a false positive is when y_pred_bin==1 but y_true (strict) is not 1.
        # (This is similar to the F1 metric, i.e. smoothing values are treated as negatives.)
        false_positive_mask = tf.logical_and(
            tf.equal(y_pred_bin, 1.0),
            tf.not_equal(y_true_non, 1.0)
        )
        false_positive_mask = tf.cast(false_positive_mask, tf.float32)

        # Compute distance per element.
        # In our smoothing scheme:
        #   - At a true annotation (v = 1), we wouldn’t count a false positive.
        #   - In a smoothed region, the maximum credit is 0.5.
        #   - We define:
        #       distance = 1 + ((max_credit - v) * (window / max_credit))
        #     so that if v == 0.5, distance = 1, and if v == 0, distance = 1 + window.
        distance = 1.0 + (self.max_credit - y_true_non) * (self.window / self.max_credit)
        distance = tf.where(distance >= 101.0, tf.constant(125.0, dtype=distance.dtype), distance)

        # Only include entries that are false positives.
        false_positive_distance = distance * false_positive_mask

        # Sum distances and count false positives.
        sum_distance = tf.reduce_sum(false_positive_distance)
        count = tf.reduce_sum(false_positive_mask)

        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.reshape(sample_weight, [-1, 1])
            sum_distance = tf.reduce_sum(false_positive_distance * sample_weight)
            count = tf.reduce_sum(false_positive_mask * sample_weight)

        self.total_distance.assign_add(sum_distance)
        self.false_positive_count.assign_add(count)

    def result(self):
        """Returns the average distance error over all false positives (or 0 if none)."""
        return tf.math.divide_no_nan(self.total_distance, self.false_positive_count)

    def reset_states(self):
        """Resets the accumulated total distance and count."""
        self.total_distance.assign(0.0)
        self.false_positive_count.assign(0.0)

    def get_config(self):
        config = super(CustomFalsePositiveDistance, self).get_config()
        config.update({
            'num_classes': self.num_classes,
            'threshold': self.threshold,
            'window': self.window,
        })
        return config

@utils.register_keras_serializable()
class CustomNoBackgroundAUC(metrics.Metric):
    def __init__(self, curve='PR', name='no_background_auc', **kwargs):
        """
        Custom AUC metric computed only for columns 1-4.

        Args:
            curve (str): The type of AUC curve to use, e.g. 'ROC' (default) or 'PR'.
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNoBackgroundAUC, self).__init__(name=name, **kwargs)
        # Store the curve parameter as a string to aid serialization.
        self.curve = curve  
        # Create one AUC metric per target column (columns 1-4).
        self.auc_metrics = [
            metrics.AUC(curve=self.curve, name=f'auc_col_{i+1}')
            for i in range(4)
        ]

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Ensure inputs are 2D tensors with shape (batch_size, num_classes).
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        # Select target columns (1-4) and ignore background (column 0).
        y_true_subset = y_true[:, 1:5]
        y_pred_subset = y_pred[:, 1:5]
        # For each target column, update the corresponding AUC metric.
        for i, auc_metric in enumerate(self.auc_metrics):
            # Ground truth: positive only if exactly equal to 1.
            y_true_col = tf.cast(tf.equal(y_true_subset[:, i], 1.0), tf.float32)
            y_pred_col = y_pred_subset[:, i]
            if sample_weight is not None:
                sample_weight = tf.reshape(sample_weight, [-1])
                auc_metric.update_state(y_true_col, y_pred_col, sample_weight=sample_weight)
            else:
                auc_metric.update_state(y_true_col, y_pred_col)

    def result(self):
        # Average AUC over all target columns.
        auc_results = [auc_metric.result() for auc_metric in self.auc_metrics]
        return tf.reduce_mean(auc_results)

    def reset_states(self):
        for auc_metric in self.auc_metrics:
            auc_metric.reset_states()

    def get_config(self):
        config = super(CustomNoBackgroundAUC, self).get_config()
        # Return the curve as a string.
        config.update({
            'curve': self.curve,
        })
        return config

@utils.register_keras_serializable()
class CustomNoBackgroundAccuracy(metrics.Metric):
    def __init__(self, threshold=0.5, name='no_background_accuracy', **kwargs):
        """
        Custom accuracy metric computed only for columns 1-4.

        Args:
            threshold (float): Threshold for y_pred (default 0.5).
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNoBackgroundAccuracy, self).__init__(name=name, **kwargs)
        self.threshold = threshold
        self.total_correct = self.add_weight(name='total_correct', initializer='zeros', dtype=tf.float32)
        self.total_count = self.add_weight(name='total_count', initializer='zeros', dtype=tf.float32)

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Reshape inputs to 2D tensors.
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        # Extract columns 1-4.
        y_true_subset = y_true[:, 1:5]
        y_pred_subset = y_pred[:, 1:5]
        # Binarize ground truth: positive if exactly 1.
        y_true_bin = tf.cast(tf.equal(y_true_subset, 1.0), tf.int32)
        # Binarize predictions using the threshold.
        y_pred_bin = tf.cast(y_pred_subset >= self.threshold, tf.int32)
        # Element-wise correctness.
        correct = tf.cast(tf.equal(y_true_bin, y_pred_bin), tf.float32)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            # Tile sample weights to match the shape of correct.
            sample_weight = tf.tile(sample_weight, [1, tf.shape(correct)[1]])
            correct = correct * sample_weight
            count = tf.reduce_sum(sample_weight)
        else:
            count = tf.cast(tf.size(correct), tf.float32)
        self.total_correct.assign_add(tf.reduce_sum(correct))
        self.total_count.assign_add(count)

    def result(self):
        return tf.math.divide_no_nan(self.total_correct, self.total_count)

    def reset_states(self):
        self.total_correct.assign(0)
        self.total_count.assign(0)

    def get_config(self):
        config = super(CustomNoBackgroundAccuracy, self).get_config()
        config.update({'threshold': self.threshold})
        return config

@utils.register_keras_serializable()
class CustomNoBackgroundPrecision(metrics.Metric):
    def __init__(self, threshold=0.5, average='weighted', name='no_background_precision', **kwargs):
        """
        Custom precision metric computed only for columns 1-4.

        Args:
            threshold (float): Threshold for y_pred (default 0.5).
            average (str): 'weighted' (default) or 'macro'.
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNoBackgroundPrecision, self).__init__(name=name, **kwargs)
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average
        self.num_target_columns = 4
        self.true_positives = self.add_weight(
            name='tp', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )
        self.false_positives = self.add_weight(
            name='fp', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )
        # For weighted averaging, we also need the support (true positives + false negatives).
        self.false_negatives = self.add_weight(
            name='fn', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Reshape inputs.
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        # Extract target columns (1-4).
        y_true_subset = y_true[:, 1:5]
        y_pred_subset = y_pred[:, 1:5]
        # Binarize ground truth and predictions.
        y_true_bin = tf.cast(tf.equal(y_true_subset, 1.0), tf.int32)
        y_pred_bin = tf.cast(y_pred_subset >= self.threshold, tf.int32)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.tile(sample_weight, [1, tf.shape(y_true_bin)[1]])
            y_true_bin = y_true_bin * tf.cast(sample_weight, tf.int32)
            y_pred_bin = y_pred_bin * tf.cast(sample_weight, tf.int32)
        # Compute counts per column.
        tp = tf.reduce_sum(tf.cast(y_true_bin * y_pred_bin, tf.float32), axis=0)
        fp = tf.reduce_sum(tf.cast((1 - y_true_bin) * y_pred_bin, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true_bin * (1 - y_pred_bin), tf.float32), axis=0)
        self.true_positives.assign_add(tp)
        self.false_positives.assign_add(fp)
        self.false_negatives.assign_add(fn)

    def result(self):
        # Precision: TP / (TP + FP)
        precision = tf.math.divide_no_nan(self.true_positives, self.true_positives + self.false_positives)
        if self.average == 'weighted':
            # Weight each column by its support (TP + FN).
            support = self.true_positives + self.false_negatives
            weighted_precision = tf.reduce_sum(precision * support) / (tf.reduce_sum(support) + K.epsilon())
            return weighted_precision
        else:  # macro
            return tf.reduce_mean(precision)

    def reset_states(self):
        self.true_positives.assign(tf.zeros_like(self.true_positives))
        self.false_positives.assign(tf.zeros_like(self.false_positives))
        self.false_negatives.assign(tf.zeros_like(self.false_negatives))

    def get_config(self):
        config = super(CustomNoBackgroundPrecision, self).get_config()
        config.update({
            'threshold': self.threshold,
            'average': self.average,
        })
        return config


@utils.register_keras_serializable()
class CustomNoBackgroundRecall(metrics.Metric):
    def __init__(self, threshold=0.5, average='weighted', name='no_background_recall', **kwargs):
        """
        Custom recall metric computed only for columns 1-4.

        Args:
            threshold (float): Threshold for y_pred (default 0.5).
            average (str): 'weighted' (default) or 'macro'.
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomNoBackgroundRecall, self).__init__(name=name, **kwargs)
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average
        self.num_target_columns = 4
        self.true_positives = self.add_weight(
            name='tp', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )
        self.false_negatives = self.add_weight(
            name='fn', shape=(self.num_target_columns,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Reshape inputs.
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        # Extract target columns (1-4).
        y_true_subset = y_true[:, 1:5]
        y_pred_subset = y_pred[:, 1:5]
        # Binarize ground truth and predictions.
        y_true_bin = tf.cast(tf.equal(y_true_subset, 1.0), tf.int32)
        y_pred_bin = tf.cast(y_pred_subset >= self.threshold, tf.int32)
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.tile(sample_weight, [1, tf.shape(y_true_bin)[1]])
            y_true_bin = y_true_bin * tf.cast(sample_weight, tf.int32)
            y_pred_bin = y_pred_bin * tf.cast(sample_weight, tf.int32)
        # Compute per-column true positives and false negatives.
        tp = tf.reduce_sum(tf.cast(y_true_bin * y_pred_bin, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true_bin * (1 - y_pred_bin), tf.float32), axis=0)
        self.true_positives.assign_add(tp)
        self.false_negatives.assign_add(fn)

    def result(self):
        # Recall: TP / (TP + FN)
        recall = tf.math.divide_no_nan(self.true_positives, self.true_positives + self.false_negatives)
        if self.average == 'weighted':
            support = self.true_positives + self.false_negatives
            weighted_recall = tf.reduce_sum(recall * support) / (tf.reduce_sum(support) + K.epsilon())
            return weighted_recall
        else:
            return tf.reduce_mean(recall)

    def reset_states(self):
        self.true_positives.assign(tf.zeros_like(self.true_positives))
        self.false_negatives.assign(tf.zeros_like(self.false_negatives))

    def get_config(self):
        config = super(CustomNoBackgroundRecall, self).get_config()
        config.update({
            'threshold': self.threshold,
            'average': self.average,
        })
        return config
@utils.register_keras_serializable()
class CustomBackgroundOnlyF1Score(metrics.Metric):
    def __init__(self, num_classes, average='weighted', threshold=0.5, name='background_only_f1', **kwargs):
        """
        Custom F1 score metric that only considers the dominant (background) class (index 0).

        This metric is designed for multi-encoded labels where:
          - The dominant class (index 0) aka background is represented as a hard label [1, 0, 0, ...].
          - For the dominant class, a label is considered positive only if it is exactly 1.
          - Predictions are thresholded (default threshold = 0.5) to decide 1 vs. 0.

        Args:
            num_classes (int): Total number of classes.
            average (str): 'weighted' (default) or 'macro'. (Since only one class is considered, this
                           choice won’t make much difference.)
            threshold (float): Threshold on y_pred to decide a positive (default 0.5).
            name (str): Name of the metric.
            **kwargs: Additional keyword arguments.
        """
        super(CustomBackgroundOnlyF1Score, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.threshold = threshold
        if average not in ['weighted', 'macro']:
            raise ValueError("average must be 'weighted' or 'macro'")
        self.average = average

        # We still create vectors of length num_classes, but will only update index 0.
        self.true_positives = self.add_weight(
            name='tp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_positives = self.add_weight(
            name='fp', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )
        self.false_negatives = self.add_weight(
            name='fn', shape=(num_classes,), initializer='zeros', dtype=tf.float32
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        Updates the metric state using only the dominant class (index 0).

        Args:
            y_true: Tensor of shape (batch_size, num_classes). For the dominant class,
                    a label is considered positive only if it is exactly 1.
            y_pred: Tensor of shape (batch_size, num_classes) (e.g. probabilities).
            sample_weight: Optional sample weights.
        """
        # Reshape to (-1, num_classes) in case additional dimensions exist.
        y_true = tf.reshape(y_true, [-1, self.num_classes])
        y_pred = tf.reshape(y_pred, [-1, self.num_classes])

        # Extract the dominant class (index 0)
        y_true_dominant = y_true[:, 0]
        y_pred_dominant = y_pred[:, 0]

        # For ground truth, treat as positive only if exactly equal to 1.
        one_value = tf.cast(1.0, dtype=y_true_dominant.dtype)
        y_true_bin = tf.cast(tf.equal(y_true_dominant, one_value), tf.float32)

        # For predictions, apply thresholding.
        y_pred_bin = tf.cast(y_pred_dominant >= self.threshold, tf.float32)

        # Optionally apply sample weighting.
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, tf.float32)
            sample_weight = tf.reshape(sample_weight, [-1])
            y_true_bin = y_true_bin * sample_weight
            y_pred_bin = y_pred_bin * sample_weight

        # Compute true positives, false positives, and false negatives for the dominant class.
        tp = tf.reduce_sum(y_true_bin * y_pred_bin)
        fp = tf.reduce_sum((1 - y_true_bin) * y_pred_bin)
        fn = tf.reduce_sum(y_true_bin * (1 - y_pred_bin))

        # We create update vectors that place the computed scalar at index 0 and zeros elsewhere.
        zeros = tf.zeros([self.num_classes - 1], dtype=tf.float32)
        tp_update = tf.concat([[tp], zeros], axis=0)
        fp_update = tf.concat([[fp], zeros], axis=0)
        fn_update = tf.concat([[fn], zeros], axis=0)

        self.true_positives.assign_add(tp_update)
        self.false_positives.assign_add(fp_update)
        self.false_negatives.assign_add(fn_update)

    def result(self):
        """
        Computes the F1 score for the dominant (background) class (index 0).
        """
        tp = self.true_positives[0]
        fp = self.false_positives[0]
        fn = self.false_negatives[0]

        precision = tf.math.divide_no_nan(tp, tp + fp)
        recall = tf.math.divide_no_nan(tp, tp + fn)
        f1 = tf.math.divide_no_nan(2 * precision * recall, precision + recall)

        # Although averaging is not critical with a single class, we mirror the interface.
        if self.average == 'weighted':
            support = tp + fn
            weighted_f1 = tf.math.divide_no_nan(f1 * support, support + K.epsilon())
            return weighted_f1
        else:  # macro
            return f1

    def reset_states(self):
        """
        Resets all of the metric state variables.
        """
        for v in self.variables:
            v.assign(tf.zeros_like(v))
            
    def get_config(self):
        """
        Returns the configuration of the metric, so it can be recreated later.
        """
        config = super(CustomBackgroundOnlyF1Score, self).get_config()
        config.update({
            'num_classes': self.num_classes,
            'average': self.average,
            'threshold': self.threshold,
        })
        return config

NameError: name 'utils' is not defined

Custom loss function, made to get granular control over how the model learns.  In particular, it addresses the background column (dominant class) and "smoothing" data that the basic keras binary focal loss function would not handle well.

In [None]:
@utils.register_keras_serializable()
class CustomBinaryFocalLoss(losses.Loss):
    def __init__(self,
                 dominant_class_index=0,
                 # Dominant class multipliers
                 dominant_correct_multiplier=0.99,    # Reward when dominant class is correct
                 dominant_incorrect_multiplier=2.5,     # Penalty when dominant class is incorrect
                 # Expanded non-dominant multipliers for hard labels
                 other_class_true_positive_multiplier=0.05,   # Reward when y_true==1 and prediction is positive
                 other_class_false_negative_multiplier=3.0,     # Punish when y_true==1 but prediction is negative
                 other_class_false_positive_multiplier=1.0,     # Punish when y_true==0 but prediction is positive
                 other_class_true_negative_multiplier=0.99,     # Reward when y_true==0 and prediction is negative
                 # For smoothed labels (0 < y_true < 1)
                 smoothing_multiplier=0.5,              # Scales the effect of a smoothed label
                 smoothing_as_correct=True,             # If True, a high prediction on a smoothed label is rewarded; else, punished
                 threshold=0.5,                         # Threshold to decide if a prediction is "positive"
                 # Focal loss parameters
                 focal_gamma=2.0,                       # Focusing parameter gamma
                 focal_alpha=0.25,                      # Balance parameter alpha
                 name="custom_binary_focal_loss",
                 reduction="sum_over_batch_size"):
        super().__init__(name=name)
        self.dominant_class_index = dominant_class_index
        self.dominant_correct_multiplier = dominant_correct_multiplier
        self.dominant_incorrect_multiplier = dominant_incorrect_multiplier

        self.other_class_true_positive_multiplier = other_class_true_positive_multiplier
        self.other_class_false_negative_multiplier = other_class_false_negative_multiplier
        self.other_class_false_positive_multiplier = other_class_false_positive_multiplier
        self.other_class_true_negative_multiplier = other_class_true_negative_multiplier

        self.smoothing_multiplier = smoothing_multiplier
        self.smoothing_as_correct = smoothing_as_correct
        self.threshold = threshold

        self.focal_gamma = focal_gamma
        self.focal_alpha = focal_alpha

    def call(self, y_true, y_pred):
        # Prevent log(0) issues.
        epsilon = K.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
        
        # Reshape to (batch_size, num_classes)
        y_true = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
        y_pred = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
        
        # Compute the focal loss elementwise.
        # For each element, p_t = y_pred if y_true==1, else 1 - y_pred.
        p_t = tf.where(tf.equal(y_true, tf.constant(1.0, dtype=y_true.dtype)), y_pred, 1 - y_pred)
        focal_loss = - self.focal_alpha * tf.pow(1 - p_t, self.focal_gamma) * tf.math.log(p_t)
        
        # Determine the number of classes.
        num_classes = tf.shape(y_true)[1]
        
        # Create masks for the dominant vs. non-dominant classes.
        dominant_mask = tf.one_hot(self.dominant_class_index, depth=num_classes, dtype=tf.float32)
        non_dominant_mask = tf.cast(1.0 - dominant_mask, dtype=tf.float32)
        
        # === Dominant Class Weighting ===
        # For the dominant class, use one multiplier if y_true==1 and another if y_true==0.
        dominant_true = y_true[:, self.dominant_class_index]  # shape: (batch_size,)
        dominant_weight = tf.where(
            tf.equal(dominant_true, tf.constant(1.0, dtype=y_true.dtype)),
            tf.constant(self.dominant_correct_multiplier, dtype=y_true.dtype),
            tf.constant(self.dominant_incorrect_multiplier, dtype=y_true.dtype)
        )
        dominant_weight = tf.expand_dims(dominant_weight, axis=1)  # shape: (batch_size, 1)
        
        # === Non-Dominant Class Weighting ===
        # Distinguish between hard labels (exactly 0 or 1) and smoothed labels (0 < y_true < 1).
        is_hard_positive = tf.equal(y_true, tf.constant(1.0, dtype=y_true.dtype))
        is_hard_negative = tf.equal(y_true, tf.constant(0.0, dtype=y_true.dtype))
        is_hard = tf.logical_or(is_hard_positive, is_hard_negative)
        
        # Determine if the prediction is "positive" (i.e. y_pred >= threshold).
        pred_positive = tf.greater_equal(y_pred, tf.constant(self.threshold, dtype=y_true.dtype))
        
        # For hard labels:
        #   - If y_true==1:
        #       * If prediction is positive: use true positive multiplier.
        #       * Else: use false negative multiplier.
        #   - If y_true==0:
        #       * If prediction is positive: use false positive multiplier.
        #       * Else: use true negative multiplier.
        hard_weight = tf.where(
            tf.equal(y_true, tf.constant(1.0, dtype=y_true.dtype)),
            tf.where(
                pred_positive,
                tf.constant(self.other_class_true_positive_multiplier, dtype=y_true.dtype),
                tf.constant(self.other_class_false_negative_multiplier, dtype=y_true.dtype)
            ),
            tf.where(
                tf.equal(y_true, tf.constant(0.0, dtype=y_true.dtype)),
                tf.where(
                    pred_positive,
                    tf.constant(self.other_class_false_positive_multiplier, dtype=y_true.dtype),
                    tf.constant(self.other_class_true_negative_multiplier, dtype=y_true.dtype)
                ),
                tf.constant(1.0, dtype=y_true.dtype)  # fallback; should not occur for a hard label.
            )
        )
        
        # For smoothed labels: (values strictly between 0 and 1)
        is_smoothed = tf.logical_and(
            tf.greater(y_true, tf.constant(0.0, dtype=y_true.dtype)),
            tf.less(y_true, tf.constant(1.0, dtype=y_true.dtype))
        )
        if self.smoothing_as_correct:
            smoothed_weight = tf.where(
                pred_positive,
                (1.0 - y_true) * self.smoothing_multiplier,  # reward by reducing the loss, smaller reduction for further distance
                1.0 * self.other_class_false_positive_multiplier   # punish for predicting a false positive
            )
        # elif self.smoothing_as_correct == None:
            
        else:
            smoothed_weight = tf.where(
                pred_positive,
                1.0 + (1-y_true) * self.smoothing_multiplier,  # punish, punishment increases with distance
                1.0 * self.other_class_true_negative_multiplier   # reward for predicting a true negative
            )
        
        # Combine weights for non-dominant classes.
        non_dominant_weight = tf.where(
            is_hard,
            hard_weight,
            tf.where(
                is_smoothed,
                smoothed_weight,
                tf.constant(1.0, dtype=y_true.dtype)  # fallback
            )
        )
        
        # Reshape the masks so they broadcast properly.
        dominant_mask = tf.reshape(dominant_mask, tf.stack([tf.constant(1, dtype=tf.int32), num_classes]))
        non_dominant_mask = tf.reshape(non_dominant_mask, tf.stack([tf.constant(1, dtype=tf.int32), num_classes]))
        
        # Combine weights: for each sample and class,
        # use dominant_weight for the dominant class and non_dominant_weight for others.
        weights = dominant_mask * dominant_weight + non_dominant_mask * non_dominant_weight
        
        # Compute the final weighted loss.
        weighted_loss = focal_loss * weights
        return tf.reduce_mean(weighted_loss)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'dominant_class_index': self.dominant_class_index,
            'dominant_correct_multiplier': self.dominant_correct_multiplier,
            'dominant_incorrect_multiplier': self.dominant_incorrect_multiplier,
            'other_class_true_positive_multiplier': self.other_class_true_positive_multiplier,
            'other_class_false_negative_multiplier': self.other_class_false_negative_multiplier,
            'other_class_false_positive_multiplier': self.other_class_false_positive_multiplier,
            'other_class_true_negative_multiplier': self.other_class_true_negative_multiplier,
            'smoothing_multiplier': self.smoothing_multiplier,
            'smoothing_as_correct': self.smoothing_as_correct,
            'threshold': self.threshold,
            'focal_gamma': self.focal_gamma,
            'focal_alpha': self.focal_alpha
        })
        return config

In [None]:
@utils.register_keras_serializable()
def tile_to_batch(z):
    pe, x = z
    return tf.tile(pe, [tf.shape(x)[0], 1, 1])

@utils.register_keras_serializable()
def create_dcnn_model(
    input_dim=5,
    sequence_length=5000,
    num_classes=5
):
    inputs = Input(shape=(sequence_length, input_dim))
    
    # Condensed positional encoding block.  See cnn for description
    positions = tf.range(start=0, limit=sequence_length, delta=1)
    pos_encoding = layers.Embedding(input_dim=sequence_length, output_dim=num_classes)(positions)
    pos_encoding = tf.expand_dims(pos_encoding, axis=0)
    # def tile_to_batch(z):
    #     pe, x = z
    #     return tf.tile(pe, [tf.shape(x)[0], 1, 1])
    pos_encoding = layers.Lambda(tile_to_batch)([pos_encoding, inputs])

    concat_input = layers.Concatenate(axis=-1)([inputs, pos_encoding])
    
    '''Initial training hyperparameters'''
    early_dropout = 0
    middle_dropout = 0.1
    late_dropout = 0.2

    cnn = layers.Conv1D(filters=64, kernel_size=9, activation='relu', padding='same')(concat_input)
    cnn = layers.BatchNormalization()(cnn)
    cnn = layers.Dropout(early_dropout)(cnn)
    # Uses six layers with increasing dilation rates to capture a wider receptive field.
    # Dilating convolutional blocks with dropout (pooling is bad because exact sequence matters)
    skip = concat_input
    skip = layers.Conv1D(filters=64, kernel_size=1, padding='same')(skip)
    dcnn = layers.Conv1D(filters=64, kernel_size=9, dilation_rate=1, activation='relu', padding='same')(skip)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(early_dropout)(dcnn)
    low_dcnn = dcnn
    
    dcnn = layers.Conv1D(filters=64, kernel_size=9, dilation_rate=2, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(early_dropout)(dcnn)
    dcnn = layers.Add()([dcnn, skip])
    
    skip = dcnn
    skip = layers.Conv1D(filters=160, kernel_size=1, padding='same')(skip)
    dcnn = layers.Conv1D(filters=160, kernel_size=9, dilation_rate=4, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout)(dcnn)
    
    dcnn = layers.Conv1D(filters=160, kernel_size=9, dilation_rate=8, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout)(dcnn)
    dcnn = layers.Add()([dcnn, skip])
    
    skip = dcnn
    skip = layers.Conv1D(filters=192, kernel_size=1, padding='same')(skip)
    dcnn = layers.Conv1D(filters=192, kernel_size=9, dilation_rate=16, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout)(dcnn)
    
    dcnn = layers.Conv1D(filters=192, kernel_size=9, dilation_rate=32, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout)(dcnn)
    dcnn = layers.Add()([dcnn, skip])
    
    skip = dcnn
    skip = layers.Conv1D(filters=192, kernel_size=1, padding='same')(skip)
    dcnn = layers.Conv1D(filters=192, kernel_size=9, dilation_rate=64, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout)(dcnn)
    
    dcnn = layers.Conv1D(filters=192, kernel_size=9, dilation_rate=128, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout)(dcnn)
    dcnn = layers.Add()([dcnn, skip])
        
    second_concat = layers.Concatenate(axis=-1)([concat_input, cnn, dcnn, low_dcnn])

    # Instead of flattening, use Conv1D with kernel_size=1 as dense layers:
    dense = layers.Conv1D(128, kernel_size=1, activation='relu')(second_concat)
    dense = layers.BatchNormalization()(dense)
    dense = layers.Dropout(late_dropout)(dense)
    
    dense = layers.Conv1D(128, kernel_size=1, activation='relu')(dense)
    dense = layers.BatchNormalization()(dense)
    dense = layers.Dropout(late_dropout)(dense)

    # Final classification layer applied at every time step:
    outputs = layers.Conv1D(num_classes, kernel_size=1, activation='sigmoid')(dense)

    model = Model(inputs=inputs, outputs=outputs)
    return model


# This part exists to view the model summary since it is inside a function definition
metrics_lst=[
        CustomNoBackgroundF1Score(num_classes=5, threshold=0.5, average='weighted'),  # Existing F1 metric
        CustomConditionalF1Score(threshold=0.5, average='weighted', filter_mode='pred'),  # 'pred' 'true' or 'either'
        CustomConditionalF1Score(threshold=0.5, average='weighted', filter_mode='true'),  # 'pred' 'true' or 'either'
        CustomFalsePositiveDistance(num_classes=5, threshold=0.5, window=100),
        CustomNoBackgroundAUC(curve='PR'),
        CustomNoBackgroundAccuracy(threshold=0.5),
        CustomNoBackgroundPrecision(threshold=0.5, average='weighted'),
        CustomNoBackgroundRecall(threshold=0.5, average='weighted'),
        CustomBackgroundOnlyF1Score(num_classes=5, threshold=0.5, average='weighted')
    ]
loss_fn = CustomBinaryFocalLoss(
        dominant_class_index=0,
        dominant_correct_multiplier=0.98,
        dominant_incorrect_multiplier=2.5,
        other_class_true_positive_multiplier=0.075,
        other_class_false_negative_multiplier=5.5,
        other_class_false_positive_multiplier=2.0,
        other_class_true_negative_multiplier=0.98,
        smoothing_multiplier=0.06,
        smoothing_as_correct=True,
        threshold=0.5,
        focal_gamma=2.0,
        focal_alpha=0.25
    )

decay_steps = 10000 # realistically, make this a function of total epochs
initial_learning_rate = 0.001
warmup_steps = 1000
target_learning_rate = 0.01
lr_warmup_decayed_fn = optimizers.schedules.CosineDecay(
    initial_learning_rate, decay_steps, warmup_target=target_learning_rate,
    warmup_steps=warmup_steps
)

dcnn_model = create_dcnn_model(5, 5000, 5)
dcnn_model.compile(
                optimizer=optimizers.Adam(learning_rate=lr_warmup_decayed_fn),
                loss=loss_fn,
                metrics=metrics_lst
                  )
dcnn_model.summary()

# Original training parameters after tuning
# 'explore_filters_1': 64, 'explore_filters_2': 160, 'explore_filters_3': 192, 'explore_kernel_size': 9, 'explore_dropout': 0.4, 'learning_rate': 0.0005343042689938801, 'dominant_correct_multiplier': 0.95, 'dominant_incorrect_multiplier': 2.0, 
#  'other_class_true_positive_multiplier': 0.125, 'other_class_false_negative_multiplier': 5.0, 'other_class_false_positive_multiplier': 2.0, 'other_class_true_negative_multiplier': 1.0, 'smoothing_multiplier': 0.30000000000000004, 

The next cell is for reading in tfrecords

In [None]:
def drop_exact_records(dataset: tf.data.Dataset, total_records, num_to_drop, seed=None):
    '''
    Function to drop n records from data before constructing parsed dataset.  
    Mostly for bug checking.
    '''
    if seed:
        np.random.seed(seed)
    drop_indices = set(np.random.choice(total_records, num_to_drop, replace=False))
    dataset = dataset.enumerate()
    dataset = dataset.filter(lambda i, x: ~tf.reduce_any(tf.equal(i, list(drop_indices))))
    dataset = dataset.map(lambda i, x: x)
    return dataset


def parse_chunk_example(serialized_example):
    """
    Parses a single serialized tf.train.Example back into tensors.
    Used in testing datasets and in piping tfrecords to DL Algorithms
    """
    feature_spec = {
        'X':          tf.io.VarLenFeature(tf.float32),
        'y':          tf.io.VarLenFeature(tf.float32),
        'record_id':  tf.io.FixedLenFeature([], tf.string),
        'cstart':     tf.io.FixedLenFeature([1], tf.int64),
        'cend':       tf.io.FixedLenFeature([1], tf.int64),
        'strand':     tf.io.FixedLenFeature([], tf.string),
        'chunk_size': tf.io.FixedLenFeature([1], tf.int64),
    }
    
    parsed = tf.io.parse_single_example(serialized_example, feature_spec)
    
    # chunk_size is shape [1]
    chunk_size = parsed['chunk_size'][0]
    
    # Convert sparse to dense
    X_flat = tf.sparse.to_dense(parsed['X'])
    y_flat = tf.sparse.to_dense(parsed['y'])

    # Reshape X to [chunk_size, 5]
    X_reshaped = tf.reshape(X_flat, [chunk_size, 5])
    # Reshape y to [chunk_size], probably redundant
    y_reshaped = tf.reshape(y_flat, [chunk_size, 5])
    
    record_id = parsed['record_id']
    cstart = parsed['cstart'][0]
    cend = parsed['cend'][0]
    strand = parsed['strand']
    
    return X_reshaped, y_reshaped, record_id, cstart, cend, strand


def prepare_for_model(X, y, record_id, cstart, cend, strand):
    '''
    Helper function that extracts and reshapes parsed data for feeding to DL Models
    '''
    # Expand last dimension of y from (batch_size, 5000) to (batch_size, 5000, 1)
    # y = tf.expand_dims(y, axis=-1) turns out this line is not needed
    # Return only (X, y). Discard the extra columns for training knowing that 
    # they still exist in the TestValTrain originals if we need them
    return X, y


def prep_dataset_from_tfrecord(
    tfrecord_path,
    batch_size=28,
    compression_type='GZIP',
    shuffled = False,
    shuffle_buffer=25000,
    total_records=None,
    num_to_drop=None,
    seed=None
):
    '''
    Imports tfrecord and shuffles it then parses it for use in fitting a model
    '''
    # Loads in records in a round robin fashion for slightly increased mixing
    dataset = tf.data.TFRecordDataset(tfrecord_path, compression_type=compression_type, num_parallel_reads = tf.data.AUTOTUNE)
    
    if num_to_drop:
        dataset = drop_exact_records(dataset, total_records=total_records, num_to_drop=num_to_drop, seed=seed)
    
    if shuffled == True:
        # Shuffle at the record level
        dataset = dataset.shuffle(shuffle_buffer, reshuffle_each_iteration=True)
        
    
    dataset = dataset.map(parse_chunk_example, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(prepare_for_model, num_parallel_calls=tf.data.AUTOTUNE)
    # dataset = dataset.map(lambda x, y: (x, tf.cast(y, tf.int32))) # found out tensorflow wants int32 in y # Note: Not anymore due to change in label format

    # Rebatch parsed and prefetch for efficient reading
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

This cell defines callbacks for running the model training.

In [None]:
class TimeLimit(callbacks.Callback):
    def __init__(self, max_time_seconds):
        super().__init__()
        self.max_time_seconds = max_time_seconds
        self.start_time = None

    def on_train_begin(self, logs=None):
        self.start_time = time.time()

    # def on_batch_end(self, batch, logs=None):
    #     if time.time() - self.start_time > self.max_time_seconds:
    #         self.model.stop_training = True
    
    # def on_train_batch_end(self, batch, logs=None):  # ✅ Runs more frequently than `on_batch_end`
    #     elapsed_time = time.time() - self.start_time
    #     if elapsed_time > self.max_time_seconds:
    #         print(f"\n⏳ Time limit of {self.max_time_seconds} sec reached. Stopping training!")
    #         self.model.stop_training = True  # 🔥 Stops training mid-batch
    
    def on_train_batch_begin(self, batch, logs=None):
        elapsed_time = time.time() - self.start_time
        if elapsed_time > self.max_time_seconds:
            print(f"\n⏳ Time limit of {self.max_time_seconds} sec reached. Stopping training!")
            self.model.stop_training = True

    def on_epoch_end(self, epoch, logs=None):  # New method added
        if time.time() - self.start_time > self.max_time_seconds:
            self.model.stop_training = True
            
class DebugCallback(callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        print(f"\n🚀 Starting Epoch {epoch+1}")
        sys.stdout.flush()

    def on_batch_begin(self, batch, logs=None):
        if batch % 1000 == 0:
            print(f"🔄 Processing Batch {batch}")
            sys.stdout.flush()

    def on_batch_end(self, batch, logs=None):
        if batch % 1000 == 0:
            print(f"✅ Finished Batch {batch}")
            sys.stdout.flush()

    def on_epoch_end(self, epoch, logs=None):
        print(f"\n🏁 Epoch {epoch+1} Completed!")
        sys.stdout.flush()
        
class CleanupCallback(callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # Example: force garbage collection
        gc.collect()

        # If you need more extensive cleanup, you can add it here.
        # e.g., close files, flush logs, free external resources, etc.
        print(f"Cleanup done at the end of epoch {epoch+1}")
        

checkpoint_cb = callbacks.ModelCheckpoint(
    filepath=models_path + 'checkpoints/epoch-{epoch:03d}-val_no_background_f1-{val_no_background_f1:.4f}.keras',
    # monitor='val_loss',          # what metric to name file on
    monitor='val_no_background_f1',
    mode='max',                    # Required for monitoring f1, comment out if monitoring val loss
    save_best_only=False,        # save model always 
    save_weights_only=False,     # save full model (architecture + weights)
    save_freq='epoch'
)

early_stopping_cb = callbacks.EarlyStopping(
    # monitor='val_loss',
    monitor='val_no_background_f1',
    mode='max',
    patience=20,
    min_delta=1e-4,
    restore_best_weights=True
)

class CleanupCallback(callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # Example: force garbage collection
        gc.collect()

The next cell sets up the tuner.

In [None]:
@utils.register_keras_serializable()
def create_dcnn_model_tunable(
    input_dim=5,
    sequence_length=5000,
    num_classes=5,
    early_dropout_val=0.0,
    middle_dropout_val=0.1,
    late_dropout_val=0.2
):
    K.clear_session()
    
    inputs = Input(shape=(sequence_length, input_dim))
    
    # Positional encoding block.
    positions = tf.range(start=0, limit=sequence_length, delta=1)
    pos_encoding = layers.Embedding(input_dim=sequence_length, output_dim=num_classes)(positions)
    pos_encoding = tf.expand_dims(pos_encoding, axis=0)
    pos_encoding = layers.Lambda(tile_to_batch)([pos_encoding, inputs])
    
    concat_input = layers.Concatenate(axis=-1)([inputs, pos_encoding])
    
    # First convolution branch.
    cnn = layers.Conv1D(filters=64, kernel_size=9, activation='relu', padding='same')(concat_input)
    cnn = layers.BatchNormalization()(cnn)
    cnn = layers.Dropout(early_dropout_val)(cnn)
    
    # Dilated convolution blocks.
    skip = concat_input
    skip = layers.Conv1D(filters=64, kernel_size=1, padding='same')(skip)
    dcnn = layers.Conv1D(filters=64, kernel_size=9, dilation_rate=1, activation='relu', padding='same')(skip)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(early_dropout_val)(dcnn)
    low_dcnn = dcnn
    
    dcnn = layers.Conv1D(filters=64, kernel_size=9, dilation_rate=2, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(early_dropout_val)(dcnn)
    dcnn = layers.Add()([dcnn, skip])
    
    skip = dcnn
    skip = layers.Conv1D(filters=160, kernel_size=1, padding='same')(skip)
    dcnn = layers.Conv1D(filters=160, kernel_size=9, dilation_rate=4, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout_val)(dcnn)
    
    dcnn = layers.Conv1D(filters=160, kernel_size=9, dilation_rate=8, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout_val)(dcnn)
    dcnn = layers.Add()([dcnn, skip])
    
    skip = dcnn
    skip = layers.Conv1D(filters=192, kernel_size=1, padding='same')(skip)
    dcnn = layers.Conv1D(filters=192, kernel_size=9, dilation_rate=16, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout_val)(dcnn)
    
    dcnn = layers.Conv1D(filters=192, kernel_size=9, dilation_rate=32, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout_val)(dcnn)
    dcnn = layers.Add()([dcnn, skip])
    
    skip = dcnn
    skip = layers.Conv1D(filters=192, kernel_size=1, padding='same')(skip)
    dcnn = layers.Conv1D(filters=192, kernel_size=9, dilation_rate=64, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout_val)(dcnn)
    
    dcnn = layers.Conv1D(filters=192, kernel_size=9, dilation_rate=128, activation='relu', padding='same')(dcnn)
    dcnn = layers.BatchNormalization()(dcnn)
    dcnn = layers.Dropout(middle_dropout_val)(dcnn)
    dcnn = layers.Add()([dcnn, skip])
    
    second_concat = layers.Concatenate(axis=-1)([concat_input, cnn, dcnn, low_dcnn])
    
    # "Dense" layers implemented as 1D convolutions.
    dense = layers.Conv1D(128, kernel_size=1, activation='relu')(second_concat)
    dense = layers.BatchNormalization()(dense)
    dense = layers.Dropout(late_dropout_val)(dense)
    
    dense = layers.Conv1D(128, kernel_size=1, activation='relu')(dense)
    dense = layers.BatchNormalization()(dense)
    dense = layers.Dropout(late_dropout_val)(dense)
    
    outputs = layers.Conv1D(num_classes, kernel_size=1, activation='sigmoid')(dense)
    
    model = Model(inputs=inputs, outputs=outputs)
    return model

def build_model(hp):
    # Tune dropout rates.
    early_dropout = hp.Float('early_dropout', min_value=0.0, max_value=0.3, step=0.1, default=0.1)
    middle_dropout = hp.Float('middle_dropout', min_value=0.0, max_value=0.3, step=0.1, default=0.1)
    late_dropout = hp.Float('late_dropout', min_value=0.1, max_value=0.4, step=0.1, default=0.2)
    
    # Tune the learning rate or rate scheduler.
    # learning_rate = hp.Float('learning_rate', min_value=1e-5, max_value=1e-3, sampling='log', default=0.000534)
    decay_steps = 15000
    initial_learning_rate = hp.Float('initial_learning_rate', min_value=0.0001, max_value=0.01, sampling='log', default=0.001)
    warmup_steps = hp.Int('warmup_steps', min_value=800, max_value=1500, step=100, default=1000)
    target_learning_rate = hp.Float('target_learning_rate', min_value=0.001, max_value=0.1, sampling='log', default=0.01)
    lr_warmup_decayed_fn = optimizers.schedules.CosineDecay(
        initial_learning_rate, decay_steps, warmup_target=target_learning_rate,
        warmup_steps=warmup_steps
    )
    
    # Tune the focal loss parameters.
    dominant_correct_multiplier = hp.Float('dominant_correct_multiplier', min_value=0.95, max_value=1.0, step=0.01, default=0.98)
    dominant_incorrect_multiplier = hp.Float('dominant_incorrect_multiplier', min_value=1.0, max_value=3.0, step=0.5, default=2.5)
    other_class_true_positive_multiplier = hp.Float('other_class_true_positive_multiplier', min_value=0.05, max_value=0.1, step=0.01, default=0.075)
    other_class_false_negative_multiplier = hp.Float('other_class_false_negative_multiplier', min_value=5.0, max_value=6.0, step=0.5, default=5.5)
    other_class_false_positive_multiplier = hp.Float('other_class_false_positive_multiplier', min_value=1.0, max_value=5.0, step=1.0, default=2.0)
    other_class_true_negative_multiplier = hp.Float('other_class_true_negative_multiplier', min_value=0.95, max_value=1.0, step=0.01, default=0.98)
    smoothing_multiplier = hp.Float('smoothing_multiplier', min_value=1.0, max_value=3.0, step=0.5, default=0.0)
    # Here we keep smoothing_as_correct fixed to False and threshold fixed to 0.5.
    focal_gamma = hp.Float('focal_gamma', min_value=1.0, max_value=4.0, step=0.5, default=2.0)
    focal_alpha = hp.Float('focal_alpha', min_value=0.05, max_value=0.5, step=0.05, default=0.25)
    
    # Build the model with the tunable dropout parameters.
    model = create_dcnn_model_tunable(
        input_dim=5,
        sequence_length=5000,
        num_classes=5,
        early_dropout_val=early_dropout,
        middle_dropout_val=middle_dropout,
        late_dropout_val=late_dropout
    )
    
    # Instantiate custom loss with the tuning parameters.
    loss_fn = CustomBinaryFocalLoss(
        dominant_class_index=0,
        dominant_correct_multiplier=dominant_correct_multiplier,
        dominant_incorrect_multiplier=dominant_incorrect_multiplier,
        other_class_true_positive_multiplier=other_class_true_positive_multiplier,
        other_class_false_negative_multiplier=other_class_false_negative_multiplier,
        other_class_false_positive_multiplier=other_class_false_positive_multiplier,
        other_class_true_negative_multiplier=other_class_true_negative_multiplier,
        smoothing_multiplier=smoothing_multiplier,
        smoothing_as_correct=False,
        threshold=0.5,
        focal_gamma=focal_gamma,
        focal_alpha=focal_alpha
    )
    
    metrics_lst=[
        CustomNoBackgroundF1Score(num_classes=5, threshold=0.5, average='weighted'),  # Existing F1 metric
        CustomConditionalF1Score(threshold=0.5, average='weighted', filter_mode='pred'),  # 'pred' 'true' or 'either'
        CustomConditionalF1Score(threshold=0.5, average='weighted', filter_mode='true'),  # 'pred' 'true' or 'either'
        CustomFalsePositiveDistance(num_classes=5, threshold=0.5, window=100),
        CustomNoBackgroundAUC(curve='PR'),
        CustomNoBackgroundAccuracy(threshold=0.5),
        CustomNoBackgroundPrecision(threshold=0.5, average='weighted'),
        CustomNoBackgroundRecall(threshold=0.5, average='weighted'),
        CustomBackgroundOnlyF1Score(num_classes=5, threshold=0.5, average='weighted')
    ]
    
    # Define optimizer using the tuned learning rate.
    optimizer = optimizers.Adam(learning_rate=lr_warmup_decayed_fn)
    
    # To load in model weights for re-tuning
    # model.load_weights(models_path + 'checkpoints/epoch-032-val_no_background_f1-0.6029.keras')
    
    # You can use the same metrics list defined earlier.
    model.compile(
        optimizer=optimizer,
        loss=loss_fn,
        metrics=metrics_lst
    )
    
    return model


# Set up the tuner. For example, using Hyperband:
tuner = kt.Hyperband(
    build_model,
    objective=kt.Objective("val_false_positive_distance", direction="min"),
    max_epochs=10,
    factor=3,
    directory=models_path + 'Tuning Data/DCNN_Tuner_11',
    project_name='dcnn_tuning',
    overwrite=False
)

The next cell runs the tuning.  

In [None]:
max_time_seconds = 3600/2  # 1 hour is 3600 seconds
batch_size = 28
epochs = 10  # Set high enough to allow stopping by time
steps_per_epoch = 1500

print('Compiling train dataset')
train_dataset = prep_dataset_from_tfrecord(datasets_path + "TestValTrain/train.tfrecord.gz",
                                batch_size=batch_size, 
                                compression_type='GZIP', 
                                shuffled=True,
                                shuffle_buffer=5000,
                                total_records=200985,
                                num_to_drop=1 # Batch size 28 leaves remainder of 1 record
                                )
train_dataset = train_dataset.repeat() # This is needed here because tuning a full epoch was not working and resulted in crashes.

print('Compiling val dataset')
val_dataset = prep_dataset_from_tfrecord(datasets_path + "TestValTrain/val.tfrecord.gz",
                                batch_size=batch_size, 
                                compression_type='GZIP', 
                                shuffled=False,
                                shuffle_buffer=5000,
                                total_records=23645,
                                num_to_drop=13, # Batch size 28 leaves remainder of 13 records
                                seed=42 # Seed for dropping the same 13 records every time
                                )



# stop_early = callbacks.EarlyStopping(monitor='val_loss', patience=5)
time_limit_callback = TimeLimit(max_time_seconds=max_time_seconds)

tuner.search(train_dataset, epochs=10, steps_per_epoch = steps_per_epoch, validation_data=val_dataset, callbacks=[early_stopping_cb, time_limit_callback, CleanupCallback()])

tuner.results_summary()

# Retrieve the best model and hyperparameters:
best_model = tuner.get_best_models(num_models=1)[0]
best_hyperparameters = tuner.get_best_hyperparameters(num_trials=1)[0]

print("Best Hyperparameters:")
print(best_hyperparameters.values)

In [None]:
'''This cell evaluates the validation data with each checkpoint'''

max_time_seconds = 3600*12  # 1 hour is 3600 seconds
batch_size = 28
epochs = 400  # Set high enough to allow stopping by callback
steps_per_epoch = 7178

# print('Compiling train dataset')
# train_dataset = prep_dataset_from_tfrecord(datasets_path + "TestValTrain/train.tfrecord.gz",
#                                 batch_size=batch_size, 
#                                 compression_type='GZIP', 
#                                 shuffled=True,
#                                 shuffle_buffer=10000,
#                                 total_records=200985,
#                                 num_to_drop=1 # Batch size 28 leaves remainder of 1 record
#                                 )
# train_dataset = train_dataset.repeat()

print('Compiling val dataset')
val_dataset = prep_dataset_from_tfrecord(datasets_path + "TestValTrain/val.tfrecord.gz",
                                batch_size=batch_size, 
                                compression_type='GZIP', 
                                shuffled=False,
                                shuffle_buffer=5000,
                                total_records=23645,
                                num_to_drop=13, # Batch size 28 leaves remainder of 13 records
                                seed=42 # Seed for dropping the same 13 records every time
                                )


loss_fn = loss_fn

metrics_lst=[
        CustomNoBackgroundF1Score(num_classes=5, threshold=0.5, average='weighted'),  # Existing F1 metric
        CustomConditionalF1Score(threshold=0.5, average='weighted', filter_mode='pred'),  # 'pred' 'true' or 'either'
        CustomConditionalF1Score(threshold=0.5, average='weighted', filter_mode='true'),  # 'pred' 'true' or 'either'
        CustomFalsePositiveDistance(num_classes=5, threshold=0.5, window=100),
        CustomNoBackgroundAUC(curve='PR'),
        CustomNoBackgroundAccuracy(threshold=0.5),
        CustomNoBackgroundPrecision(threshold=0.5, average='weighted'),
        CustomNoBackgroundRecall(threshold=0.5, average='weighted'),
        CustomBackgroundOnlyF1Score(num_classes=5, threshold=0.5, average='weighted')
    ]

# Define checkpoint directory
checkpoint_dir = models_path + "checkpoints"

# Store results in a list
results_list = []

# Loop through all saved model files
for filename in sorted(os.listdir(checkpoint_dir)):  # Ensure sorted order
    if filename.endswith(".keras"):  # Adjust if using TensorFlow checkpoints
        model_path = os.path.join(checkpoint_dir, filename)
        print(f"Evaluating {filename}...")

        # Load the model (Include custom loss/metrics if necessary)
        model = models.load_model(model_path) 
        model.compile(
                    loss=loss_fn,
                    metrics=metrics_lst
                    )
            

        # Evaluate on validation dataset
        results = model.evaluate(val_dataset, verbose=1)  # Suppress output

        # Store results (Modify column names as needed)
        results_list.append({"Checkpoint": filename, "Results" : results})

# Convert results to a DataFrame
df_results = pd.DataFrame(results_list)

# Display the DataFrame
print(df_results)

# Save the DataFrame to a CSV file for later analysis after renaming
# df_results.to_csv(models_path + "Results/validation_results.csv", index=False)

# print("Validation results saved to 'Results/validation_results.csv'.")

In [None]:
print(results_list)