Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aws_doc_sdk_examples_tools/lliam/domain/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class CreatePrompts(Command):
@dataclass
class RunAilly(Command):
batches: List[str]
packages: List[str]


@dataclass
Expand Down
12 changes: 12 additions & 0 deletions aws_doc_sdk_examples_tools/lliam/domain/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import dataclass


@dataclass
class DomainError:
pass


@dataclass
class CommandExecutionError(DomainError):
command_name: str
message: str
40 changes: 24 additions & 16 deletions aws_doc_sdk_examples_tools/lliam/entry_points/lliam_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
import logging
import typer

from aws_doc_sdk_examples_tools.lliam.config import (
AILLY_DIR,
BATCH_PREFIX
)
from aws_doc_sdk_examples_tools.lliam.domain import commands
from aws_doc_sdk_examples_tools.lliam.config import AILLY_DIR, BATCH_PREFIX
from aws_doc_sdk_examples_tools.lliam.domain import commands, errors
from aws_doc_sdk_examples_tools.lliam.service_layer import messagebus, unit_of_work

logging.basicConfig(
Expand All @@ -31,36 +28,39 @@ def create_prompts(iam_tributary_root: str, system_prompts: List[str] = []):
out_dir=AILLY_DIR,
)
uow = unit_of_work.FsUnitOfWork()
messagebus.handle(cmd, uow)
errors = messagebus.handle(cmd, uow)
handle_domain_errors(errors)


@app.command()
def run_ailly(
batches: Annotated[
Optional[str],
typer.Option(
help="Batch names to process (comma-separated list)"
),
typer.Option(help="Batch names to process (comma-separated list)"),
] = None,
packages: Annotated[
Optional[str], typer.Option(help="Comma delimited list of packages to update")
] = None,
) -> None:
"""
Run ailly to generate IAM policy content and process the results.
If batches is specified, only those batches will be processed.
If batches is omitted, all batches will be processed.
If packages is specified, only those packages will be processed.
"""
requested_batches = parse_batch_names(batches)
cmd = commands.RunAilly(batches=requested_batches)
messagebus.handle(cmd)
package_names = parse_package_names(packages)
cmd = commands.RunAilly(batches=requested_batches, packages=package_names)
errors = messagebus.handle(cmd)
handle_domain_errors(errors)


@app.command()
def update_reservoir(
iam_tributary_root: str,
batches: Annotated[
Optional[str],
typer.Option(
help="Batch names to process (comma-separated list)"
),
typer.Option(help="Batch names to process (comma-separated list)"),
] = None,
packages: Annotated[
Optional[str], typer.Option(help="Comma delimited list of packages to update")
Expand All @@ -77,7 +77,15 @@ def update_reservoir(
cmd = commands.UpdateReservoir(
root=doc_gen_root, batches=batch_names, packages=package_names
)
messagebus.handle(cmd)
errors = messagebus.handle(cmd)
handle_domain_errors(errors)


def handle_domain_errors(errors: List[errors.DomainError]):
if errors:
for error in errors:
logger.error(error)
typer.Exit(code=1)


def parse_batch_names(batch_names_str: Optional[str]) -> List[str]:
Expand All @@ -86,7 +94,7 @@ def parse_batch_names(batch_names_str: Optional[str]) -> List[str]:
"""
if not batch_names_str:
return []

batch_names = []

for name in batch_names_str.split(","):
Expand Down
13 changes: 5 additions & 8 deletions aws_doc_sdk_examples_tools/lliam/service_layer/messagebus.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,21 @@
Message = commands.Command


def handle(
message: commands.Command, uow: Optional[unit_of_work.FsUnitOfWork] = None
):
def handle(message: commands.Command, uow: Optional[unit_of_work.FsUnitOfWork] = None):
queue = [message]

while queue:
message = queue.pop(0)
if isinstance(message, commands.Command):
handle_command(message, uow)
return handle_command(message, uow)
else:
raise Exception(f"{message} was not a Command")


def handle_command(
command: commands.Command, uow: Optional[unit_of_work.FsUnitOfWork]
):
def handle_command(command: commands.Command, uow: Optional[unit_of_work.FsUnitOfWork]):
handler = COMMAND_HANDLERS[type(command)]
handler(command, uow)
errors = handler(command, uow)
return errors


COMMAND_HANDLERS: Dict[Type[commands.Command], Callable] = {
Expand Down
68 changes: 56 additions & 12 deletions aws_doc_sdk_examples_tools/lliam/service_layer/run_ailly.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,50 @@
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path
from subprocess import run
from typing import Any, Dict, List, Optional, Set

from aws_doc_sdk_examples_tools.lliam.domain.commands import RunAilly
from aws_doc_sdk_examples_tools.lliam.domain.errors import (
CommandExecutionError,
DomainError,
)
from aws_doc_sdk_examples_tools.lliam.config import (
AILLY_DIR_PATH,
BATCH_PREFIX,
)

AILLY_CMD_BASE = [
"ailly",
"--max-depth",
"10",
"--root",
str(AILLY_DIR_PATH),
]

logger = logging.getLogger(__file__)


def handle_run_ailly(cmd: RunAilly, uow: None):
resolved_batches = resolve_requested_batches(cmd.batches)

errors: List[DomainError] = []

if resolved_batches:
total_start_time = time.time()

for batch in resolved_batches:
run_ailly_single_batch(batch)
try:
run_ailly_single_batch(batch, cmd.packages)
except FileNotFoundError as e:
errors.append(
CommandExecutionError(
command_name=cmd.__class__.__name__, message=str(e)
)
)

total_end_time = time.time()
total_duration = total_end_time - total_start_time
Expand All @@ -32,6 +54,8 @@ def handle_run_ailly(cmd: RunAilly, uow: None):
f"[TIMECHECK] {num_batches} batches took {format_duration(total_duration)} to run"
)

return errors


def resolve_requested_batches(batch_names: List[str]) -> List[Path]:
if not batch_names:
Expand All @@ -56,19 +80,26 @@ def resolve_requested_batches(batch_names: List[str]) -> List[Path]:
return batch_paths


def run_ailly_single_batch(batch: Path) -> None:
def run_ailly_single_batch(batch: Path, packages: List[str] = []) -> None:
"""Run ailly and process files for a single batch."""
batch_start_time = time.time()
iam_updates_path = AILLY_DIR_PATH / f"updates_{batch.name}.json"

cmd = [
"ailly",
"--max-depth",
"10",
"--root",
str(AILLY_DIR_PATH),
batch.name,
]
if packages:
paths = []
for package in packages:
package_files = [
f"{batch.name}/{p.name}" for p in batch.glob(f"*{package}*.md")
]
paths.extend(package_files)

if not paths:
raise FileNotFoundError(f"No matching files found for packages: {packages}")

cmd = AILLY_CMD_BASE + paths
else:
cmd = AILLY_CMD_BASE + [batch.name]

logger.info(f"Running {cmd}")
run(cmd)

Expand All @@ -79,7 +110,9 @@ def run_ailly_single_batch(batch: Path) -> None:
)

logger.info(f"Processing generated content for {batch.name}")
process_ailly_files(input_dir=batch, output_file=iam_updates_path)
process_ailly_files(
input_dir=batch, output_file=iam_updates_path, packages=packages
)


EXPECTED_KEYS: Set[str] = set(["title", "title_abbrev"])
Expand Down Expand Up @@ -177,7 +210,10 @@ def parse_package_name(policy_update: Dict[str, str]) -> Optional[str]:


def process_ailly_files(
input_dir: Path, output_file: Path, file_pattern: str = "*.md.ailly.md"
input_dir: Path,
output_file: Path,
file_pattern: str = "*.md.ailly.md",
packages: List[str] = [],
) -> None:
"""
Process all .md.ailly.md files in the input directory and write the results as JSON to the output file.
Expand All @@ -186,6 +222,7 @@ def process_ailly_files(
input_dir: Directory containing .md.ailly.md files
output_file: Path to the output JSON file
file_pattern: Pattern to match files (default: "*.md.ailly.md")
packages: Optional list of packages to filter by
"""
results = defaultdict(list)

Expand All @@ -197,6 +234,13 @@ def process_ailly_files(
package_name = parse_package_name(policy_update)
if not package_name:
raise TypeError(f"Could not get package name from policy update.")

if packages and package_name not in packages:
logger.info(
f"Skipping package {package_name} (not in requested packages)"
)
continue

results[package_name].append(policy_update)

with open(output_file, "w", encoding="utf-8") as out_file:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def make_title_abbreviation(old: Example, new: Example, abbreviations: Counter):
version = language.versions[0]
source = version.source
source_title = source.title if source else ""
base = f"{new.title_abbrev} (from '{source_title}' docs)"
base = f"{new.title_abbrev} (from '{source_title}' guide)"
abbreviations[base] += 1
count = abbreviations[base]
return f"{base} ({count})" if count > 1 else base
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ def test_update_examples_title_abbrev(doc_gen_tributary: DocGen):
updated_example = doc_gen_tributary.examples["iam_policies_example"]
assert (
updated_example.title_abbrev
== "Updated Title Abbrev (from 'AWS Account Management' docs)"
== "Updated Title Abbrev (from 'AWS Account Management' guide)"
)
Loading