diff --git a/src/capfinder/collate.py b/src/capfinder/collate.py index 5e23c8f..756d83f 100644 --- a/src/capfinder/collate.py +++ b/src/capfinder/collate.py @@ -488,35 +488,13 @@ def signal_handler(signum: signal.Signals, frame: Any) -> None: if __name__ == "__main__": - # bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/8_20231114_randomCAP1v3_rna004/1_basecall_subset/sorted.calls.bam" - # pod5_dir = "/export/valenfs/data/raw_data/minion/20231114_randomCAP1v3_rna004" - # num_processes = 3 - # reference = "GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT" - # cap0_pos = 52 - # train_or_test = "test" - # output_dir = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/8_20231114_randomCAP1v3_rna004/test_OTE_vizs19" - # plot_signal = True - # cap_class = 1 - # collate_bam_pod5( - # bam_filepath, - # pod5_dir, - # num_processes, - # reference, - # cap_class, - # cap0_pos, - # train_or_test, - # plot_signal, - # output_dir, - # ) - - bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data_old/7_20231025_capjump_rna004/2_alignment/1_basecalled/sorted.calls.bam" - # bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/7_20231025_capjump_rna004/1_basecall_subset/sorted.calls.bam" - pod5_dir = "/export/valenfs/data/raw_data/minion/7_20231025_capjump_rna004/20231025_CapJmpCcGFP_RNA004/20231025_1536_MN29576_FAX71885_5b8c42a6" - num_processes = 120 - reference = "TTCGTCTCCGGACTTATCGCACCACCTAT" - cap0_pos = 43 # 59 + bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/8_20231114_randomCAP1v3_rna004/1_basecall_subset/sorted.calls.bam" + pod5_dir = "/export/valenfs/data/raw_data/minion/20231114_randomCAP1v3_rna004" + num_processes = 3 + reference = "GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT" + cap0_pos = 52 train_or_test = "test" - output_dir = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/7_20231025_capjump_rna004/output_full12" + output_dir = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/8_20231114_randomCAP1v3_rna004/test_OTE_vizs19" plot_signal = True cap_class = 1 collate_bam_pod5( @@ -530,3 +508,25 @@ def signal_handler(signum: signal.Signals, frame: Any) -> None: plot_signal, output_dir, ) + + # bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data_old/7_20231025_capjump_rna004/2_alignment/1_basecalled/sorted.calls.bam" + # # bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/7_20231025_capjump_rna004/1_basecall_subset/sorted.calls.bam" + # pod5_dir = "/export/valenfs/data/raw_data/minion/7_20231025_capjump_rna004/20231025_CapJmpCcGFP_RNA004/20231025_1536_MN29576_FAX71885_5b8c42a6" + # num_processes = 120 + # reference = "TTCGTCTCCGGACTTATCGCACCACCTAT" + # cap0_pos = 43 # 59 + # train_or_test = "test" + # output_dir = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/7_20231025_capjump_rna004/output_full12" + # plot_signal = True + # cap_class = 1 + # collate_bam_pod5( + # bam_filepath, + # pod5_dir, + # num_processes, + # reference, + # cap_class, + # cap0_pos, + # train_or_test, + # plot_signal, + # output_dir, + # ) diff --git a/src/capfinder/model.py b/src/capfinder/model.py index 515bee8..570fb02 100644 --- a/src/capfinder/model.py +++ b/src/capfinder/model.py @@ -2,9 +2,48 @@ import keras from keras import Model, layers +from keras import ops as kops +from keras.layers import Input from keras_tuner import HyperModel, HyperParameters +def mcc(y_true: Input, y_pred: Input) -> Input: + """ + Calculates the Matthews Correlation Coefficient (MCC) metric using Keras functions. + + Args: + y_true: Ground truth labels (one-hot encoded, shape: [None]). + y_pred: Predicted probabilities (shape: [None, num_classes]). + + Returns: + A tensor with the MCC value for each sample in the batch. + """ + y_true = kops.cast(y_true, "float32") # Cast y_true to float32 for calculations + + # Reshape y_pred only if necessary + if kops.ndim(y_pred) == 1: + # Reshape y_pred to match the number of classes in y_true (one-hot encoded) + num_classes = keras.backend.int_shape(y_true)[-1] + y_pred = kops.reshape(y_pred, (-1, num_classes)) + + # Convert probabilities to binary predictions (optional) + # You can uncomment this line if you need binary predictions + # y_pred = K.cast(y_pred > 0.5, np.float32) + + # True positives, negatives, etc. (using Keras functions) + tp = kops.mean(kops.equal(y_true, kops.argmax(y_pred, axis=-1)), axis=-1) + tn = kops.mean(kops.equal(1 - y_true, 1 - kops.max(y_pred, axis=-1)), axis=-1) + fp = kops.mean(kops.equal(y_true, 1 - kops.max(y_pred, axis=-1)), axis=-1) + fn = kops.mean(kops.equal(1 - y_true, kops.max(y_pred, axis=-1)), axis=-1) + + # Calculate MCC + numerator = tp * tn - fp * fn + denominator = kops.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) + mcc_value = numerator / (denominator + keras.config.epsilon()) + + return mcc_value + + def transformer_encoder( inputs: keras.layers.Layer, head_size: int, @@ -174,7 +213,7 @@ def build(self, hp: HyperParameters) -> Model: model.compile( loss="sparse_categorical_crossentropy", optimizer="adam", - metrics=["sparse_categorical_accuracy"], + metrics=["sparse_categorical_accuracy", mcc], ) # Return only the full model to Keras Tuner diff --git a/src/capfinder/training.py b/src/capfinder/training.py index 5a603ea..aade0a2 100644 --- a/src/capfinder/training.py +++ b/src/capfinder/training.py @@ -346,9 +346,15 @@ def run_training_pipeline( ) experiment.end() + # Convert the integer labels to one-hot encoded labels + # y_train = keras.utils.to_categorical(y_train, etl_params["n_classes"]) + # y_test = keras.utils.to_categorical(y_test, etl_params["n_classes"]) + logger.info("Dataset loaded successfully!") logger.info(f"x_train shape: {x_train.shape}") logger.info(f"x_test shape: {x_test.shape}") + logger.info(f"y_train shape: {y_train.shape}") + logger.info(f"y_test shape: {y_test.shape}") logger.info(f"Dataset version: {dataset_info["version"]}") """