Skip to content

Commit

Permalink
Added mcc as metric
Browse files Browse the repository at this point in the history
  • Loading branch information
adnaniazi committed Jun 6, 2024
1 parent ad5296b commit 7666e96
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 29 deletions.
56 changes: 28 additions & 28 deletions src/capfinder/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,35 +488,13 @@ def signal_handler(signum: signal.Signals, frame: Any) -> None:


if __name__ == "__main__":
# bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/8_20231114_randomCAP1v3_rna004/1_basecall_subset/sorted.calls.bam"
# pod5_dir = "/export/valenfs/data/raw_data/minion/20231114_randomCAP1v3_rna004"
# num_processes = 3
# reference = "GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT"
# cap0_pos = 52
# train_or_test = "test"
# output_dir = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/8_20231114_randomCAP1v3_rna004/test_OTE_vizs19"
# plot_signal = True
# cap_class = 1
# collate_bam_pod5(
# bam_filepath,
# pod5_dir,
# num_processes,
# reference,
# cap_class,
# cap0_pos,
# train_or_test,
# plot_signal,
# output_dir,
# )

bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data_old/7_20231025_capjump_rna004/2_alignment/1_basecalled/sorted.calls.bam"
# bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/7_20231025_capjump_rna004/1_basecall_subset/sorted.calls.bam"
pod5_dir = "/export/valenfs/data/raw_data/minion/7_20231025_capjump_rna004/20231025_CapJmpCcGFP_RNA004/20231025_1536_MN29576_FAX71885_5b8c42a6"
num_processes = 120
reference = "TTCGTCTCCGGACTTATCGCACCACCTAT"
cap0_pos = 43 # 59
bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/8_20231114_randomCAP1v3_rna004/1_basecall_subset/sorted.calls.bam"
pod5_dir = "/export/valenfs/data/raw_data/minion/20231114_randomCAP1v3_rna004"
num_processes = 3
reference = "GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT"
cap0_pos = 52
train_or_test = "test"
output_dir = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/7_20231025_capjump_rna004/output_full12"
output_dir = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/8_20231114_randomCAP1v3_rna004/test_OTE_vizs19"
plot_signal = True
cap_class = 1
collate_bam_pod5(
Expand All @@ -530,3 +508,25 @@ def signal_handler(signum: signal.Signals, frame: Any) -> None:
plot_signal,
output_dir,
)

# bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data_old/7_20231025_capjump_rna004/2_alignment/1_basecalled/sorted.calls.bam"
# # bam_filepath = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/7_20231025_capjump_rna004/1_basecall_subset/sorted.calls.bam"
# pod5_dir = "/export/valenfs/data/raw_data/minion/7_20231025_capjump_rna004/20231025_CapJmpCcGFP_RNA004/20231025_1536_MN29576_FAX71885_5b8c42a6"
# num_processes = 120
# reference = "TTCGTCTCCGGACTTATCGCACCACCTAT"
# cap0_pos = 43 # 59
# train_or_test = "test"
# output_dir = "/export/valenfs/data/processed_data/MinION/9_madcap/1_data/7_20231025_capjump_rna004/output_full12"
# plot_signal = True
# cap_class = 1
# collate_bam_pod5(
# bam_filepath,
# pod5_dir,
# num_processes,
# reference,
# cap_class,
# cap0_pos,
# train_or_test,
# plot_signal,
# output_dir,
# )
41 changes: 40 additions & 1 deletion src/capfinder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,48 @@

import keras
from keras import Model, layers
from keras import ops as kops
from keras.layers import Input
from keras_tuner import HyperModel, HyperParameters


def mcc(y_true: Input, y_pred: Input) -> Input:
"""
Calculates the Matthews Correlation Coefficient (MCC) metric using Keras functions.
Args:
y_true: Ground truth labels (one-hot encoded, shape: [None]).
y_pred: Predicted probabilities (shape: [None, num_classes]).
Returns:
A tensor with the MCC value for each sample in the batch.
"""
y_true = kops.cast(y_true, "float32") # Cast y_true to float32 for calculations

# Reshape y_pred only if necessary
if kops.ndim(y_pred) == 1:
# Reshape y_pred to match the number of classes in y_true (one-hot encoded)
num_classes = keras.backend.int_shape(y_true)[-1]
y_pred = kops.reshape(y_pred, (-1, num_classes))

# Convert probabilities to binary predictions (optional)
# You can uncomment this line if you need binary predictions
# y_pred = K.cast(y_pred > 0.5, np.float32)

# True positives, negatives, etc. (using Keras functions)
tp = kops.mean(kops.equal(y_true, kops.argmax(y_pred, axis=-1)), axis=-1)
tn = kops.mean(kops.equal(1 - y_true, 1 - kops.max(y_pred, axis=-1)), axis=-1)
fp = kops.mean(kops.equal(y_true, 1 - kops.max(y_pred, axis=-1)), axis=-1)
fn = kops.mean(kops.equal(1 - y_true, kops.max(y_pred, axis=-1)), axis=-1)

# Calculate MCC
numerator = tp * tn - fp * fn
denominator = kops.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
mcc_value = numerator / (denominator + keras.config.epsilon())

return mcc_value


def transformer_encoder(
inputs: keras.layers.Layer,
head_size: int,
Expand Down Expand Up @@ -174,7 +213,7 @@ def build(self, hp: HyperParameters) -> Model:
model.compile(
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["sparse_categorical_accuracy"],
metrics=["sparse_categorical_accuracy", mcc],
)

# Return only the full model to Keras Tuner
Expand Down
6 changes: 6 additions & 0 deletions src/capfinder/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,15 @@ def run_training_pipeline(
)
experiment.end()

# Convert the integer labels to one-hot encoded labels
# y_train = keras.utils.to_categorical(y_train, etl_params["n_classes"])
# y_test = keras.utils.to_categorical(y_test, etl_params["n_classes"])

logger.info("Dataset loaded successfully!")
logger.info(f"x_train shape: {x_train.shape}")
logger.info(f"x_test shape: {x_test.shape}")
logger.info(f"y_train shape: {y_train.shape}")
logger.info(f"y_test shape: {y_test.shape}")
logger.info(f"Dataset version: {dataset_info["version"]}")

"""
Expand Down

0 comments on commit 7666e96

Please sign in to comment.