Skip to content

Commit

Permalink
Improved cli
Browse files Browse the repository at this point in the history
  • Loading branch information
adnaniazi committed Jun 9, 2024
1 parent 9f33fa3 commit c0c94e6
Showing 1 changed file with 81 additions and 45 deletions.
126 changes: 81 additions & 45 deletions src/capfinder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,31 @@

app = typer.Typer(
help="capfinder: A Python package for decoding RNA cap types using an encoder-based deep learning model.",
add_completion=False,
add_completion=True,
rich_markup_mode="rich",
)


@app.command()
def fetch_cap_signal(
bam_filepath: Annotated[str, typer.Argument(help="Path to the BAM file")],
def extract_cap_signal(
bam_filepath: Annotated[
str, typer.Option("--bam_filepath", "-b", help="Path to the BAM file")
] = "",
pod5_dir: Annotated[
str, typer.Argument(help="Path to directory containing POD5 files")
],
reference: Annotated[str, typer.Argument(help="Reference Sequence (5' -> 3')")],
str,
typer.Option(
"--pod5_dir", "-p", help="Path to directory containing POD5 files"
),
] = "",
reference: Annotated[
str,
typer.Option("--reference", "-r", help="Reference Sequence (5' -> 3')"),
] = "GCTTTCGTTCGTCTCCGGACTTATCGCACCACCTATCCATCATCAGTACTGT",
cap_class: Annotated[
int,
typer.Argument(
typer.Option(
"--cap_class",
"-c",
help="""\n
Integer-based class label for the RNA cap type. \n
- 0 represents Cap_0 \n
Expand All @@ -33,24 +43,30 @@ def fetch_cap_signal(
- 5 represents NAD Cap \n
- 6 represents FAD Cap \n
- -99 represents and unknown cap(s). \n
"""
""",
),
],
] = -99,
cap_n1_pos0: Annotated[
int,
typer.Argument(
help="0-based index of 1st nucleotide (N1) of cap in the reference"
typer.Option(
"--cap_n1_pos0",
"-p",
help="0-based index of 1st nucleotide (N1) of cap in the reference",
),
],
] = 52,
train_or_test: Annotated[
str,
typer.Argument(
help="set to train or test depending on whether it is training or testing data"
typer.Option(
"--train_or_test",
"-t",
help="set to train or test depending on whether it is training or testing data",
),
],
] = "test",
output_dir: Annotated[
str,
typer.Argument(
typer.Option(
"--output_dir",
"-o",
help=textwrap.dedent(
"""
Path to the output directory which will contain: \n
Expand All @@ -60,15 +76,14 @@ def fetch_cap_signal(
└── (Optional) plots directory containing cap signal plots, if plot_signal is set to True.\n
\u200B ├── good_reads: Directory that contains the plots for the good reads.\n
\u200B ├── bad_reads: Directory that contains the plots for the bad reads.\n
\u200B └── plotpaths.csv: CSV file containing the paths to the plots based on the read ID.\n
"""
)
\u200B └── plotpaths.csv: CSV file containing the paths to the plots based on the read ID.\n"""
),
),
],
] = "",
n_workers: Annotated[
int,
typer.Option(
"--n_workers", help="Number of CPUs to use for parallel processing"
"--n_workers", "-n", help="Number of CPUs to use for parallel processing"
),
] = 1,
plot_signal: Annotated[
Expand Down Expand Up @@ -106,38 +121,56 @@ def fetch_cap_signal(


@app.command()
def perpare_train_dataset(
data_dir: Annotated[
def make_train_dataset(
csv_dir: Annotated[
str,
typer.Argument(
help="Directory containing all the cap signal data files (data__cap_x.csv)"
typer.Option(
"--csv_dir",
"-c",
help="Directory containing all the cap signal data files (data__cap_x.csv)",
),
],
] = "",
save_dir: Annotated[
str,
typer.Argument(
help="Directory where the processed data will be saved as csv files."
typer.Option(
"--save_dir",
"-s",
help="Directory where the processed data will be saved as csv files.",
),
],
] = "",
target_length: Annotated[
int,
typer.Argument(
help="Number of signal points in cap signal to consider. If the signal is shorter, it will be padded with zeros. If the signal is longer, it will be truncated."
typer.Option(
"--target_length",
"-t",
help="Number of signal points in cap signal to consider. If the signal is shorter, it will be padded with zeros. If the signal is longer, it will be truncated.",
),
],
] = 500,
dtype: Annotated[
str,
typer.Argument(
help="Data type to transform the dataset to Valid values are 'float16', 'float32', or 'float64'."
typer.Option(
"--dtype",
"-d",
help="Data type to transform the dataset to Valid values are 'float16', 'float32', or 'float64'.",
),
],
] = "float32",
n_workers: Annotated[
int,
typer.Argument(help="Number of CPUs to use for parallel processing"),
],
typer.Option(
"--n_workers", "-n", help="Number of CPUs to use for parallel processing"
),
] = 1,
) -> None:
"""
Prepares dataset for training the ML model.
Example command:
capfinder make-train-dataset \\
--csv_dir /path/to/csv_dir \\
--save_dir /path/to/save_dir \\
--target_length 500 \\
--dtype float32 \\
--n_workers 10
"""
from typing import cast

Expand All @@ -154,7 +187,7 @@ def perpare_train_dataset(
)

train_etl(
data_dir=data_dir,
data_dir=csv_dir,
save_dir=save_dir,
target_length=target_length,
dtype=dt,
Expand All @@ -165,8 +198,11 @@ def perpare_train_dataset(
@app.command()
def create_train_config(
file_path: Annotated[
str, typer.Argument(help="File path to save the JSON configuration file")
],
str,
typer.Option(
"--file_path", "-f", help="File path to save the JSON configuration file"
),
] = "",
) -> None:
"""Creats a dummy JSON configuration file at the specified path. Edit it to suit your needs."""
config = {
Expand Down Expand Up @@ -208,12 +244,12 @@ def create_train_config(
def train_model(
config_file: Annotated[
str,
typer.Argument(
help="""\n
Path to the JSON configuration file containing the parameters for the training pipeline. \n
"""
typer.Option(
"--file_path",
"-f",
help="""Path to the JSON configuration file containing the parameters for the training pipeline.""",
),
],
] = "",
) -> None:
"""Trains the model using the parameters in the JSON configuration file."""
from capfinder.training import run_training_pipeline
Expand Down

0 comments on commit c0c94e6

Please sign in to comment.