In [49]:
import warnings

import kfp
import kfp.dsl as dsl
import kfp.components as comp
from kfp.components import InputPath, OutputPath

warnings.filterwarnings("ignore")


## 2. Set configurations

In [50]:
import os
import matplotlib

# matplotlib.use("agg")
import matplotlib.pyplot as plt

region_name = "Hawaii_201801_202206"
dir_name = region_name
if not os.path.exists(dir_name):
    os.mkdir(dir_name)
root_dir = lambda x: os.path.join(dir_name, x)

run_local = False


In [51]:
def set_config(
    index_json: OutputPath("json"),
    config_json: OutputPath("json"),
    datetime_json: OutputPath("json"),
    num_parallel: int = 1,
) -> list:

    import obspy
    import os
    import pickle
    import datetime
    import numpy as np
    import json

    pi = 3.1415926
    degree2km = pi * 6371 / 180

    region_name = "Hawaii"
    center = (-155.32, 19.39)
    horizontal_degree = 2.0
    vertical_degree = 2.0
    # starttime = obspy.UTCDateTime("2018-01-01T00")
    starttime = obspy.UTCDateTime("2021-01-01T00")
    endtime = obspy.UTCDateTime("2022-05-18T00")
    # endtime = obspy.UTCDateTime("2018-01-03T00")
    client = "IRIS"
    network_list = ["HV", "PT"]
    channel_list = "HH*,BH*,EH*,HN*"

    ####### save config ########
    config = {}
    config["region"] = region_name
    config["center"] = center
    config["xlim_degree"] = [
        center[0] - horizontal_degree / 2,
        center[0] + horizontal_degree / 2,
    ]
    config["ylim_degree"] = [
        center[1] - vertical_degree / 2,
        center[1] + vertical_degree / 2,
    ]
    config["degree2km"] = degree2km
    config["starttime"] = starttime.datetime.isoformat()
    config["endtime"] = endtime.datetime.isoformat()
    config["networks"] = network_list
    config["channels"] = channel_list
    config["client"] = client

    with open(config_json, 'w') as fp:
        json.dump(config, fp)

    one_day = datetime.timedelta(days=1)
    # one_hour = datetime.timedelta(hours=1)
    starttimes = []
    tmp_start = starttime
    while tmp_start < endtime:
        starttimes.append(tmp_start.datetime.isoformat())
        # tmp_start += one_hour
        tmp_start += one_day

    with open(datetime_json, "w") as fp:
        # json.dump({"starttimes": starttimes, "interval": one_hour.total_seconds()}, fp)
        json.dump({"starttimes": starttimes, "interval": one_day.total_seconds()}, fp)

    if num_parallel == 0:
        # num_parallel = min(60, len(starttimes)//6)
        num_parallel = min(30, len(starttimes))

    idx = [[] for i in range(num_parallel)]
    for i in range(len(starttimes)):
        idx[i - i // num_parallel * num_parallel].append(i)

    with open(index_json, 'w') as fp:
        json.dump(idx, fp)

    return list(range(num_parallel))


In [52]:
if run_local:
    idx = set_config(root_dir("index.json"), root_dir("config.json"), root_dir("datetimes.json"), num_parallel=1,)

In [53]:
config_op = comp.func_to_container_op(
    set_config,
    # base_image='zhuwq0/quakeflow-env:latest',
    base_image='python:3.8',
    packages_to_install=["numpy", "obspy",],
)


## 3. Download events in the routine catalog

This catalog is not used by QuakeFolow. It is only used for comparing detection results.

In [54]:
def download_events(config_json: InputPath("json"), event_csv: OutputPath(str)):

    import pickle, os
    import obspy
    from obspy.clients.fdsn import Client
    from collections import defaultdict
    import pandas as pd
    import json
    import matplotlib

    #     matplotlib.use("agg")
    import matplotlib.pyplot as plt

    with open(config_json, "r") as fp:
        config = json.load(fp)

    ####### IRIS catalog ########
    try:
        events = Client(config["client"]).get_events(
            starttime=config["starttime"],
            endtime=config["endtime"],
            minlongitude=config["xlim_degree"][0],
            maxlongitude=config["xlim_degree"][1],
            minlatitude=config["ylim_degree"][0],
            maxlatitude=config["ylim_degree"][1],
            # filename='events.xml',
        )
    except:
        events = Client("iris").get_events(
            starttime=config["starttime"],
            endtime=config["endtime"],
            minlongitude=config["xlim_degree"][0],
            maxlongitude=config["xlim_degree"][1],
            minlatitude=config["ylim_degree"][0],
            maxlatitude=config["ylim_degree"][1],
            # filename='events.xml',
        )

    #     events = obspy.read_events('events.xml')
    print(f"Number of events: {len(events)}")
    #     events.plot('local', outfile="events.png")
    #     events.plot('local')

    ####### Save catalog ########
    catalog = defaultdict(list)
    for event in events:
        if len(event.magnitudes) > 0:
            catalog["time"].append(event.origins[0].time.datetime)
            catalog["magnitude"].append(event.magnitudes[0].mag)
            catalog["longitude"].append(event.origins[0].longitude)
            catalog["latitude"].append(event.origins[0].latitude)
            catalog["depth(m)"].append(event.origins[0].depth)
    catalog = pd.DataFrame.from_dict(catalog).sort_values(["time"])
    catalog.to_csv(
        event_csv,
        sep="\t",
        index=False,
        float_format="%.3f",
        date_format='%Y-%m-%dT%H:%M:%S.%f',
        columns=["time", "magnitude", "longitude", "latitude", "depth(m)"],
    )

    ####### Plot catalog ########
    plt.figure()
    plt.plot(catalog["longitude"], catalog["latitude"], '.', markersize=1)
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.axis("scaled")
    plt.xlim(config["xlim_degree"])
    plt.ylim(config["ylim_degree"])
    #     plt.savefig(os.path.join(data_path, "events_loc.png"))
    plt.show()

    plt.figure()
    plt.plot_date(catalog["time"], catalog["magnitude"], '.', markersize=1)
    plt.gcf().autofmt_xdate()
    plt.ylabel("Magnitude")
    plt.title(f"Number of events: {len(events)}")
    plt.savefig(os.path.join("events_mag_time.png"))
    plt.show()


In [55]:
if run_local:
    download_events(root_dir("config.json"), root_dir("events.csv"))


In [56]:
download_events_op = comp.func_to_container_op(
    download_events,
    # base_image='zhuwq0/quakeflow-env:latest',
    base_image='python:3.8',
    packages_to_install=["obspy", "pandas", "matplotlib",],
)


## 4. Download stations

In [57]:
def download_stations(
    config_json: InputPath("json"), station_csv: OutputPath(str), station_pkl: OutputPath("pickle"),
):

    import pickle, os
    import obspy
    from obspy.clients.fdsn import Client
    from collections import defaultdict
    import pandas as pd
    import json
    import matplotlib

    #     matplotlib.use("agg")
    import matplotlib.pyplot as plt

    with open(config_json, "r") as fp:
        config = json.load(fp)

    print("Network:", ",".join(config["networks"]))
    ####### Download stations ########
    stations = Client(config["client"]).get_stations(
        network=",".join(config["networks"]),
        station="*",
        starttime=config["starttime"],
        endtime=config["endtime"],
        minlongitude=config["xlim_degree"][0],
        maxlongitude=config["xlim_degree"][1],
        minlatitude=config["ylim_degree"][0],
        maxlatitude=config["ylim_degree"][1],
        channel=config["channels"],
        level="response",
    )  # ,
    #                                            filename="stations.xml")

    #     stations = obspy.read_inventory("stations.xml")
    print("Number of stations: {}".format(sum([len(x) for x in stations])))
    # stations.plot('local', outfile="stations.png")
    #     stations.plot('local')

    ####### Save stations ########
    station_locs = defaultdict(dict)
    for network in stations:
        for station in network:
            for chn in station:
                sid = f"{network.code}.{station.code}.{chn.location_code}.{chn.code[:-1]}"
                if sid in station_locs:
                    station_locs[sid]["component"] += f",{chn.code[-1]}"
                    station_locs[sid]["response"] += f",{chn.response.instrument_sensitivity.value:.2f}"
                else:
                    component = f"{chn.code[-1]}"
                    response = f"{chn.response.instrument_sensitivity.value:.2f}"
                    dtype = chn.response.instrument_sensitivity.input_units.lower()
                    tmp_dict = {}
                    (tmp_dict["longitude"], tmp_dict["latitude"], tmp_dict["elevation(m)"],) = (
                        chn.longitude,
                        chn.latitude,
                        chn.elevation,
                    )
                    tmp_dict["component"], tmp_dict["response"], tmp_dict["unit"] = (
                        component,
                        response,
                        dtype,
                    )
                    station_locs[sid] = tmp_dict

    station_locs = pd.DataFrame.from_dict(station_locs, orient='index')
    station_locs.to_csv(
        station_csv,
        sep="\t",
        float_format="%.3f",
        index_label="station",
        columns=["longitude", "latitude", "elevation(m)", "unit", "component", "response",],
    )

    with open(station_pkl, "wb") as fp:
        pickle.dump(stations, fp)

    #     ####### Plot stations ########
    plt.figure()
    plt.plot(station_locs["longitude"], station_locs["latitude"], "^", label="Stations")
    plt.xlabel("X (km)")
    plt.ylabel("Y (km)")
    plt.axis("scaled")
    plt.xlim(config["xlim_degree"])
    plt.ylim(config["ylim_degree"])
    plt.legend()
    plt.title(f"Number of stations: {len(station_locs)}")
    #     plt.savefig(os.path.join(data_path, "stations_loc.png"))
    plt.show()


In [58]:
if run_local:
    download_stations(root_dir("config.json"), root_dir("stations.csv"), root_dir("stations.pkl"))


In [59]:
download_stations_op = comp.func_to_container_op(
    download_stations,
    # base_image='zhuwq0/quakeflow-env:latest',
    base_image='python:3.8',
    packages_to_install=["obspy", "pandas", "matplotlib",],
)


## 5. Download waveform data

In [60]:
def download_waveform(
    i: int,
    index_json: InputPath("json"),
    config_json: InputPath("json"),
    datetime_json: InputPath("json"),
    station_pkl: InputPath("pickle"),
    fname_csv: OutputPath(str),
    data_path: str,
    bucket_name: str = "waveforms",
    s3_url: str = "minio-service:9000",
    secure: bool = True,
) -> str:

    import pickle, os
    import obspy
    from obspy.clients.fdsn import Client
    import time
    import json
    import random
    import threading

    lock = threading.Lock()

    with open(index_json, "r") as fp:
        index = json.load(fp)
    idx = index[i]
    with open(config_json, "r") as fp:
        config = json.load(fp)
    with open(datetime_json, "r") as fp:
        tmp = json.load(fp)
        starttimes = tmp["starttimes"]
        interval = tmp["interval"]
    with open(station_pkl, "rb") as fp:
        stations = pickle.load(fp)

    waveform_dir = os.path.join(data_path, config["region"], "waveforms")
    if not os.path.exists(waveform_dir):
        os.makedirs(waveform_dir)

    ####### Download data ########
    client = Client(config["client"])
    fname_list = ["fname"]

    def download(i):
        #     for i in idx:
        starttime = obspy.UTCDateTime(starttimes[i])
        endtime = starttime + interval
        # fname = "{}.mseed".format(starttime.datetime.strftime("%Y-%m-%dT%H:%M:%S"))
        folder = starttime.datetime.strftime("%Y/%j")

        if not os.path.exists(os.path.join(waveform_dir, folder)):
            os.makedirs(os.path.join(waveform_dir, folder), exist_ok=True)
        # status = os.system(f"ssh zhuwq@wintermute.gps.caltech.edu mkdir -p /scratch/zhuwq/Hawaii/wf/{folder}/")
        # if status != 0:
        #     print(f"Failed: ssh zhuwq@wintermute.gps.caltech.edu mkdir -p /scratch/zhuwq/Hawaii/wf/{folder}/")

        max_retry = 10

        status = -1
        retry_rsync = 0
        while (status != 0) and (retry_rsync < max_retry):
            status = os.system(f"ssh zhuwq@wintermute mkdir -p /scratch/zhuwq/Hawaii/wf/{folder}/")
            retry_rsync += 1
            time.sleep(5)

        if status != 0:
            print(f"Failed: ssh zhuwq@wintermute mkdir -p /scratch/zhuwq/Hawaii/wf/{folder}/")

        print(f"{folder} download starts")
        num_sta = 0
        for network in stations:
            for station in network:
                print(f"********{network.code}.{station.code}********")
                retry = 0
                while retry < max_retry:
                    try:
                        stream = client.get_waveforms(
                            network.code, station.code, "*", config["channels"], starttime, endtime,
                        )
                        if len(stream) > 0:
                            stream = stream.merge(fill_value=0)
                            stream = stream.trim(starttime, endtime, pad=True, fill_value=0)

                            for trace in stream:
                                if trace.stats.sampling_rate != 100:
                                    trace = trace.interpolate(100, method="linear")
                                trace_name = f"{trace.stats.network}.{trace.stats.station}.{trace.stats.channel}.mseed"
                                trace.write(os.path.join(waveform_dir, folder, trace_name), format="mseed")
                                # os.system(f"scp {trace_name} zhuwq@wintermute.gps.caltech.edu:/scratch/zhuwq/Hawaii/wf/{folder}/")
                                # scp.put(f"{trace_name}", f"/scratch/zhuwq/Hawaii/wf/{folder}/")
                                # status = os.system(f"rsync -av {os.path.join(waveform_dir, folder, trace_name)} zhuwq@wintermute.gps.caltech.edu:/scratch/zhuwq/Hawaii/wf/{folder}/{trace_name}")
                                # if status != 0:
                                #     print(f"Failed: rsync -av {os.path.join(waveform_dir, folder, trace_name)} zhuwq@wintermute.gps.caltech.edu:/scratch/zhuwq/Hawaii/wf/{folder}/{trace_name}")

                                status = -1
                                retry_rsync = 0
                                while (status != 0) and (retry_rsync < max_retry):
                                    status = os.system(f"rsync -av {os.path.join(waveform_dir, folder, trace_name)} zhuwq@wintermute:/scratch/zhuwq/Hawaii/wf/{folder}/{trace_name}")
                                    retry_rsync += 1
                                    time.sleep(5)

                                if status != 0:
                                    print(f"Failed: rsync -av {os.path.join(waveform_dir, folder, trace_name)} zhuwq@wintermute:/scratch/zhuwq/Hawaii/wf/{folder}/{trace_name}")
                                else:
                                    os.system(f"rm -f {os.path.join(waveform_dir, folder, trace_name)}")

                        break
                    except Exception as err:
                        print("Error {}.{}: {}".format(network.code, station.code, err))
                        message = "No data available for request."
                        if str(err)[: len(message)] == message:
                            break
                        retry += 1
                        time.sleep(5)
                        continue

                if retry == max_retry:
                    print(f"{folder}: MAX {max_retry} retries reached : {network.code}.{station.code}")

        print(f"{folder} download succeeds")

        lock.acquire()
        fname_list.append(folder)
        lock.release()

    threads = []
    MAX_THREADS = 4
    # MAX_THREADS = 1
    for ii, i in enumerate(idx):
        t = threading.Thread(target=download, args=(i,))
        t.start()
        time.sleep(1)
        threads.append(t)
        if ii % MAX_THREADS == MAX_THREADS - 1:
            for t in threads:
                t.join()
            threads = []
    for t in threads:
        t.join()

    with open(fname_csv, "w") as fp:
        fp.write("\n".join(fname_list))

    return waveform_dir


In [61]:
if run_local:
    waveform_path = download_waveform(
        0,
        root_dir("index.json"),
        root_dir("config.json"),
        root_dir("datetimes.json"),
        root_dir("stations.pkl"),
        root_dir("fname.csv"),
        data_path=root_dir(""),
    )


In [62]:
download_waveform_op = comp.func_to_container_op(
    download_waveform,
    base_image="zhuwq0/waveform-env:1.1",
    # base_image='python:3.8',
    # packages_to_install=["obspy", "minio"],
)


In [63]:
@dsl.pipeline(name='QuakeFlow', description='')
def quakeflow_pipeline(
    data_path: str = "/tmp/",
    num_parallel=0,
    bucket_catalog: str = "catalogs",
    s3_url: str = "minio-service:9000",
    secure: bool = False,
):

    config = config_op(num_parallel)

    events = download_events_op(config.outputs["config_json"]).set_display_name('Download Events')

    stations = download_stations_op(config.outputs["config_json"]).set_display_name('Download Stations')

    with kfp.dsl.ParallelFor(config.outputs["output"]) as i:

        # vop_ = dsl.VolumeOp(
        #     name=f"Create volume 2",
        #     resource_name=f"data-volume-{str(i)}",
        #     size="50Gi",
        #     modes=dsl.VOLUME_MODE_RWO,
        # ).set_retry(3)

        download_op_ = (
            download_waveform_op(
                i,
                config.outputs["index_json"],
                config.outputs["config_json"],
                config.outputs["datetime_json"],
                stations.outputs["station_pkl"],
                data_path=data_path,
                bucket_name=f"waveforms",
                s3_url=s3_url,
                secure=secure,
            )
            # .add_pvolumes({data_path: vop_.volume})
            .set_cpu_request("800m")
            .set_retry(3)
            .set_display_name('Download Waveforms')
        )
        download_op_.execution_options.caching_strategy.max_cache_staleness = "P30D"
        download_op_.set_image_pull_policy("Always")


In [64]:
import os

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/home/weiqiang/.dotbot/cloud/quakeflow_zhuwq.json"
experiment_name = 'QuakeFlow'
pipeline_func = quakeflow_pipeline
run_name = pipeline_func.__name__ + '_run'

arguments = {
    "data_path": "/tmp",
    "num_parallel": 0,
    "bucket_catalog": "catalogs",
    "s3_url": "minio-service:9000",
    "secure": False,
}

if not run_local:
    pipeline_conf = kfp.dsl.PipelineConf()
    pipeline_conf.set_image_pull_policy("Always")
    pipeline_conf.ttl_seconds_after_finished = 60 * 10
    client = kfp.Client(host="7f176775ae43f263-dot-us-west1.pipelines.googleusercontent.com")
    # client = kfp.Client(host="http://localhost:8080")
    kfp.compiler.Compiler().compile(pipeline_func, '{}.zip'.format(experiment_name), pipeline_conf=pipeline_conf)
    results = client.create_run_from_pipeline_func(
        pipeline_func, experiment_name=experiment_name, run_name=run_name, arguments=arguments,
    )
