In [2]:
import boto3
import time

In [7]:
# Submit a batch job
def submit_job(
    job_name,
    job_queue,
    job_definition,
    container_overrides,
    parameters,
    multi=False,
    multi_size=10,
):
    client = boto3.client("batch")
    extras = {}
    if multi:
        extras["arrayProperties"] = {"size": multi_size}

    response = client.submit_job(
        jobName=job_name,
        jobQueue=job_queue,
        jobDefinition=job_definition,
        containerOverrides=container_overrides,
        parameters=parameters,
        schedulingPriorityOverride=99,
        shareIdentifier="alex",
        retryStrategy={"attempts": 1},
        **extras,
    )
    return response["jobId"]


# Get the status of a job
def get_job_status(job_id):
    client = boto3.client("batch")
    response = client.describe_jobs(jobs=[job_id])
    return response["jobs"][0]["status"]


def get_cloudwatch_logs(job_id):
    client = boto3.client("batch")
    response = client.describe_jobs(jobs=[job_id])
    log_stream_name = response["jobs"][0]["container"]["logStreamName"]

    logs_client = boto3.client("logs")
    log_group_name = f"/aws/batch/auspatious-ldn"

    response = logs_client.get_log_events(
        logGroupName=log_group_name, logStreamName=log_stream_name, startFromHead=True
    )

    return response["events"]

In [None]:
job_name = "real-job"
job_queue = "normalQueue"
job_definition = "auspatious-ldn"
container_overrides = {
    "command": [
        "ldn-processor",
        # "--tile",
        # "Ref::tile",
        "--year",
        "Ref::year",
        "--version",
        "Ref::version",
        "--n-workers",
        "Ref::n_workers",
        "--threads-per-worker",
        "Ref::threads_per_worker",
        "--memory-limit",
        "Ref::memory_limit",
        "Ref::overwrite",
    ],
    "resourceRequirements": [
        {"type": "VCPU", "value": "16"},
        {"type": "MEMORY", "value": "122880"},
    ],
}
parameters = {
    "tile": "238,47",
    "year": "2023",
    "version": "0.1.0",
    "n_workers": "4",
    "threads_per_worker": "32",
    "memory_limit": "80GB",
    "overwrite": "--overwrite",
}

job_id = submit_job(
    job_name, job_queue, job_definition, container_overrides, parameters, multi=True
)
print("Job submitted with id:", job_id)

old_logs = set()

status = get_job_status(job_id)

# Get status of job every 10 seconds
while status not in('SUCCEEDED', 'FAILED'):
    new_status = get_job_status(job_id)
    if new_status != status:
        print('Job status:', status)
        status = new_status

    if new_status == 'RUNNING':
        logs = get_cloudwatch_logs(job_id)

        for log in logs:
            log_message = log['message']
            if log_message not in old_logs:
                print(log_message)
                old_logs.add(log_message)

    time.sleep(10)
else:
    print(f'Job {status} successfully')