From c0c94e6687bb91d6ddaaf36c038087bb3c22ea91 Mon Sep 17 00:00:00 2001 From: adnaniazi Date: Sun, 9 Jun 2024 16:26:06 +0200 Subject: [PATCH] Improved cli --- src/capfinder/cli.py | 126 +++++++++++++++++++++++++++---------------- 1 file changed, 81 insertions(+), 45 deletions(-) diff --git a/src/capfinder/cli.py b/src/capfinder/cli.py index f074125..0457f9c 100644 --- a/src/capfinder/cli.py +++ b/src/capfinder/cli.py @@ -8,21 +8,31 @@ app = typer.Typer( help="capfinder: A Python package for decoding RNA cap types using an encoder-based deep learning model.", - add_completion=False, + add_completion=True, rich_markup_mode="rich", ) @app.command() -def fetch_cap_signal( - bam_filepath: Annotated[str, typer.Argument(help="Path to the BAM file")], +def extract_cap_signal( + bam_filepath: Annotated[ + str, typer.Option("--bam_filepath", "-b", help="Path to the BAM file") + ] = "", pod5_dir: Annotated[ - str, typer.Argument(help="Path to directory containing POD5 files") - ], - reference: Annotated[str, typer.Argument(help="Reference Sequence (5' -> 3')")], + str, + typer.Option( + "--pod5_dir", "-p", help="Path to directory containing POD5 files" + ), + ] = "", + reference: Annotated[ + str, + typer.Option("--reference", "-r", help="Reference Sequence (5' -> 3')"), + ] = "GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT", cap_class: Annotated[ int, - typer.Argument( + typer.Option( + "--cap_class", + "-c", help="""\n Integer-based class label for the RNA cap type. \n - 0 represents Cap_0 \n @@ -33,24 +43,30 @@ def fetch_cap_signal( - 5 represents NAD Cap \n - 6 represents FAD Cap \n - -99 represents and unknown cap(s). \n - """ + """, ), - ], + ] = -99, cap_n1_pos0: Annotated[ int, - typer.Argument( - help="0-based index of 1st nucleotide (N1) of cap in the reference" + typer.Option( + "--cap_n1_pos0", + "-p", + help="0-based index of 1st nucleotide (N1) of cap in the reference", ), - ], + ] = 52, train_or_test: Annotated[ str, - typer.Argument( - help="set to train or test depending on whether it is training or testing data" + typer.Option( + "--train_or_test", + "-t", + help="set to train or test depending on whether it is training or testing data", ), - ], + ] = "test", output_dir: Annotated[ str, - typer.Argument( + typer.Option( + "--output_dir", + "-o", help=textwrap.dedent( """ Path to the output directory which will contain: \n @@ -60,15 +76,14 @@ def fetch_cap_signal( └── (Optional) plots directory containing cap signal plots, if plot_signal is set to True.\n \u200B ├── good_reads: Directory that contains the plots for the good reads.\n \u200B ├── bad_reads: Directory that contains the plots for the bad reads.\n - \u200B └── plotpaths.csv: CSV file containing the paths to the plots based on the read ID.\n - """ - ) + \u200B └── plotpaths.csv: CSV file containing the paths to the plots based on the read ID.\n""" + ), ), - ], + ] = "", n_workers: Annotated[ int, typer.Option( - "--n_workers", help="Number of CPUs to use for parallel processing" + "--n_workers", "-n", help="Number of CPUs to use for parallel processing" ), ] = 1, plot_signal: Annotated[ @@ -106,38 +121,56 @@ def fetch_cap_signal( @app.command() -def perpare_train_dataset( - data_dir: Annotated[ +def make_train_dataset( + csv_dir: Annotated[ str, - typer.Argument( - help="Directory containing all the cap signal data files (data__cap_x.csv)" + typer.Option( + "--csv_dir", + "-c", + help="Directory containing all the cap signal data files (data__cap_x.csv)", ), - ], + ] = "", save_dir: Annotated[ str, - typer.Argument( - help="Directory where the processed data will be saved as csv files." + typer.Option( + "--save_dir", + "-s", + help="Directory where the processed data will be saved as csv files.", ), - ], + ] = "", target_length: Annotated[ int, - typer.Argument( - help="Number of signal points in cap signal to consider. If the signal is shorter, it will be padded with zeros. If the signal is longer, it will be truncated." + typer.Option( + "--target_length", + "-t", + help="Number of signal points in cap signal to consider. If the signal is shorter, it will be padded with zeros. If the signal is longer, it will be truncated.", ), - ], + ] = 500, dtype: Annotated[ str, - typer.Argument( - help="Data type to transform the dataset to Valid values are 'float16', 'float32', or 'float64'." + typer.Option( + "--dtype", + "-d", + help="Data type to transform the dataset to Valid values are 'float16', 'float32', or 'float64'.", ), - ], + ] = "float32", n_workers: Annotated[ int, - typer.Argument(help="Number of CPUs to use for parallel processing"), - ], + typer.Option( + "--n_workers", "-n", help="Number of CPUs to use for parallel processing" + ), + ] = 1, ) -> None: """ Prepares dataset for training the ML model. + + Example command: + capfinder make-train-dataset \\ + --csv_dir /path/to/csv_dir \\ + --save_dir /path/to/save_dir \\ + --target_length 500 \\ + --dtype float32 \\ + --n_workers 10 """ from typing import cast @@ -154,7 +187,7 @@ def perpare_train_dataset( ) train_etl( - data_dir=data_dir, + data_dir=csv_dir, save_dir=save_dir, target_length=target_length, dtype=dt, @@ -165,8 +198,11 @@ def perpare_train_dataset( @app.command() def create_train_config( file_path: Annotated[ - str, typer.Argument(help="File path to save the JSON configuration file") - ], + str, + typer.Option( + "--file_path", "-f", help="File path to save the JSON configuration file" + ), + ] = "", ) -> None: """Creats a dummy JSON configuration file at the specified path. Edit it to suit your needs.""" config = { @@ -208,12 +244,12 @@ def create_train_config( def train_model( config_file: Annotated[ str, - typer.Argument( - help="""\n - Path to the JSON configuration file containing the parameters for the training pipeline. \n - """ + typer.Option( + "--file_path", + "-f", + help="""Path to the JSON configuration file containing the parameters for the training pipeline.""", ), - ], + ] = "", ) -> None: """Trains the model using the parameters in the JSON configuration file.""" from capfinder.training import run_training_pipeline