In [None]:
import tempfile
import time
import subprocess
import os
import dask
import math
import json
import dask.bag as db
from dask_cloudprovider.aws import EC2Cluster, FargateCluster

In [None]:
from dataplug import CloudObject
from dataplug.formats.genomics.fastq import FASTQGZip, partition_reads_batches

In [None]:
def filter_dedup_fastq(slice):
    print("Processing slice: ", slice)
    task_start = time.time()

    t0 = time.perf_counter()
    chunk_filename = tempfile.mktemp() + ".fastq"
    slice.to_file(chunk_filename)
    t1 = time.perf_counter()

    fetch_time = t1 - t0
    chunk_size = os.stat(chunk_filename).st_size

    filtered = chunk_filename + ".filtered"
    t0 = time.perf_counter()
    proc = subprocess.check_output(["fastq-filter", "-e", "0.001", "-o", filtered, chunk_filename])
    t1 = time.perf_counter()
    filter_time = t1 - t0
    print(proc)

    deduped = filtered + ".dedup"
    t0 = time.perf_counter()
    proc = subprocess.check_output(["czid-dedup", "-i", filtered, "-o", deduped])
    t1 = time.perf_counter()
    dedup_time = t1 - t0
    print(proc)

    res_size = os.stat(deduped).st_size

    os.remove(chunk_filename)
    os.remove(deduped)
    os.remove(filtered)

    task_end = time.time()
    return {
        "chunk_size": chunk_size,
        "result_size": res_size,
        "fetch_time": fetch_time,
        "filter_time": filter_time,
        "dedup_time": dedup_time,
        "task_time": task_end - task_start,
        "task_start": task_start,
        "task_end": task_end,
    }

In [None]:
def index_fastq(slice):
    task_start = time.time()

    t0 = time.perf_counter()
    chunk_filename = tempfile.mktemp() + ".fastq"
    slice.to_file(chunk_filename)
    t1 = time.perf_counter()

    fetch_time = t1 - t0
    chunk_size = os.stat(chunk_filename).st_size

    output_file = chunk_filename + ".fa"
    t0 = time.perf_counter()
    proc = subprocess.check_output(["seqtk", "seq", "-a", chunk_filename, "/dev/null"])
    t1 = time.perf_counter()
    print(proc)

    transform_time = t1 - t0

    # output_file_2 = output_file + ".index"
    # t0 = time.perf_counter()
    # proc = subprocess.check_output(["seqtk", "trimfq", "-b", "5", "-e", "10", output_file, ">", output_file_2])
    # t1 = time.perf_counter()
    # trim_time = t1 - t0
    # print(proc)

    os.remove(chunk_filename)
    os.remove(output_file)

    task_end = time.time()
    return {
        "chunk_size": chunk_size,
        "fetch_time": fetch_time,
        "transform_time": transform_time,
        # "trim_time": trim_time,
        "task_time": task_end - task_start,
        "task_start": task_start,
        "task_end": task_end,
    }

In [None]:
dask.__version__

In [None]:
storage_config = {
    "aws_access_key_id": "",
    "aws_secret_access_key": "",
    "aws_session_token": "",
    "region_name": "us-east-1",
    "use_token": False
}

co = CloudObject.from_path(FASTQGZip, "s3://lithops-datasets/fastq.gz/13gb.fastq.gz", storage_config=storage_config)

In [None]:
cluster = EC2Cluster(
    region="us-east-1",
    availability_zone="us-east-1a",
    subnet_id="subnet-0a95cf6e",
    worker_instance_type="m6i.2xlarge",
    scheduler_instance_type="m6i.large",
    docker_image="aitorarjona/dataplug-fastq-dask:0.4",
    security=False
)

In [None]:
data_slices = co.partition(partition_reads_batches, num_batches=25)

In [None]:
client = cluster.get_client()
client

In [None]:
cluster.scale(1)
cluster.wait_for_workers(1)
wl = db.from_sequence([data_slices[0]]).map(index_fastq)
fut = client.compute(wl)
fut.result()

In [None]:
vm_cpus = 8
workers = math.ceil(len(data_slices) / vm_cpus)

In [None]:
t0 = time.perf_counter()
cluster.scale(workers)
cluster.wait_for_workers(workers)
t1 = time.perf_counter()

print(f"Scaling took {t1 - t0} seconds")

In [None]:
t0 = time.perf_counter()
# wl = db.from_sequence(data_slices).map(filter_dedup_fastq)
wl = db.from_sequence(data_slices).map(index_fastq)
fut = client.compute(wl)
try:
    results = fut.result()
except Exception as e:
    results = []
    print(e)
t1 = time.perf_counter()
print(f"Execution took {t1 - t0} seconds")
# results = fut.result()

In [None]:
# experiment_name = f"fastq-dedup-{len(data_slices)}b-{workers}w"
experiment_name = f"fastq2fasta-{len(data_slices)}b-{workers}w"
with open(f"{experiment_name}.json", "w") as f:
    f.write(json.dumps(results, indent=4))

In [None]:
cluster.close()