In [1]:
import kfp
import kfp.dsl as dsl
import kfp.components as comp
from kfp.components import InputPath, OutputPath
import json

## Download Catalog

In [2]:
def download_catalog(catalog_path: OutputPath(str)):
    
    # %%
    import urllib
    import urllib.request as request
    import re, os
    from glob import glob
#     from tqdm import tqdm
    import collections
    import pickle
    from google.cloud import storage
    
    storage_client = storage.Client()
    bucket = storage_client.bucket("quakeflow")

    def upload_blob(source_file_name, destination_blob_name):
#         storage_client = storage.Client()
#         bucket = storage_client.bucket(bucket_name)
        while True:
            try:
                blob = bucket.blob(destination_blob_name)
                blob.upload_from_filename(source_file_name, timeout=3600)
                print(f"File {source_file_name} uploaded to {destination_blob_name}")
                break
            except Exception as e:
                print(f"Error: File {source_file_name} uploade failed\n{e}")
            
    def download_blob(source_file_name, destination_blob_name):
#         storage_client = storage.Client()
#         bucket = storage_client.bucket(bucket_name)
        while True:
            try:
                blob = bucket.blob(source_file_name)
                blob.download_to_filename(destination_blob_name, timeout=3600)
                print(f"File {source_file_name} download to {destination_blob_name}")
                break
            except Exception as e:
                print(f"Error: File {source_file_name} download failed\n{e}")
        
    def exist_blob(source_file_name):
#         storage_client = storage.Client()
#         bucket = storage_client.bucket(bucket_name)
        stats = storage.Blob(bucket=bucket, name=source_file_name).exists(storage_client)
        if stats:
            print(f"File {source_file_name} exist")
        return stats
    
    # %%
    root_dir = "catalogs"
    root_url = "http://ncedc.org/ftp/pub/catalogs/ncss/hypoinverse/phase2k"
    if not os.path.exists(root_dir):
        os.mkdir(root_dir)
        
    catalog_exit = False
    if exist_blob("ncedc_catalogs.txt"):
        download_blob("ncedc_catalogs.txt", "catalogs.txt")
        catalog_exit = True

    # %%
    def get_years(root_url):
        html = urllib.request.urlopen(root_url).read().decode()
        pattern = re.compile("<a href=\"\d\d\d\d/\">", re.S)
        tmp_years = re.findall(pattern, html)
        years = [re.findall("\d\d\d\d", yr)[0] for yr in tmp_years][:-1]
        year_urls = {yr: root_url+"/"+yr for yr in years}
        return year_urls

    if not catalog_exit:
        year_urls = get_years(root_url)

    # %%
    def get_files(year_urls):
        file_urls = {}
        for year, url in year_urls.items():
            html = urllib.request.urlopen(url).read().decode()
            pattern = re.compile("<a href=\".*?\.phase\.Z\">", re.S)
            tmp_files = re.findall(pattern, html)
            files = [re.findall("\d.*?\.phase\.Z", fl)[0] for fl in tmp_files]
            file_urls[year] = [url+"/"+fl for fl in files]
        return file_urls

    if not catalog_exit:
        file_urls = get_files(year_urls)
    
    # %%
    def download_files(file_urls, root_dir):
        for year in file_urls:
            data_dir = os.path.join(root_dir, year)
            if not os.path.exists(data_dir):
                os.makedirs(data_dir)
#             for url in tqdm(file_urls[year], desc="Downloading"):
            for url in file_urls[year]:
                print("Downloading: "+url)
                request.urlretrieve(url, os.path.join(data_dir, url.split('/')[-1]))
                os.system("uncompress "+os.path.join(data_dir, url.split('/')[-1]))

    if not catalog_exit:
        download_files(file_urls, root_dir)

    # %%
    def merge_files(file_urls, root_dir, fout):

        catlog = []
        for year in file_urls:
#             for url in tqdm(file_urls[year], desc="Merging"):
            for url in file_urls[year]:
#                 print(f"Merging: {url}")
                with open(os.path.join(root_dir, year, url.split('/')[-1].rstrip(".Z")), 'r') as fp:
                    lines = fp.readlines()
                    catlog += lines

        with open(fout, 'w') as fp:
#             for line in tqdm(catlog, desc="Writing catalog"):
            for line in catlog:
                fp.write(line)
        print(f"Finish writing {len(catlog)} lines to {fout}")
    
    if not catalog_exit:
        merge_files(file_urls, root_dir, "catalogs.txt")
        upload_blob("catalogs.txt", "ncedc_catalogs.txt")

    # %%
    def build_dict(catalog):
        dataset1 = collections.OrderedDict()
        with open(catalog) as fp:
            for line in fp:
                if line[0].isspace():
                    continue
                elif len(line) > 130:
                    event_id = line
                    dataset1[event_id] = []
                elif len(line) <= 130:
                    dataset1[event_id].append(line)
                else:
                    print("Unrecognized line: %s" % line)

        # dataset organized by event then station
        dataset2 = collections.OrderedDict()
#         for event in tqdm(dataset1, desc="Build dict"):
        print("Build dict:")
        for event in dataset1:
#             print(f"Build dict: {event[:16]}")
            stationset = collections.OrderedDict()
            for line in dataset1[event]:
                # if line[111:113] not in ["  ", "--", "00"]:
                #     sta_id = line[:7]+line[111:113]#plus location code
                # else:
                #     sta_id = line[:7]+"--"
                sta_id = line[:7]+line[111:113]
                if sta_id in stationset:
                    stationset[sta_id].append(line)
                else:
                    stationset[sta_id] = [line]
            dataset2[event] = stationset

        return dataset2

    catalog_dict = build_dict("catalogs.txt")
    
    # %%
    def extract_ps(catalog):

        # pick the best P and S pickers
        dataset = collections.OrderedDict()
#         for event in tqdm(catalog, desc="Extract P/S picks"):
        print("Extract P/S picks:")
        for event in catalog:
#             print(f"Extract P/S picks: {event[:16]}")
            stationset = collections.OrderedDict()
            for sta_id in catalog[event]:
                best_p = 10
                best_s = 10
                id_p = 0
                id_s = 0
                found_p = 0
                found_s = 0
                for j, line in enumerate(catalog[event][sta_id]):
                    if line[14] == 'P':
                        if int(line[16]) < best_p:
                            best_p = int(line[16])
                            id_p = j
                            found_p = 1
                    if line[47] == 'S':
                        if int(line[49]) < best_s:
                            best_s = int(line[20])
                            id_s = j
                            found_s = 1

                if found_p and found_s:
                    stationset[sta_id] = [catalog[event][sta_id][id_p], catalog[event][sta_id][id_s]]

                dataset[event] = stationset

        return dataset

    dataset_ps = extract_ps(catalog_dict)

    # %%
    def write_ps(dataset, fname, with_event=False):
        with open(fname, 'w') as fp:
#             for event in tqdm(dataset, desc="Write P/S picks"):
            print("Write P/S picks:")
            for event in dataset:
#                 print(f"Write P/S picks: {event[:16]}")
                if with_event and len(dataset[event]) > 0:
                    fp.write(event)
                for sta_id in dataset[event]:
                    for line in dataset[event][sta_id]:
                        fp.write(line)

    write_ps(dataset_ps, "catalogs_ps.txt", with_event=True)
    

    # %%
#     def upload_blob(bucket_name, source_file_name, destination_blob_name):

#         storage_client = storage.Client()
#         bucket = storage_client.bucket(bucket_name)
#         blobs = storage_client.list_blobs(bucket_name)
#         for blob in blobs:
#             print(blob.name)

#         blob = bucket.blob(destination_blob_name)
#         blob.upload_from_filename(source_file_name, timeout=3600)
#         print(f"File {source_file_name} uploaded to {destination_blob_name}")
    
#     upload_blob("ncedc", "catalogs_ps.txt", "catalogs/catalogs_ps.txt")

    # %% 
    with open(catalog_path, "w") as fout:
        with open("catalogs_ps.txt") as fin:
            for line in fin:
                fout.write(line)

In [3]:
# download_catalog(credentials=credentials, "test.txt")

download_catalog_op = comp.func_to_container_op(download_catalog, 
                                      base_image='python:3.8',
                                      packages_to_install= [
#                                           "tqdm",
                                          "google-cloud-storage"
                                      ])

In [4]:
#%%
def read_ps_catalog(catalog: InputPath(str),
                    index_path: OutputPath("pickle"),
                    events_path: OutputPath("pickle"),
                    phases_path: OutputPath("pickle")) -> list:
    
    import pickle
#     from tqdm import tqdm
    from collections import OrderedDict
    
    events = OrderedDict()
    phases = OrderedDict()
    index = -1
    with open(catalog) as fp:
#         for line in tqdm(fp, desc="Read catalog"):
        for line in fp:
            if len(line) > 130:
                index += 1
                event_line = line
                events[index] = event_line
                phases[index] = []
            elif len(line) <= 130:
                phase_line = line
                phases[index].append(phase_line)
            else:
                print("Unrecognized line: %s" % line)
                
    with open(events_path, "wb") as fp:
        pickle.dump(events, fp)
        
    with open(phases_path, "wb") as fp:
        pickle.dump(phases, fp)
    
#     return list(range(index))
    num_parallel = 48*4 - 3
    idxs = [[] for i in range(num_parallel)]
    for i in range(index):
#     for i in range(index-45*10,index):
        idxs[i - i//num_parallel*num_parallel].append(i)
    
    with open(index_path, "wb") as fp:
        pickle.dump(idxs, fp)
        
    print(f"Events number: {index}")
    print(f"Parallel number: {num_parallel}")

    return list(range(num_parallel))

In [5]:
# read_ps_catalog("catalogs_ps.txt", "events.pkl", "phases.pkl")

read_ps_catalog_op = comp.func_to_container_op(read_ps_catalog, 
                                              base_image='python:3.8',
                                              packages_to_install= [
#                                                   "tqdm",
                                              ])

In [6]:
def build_dataset(i: int,
                  index_input: InputPath("pickle"),
                  events_input: InputPath("pickle"),
                  phases_input: InputPath("pickle"),
                  events_path: OutputPath(str),
                  phases_path: OutputPath(str)):
    

    import numpy as np
#     from tqdm import tqdm
    import os
    import pandas as pd
    import pickle
    from collections import namedtuple, OrderedDict
    import obspy
    from obspy.clients.fdsn import Client
    client = Client("NCEDC")
    from google.cloud import storage
    
    # %%
    join_path = lambda x: os.path.join(data_path, x)
    with open(index_input, "rb") as fp:
        index = pickle.load(fp)[i][::-1]
#         index = pickle.load(fp)[i]
    
    if len(index) == 0:
        print("len(index) = 0")
        
        with open(events_path, "w") as fout:
            fout.write("")
        with open(phases_path, "w") as fout:
            fout.write("")
        return
    
    # %%
    storage_client = storage.Client()
    bucket = storage_client.bucket("quakeflow")

    def upload_blob(source_file_name, destination_blob_name):
#         storage_client = storage.Client()
#         bucket = storage_client.bucket(bucket_name)
        while True:
            try:
                blob = bucket.blob(destination_blob_name)
                blob.upload_from_filename(source_file_name, timeout=3600)
                print(f"File {source_file_name} uploaded to {destination_blob_name}")
                break
            except Exception as e:
                print(f"Error: File {source_file_name} uploade failed\n{e}")
            
    def download_blob(source_file_name, destination_blob_name):
#         storage_client = storage.Client()
#         bucket = storage_client.bucket(bucket_name)
        while True:
            try:
                blob = bucket.blob(source_file_name)
                blob.download_to_filename(destination_blob_name, timeout=3600)
                print(f"File {source_file_name} download to {destination_blob_name}")
                break
            except Exception as e:
                print(f"Error: File {source_file_name} download failed\n{e}")
        
    def exist_blob(source_file_name):
#         storage_client = storage.Client()
#         bucket = storage_client.bucket(bucket_name)
        stats = storage.Blob(bucket=bucket, name=source_file_name).exists(storage_client)
        if stats:
            print(f"File {source_file_name} exist")
        return stats
    
    
    # %%
    def to_float(string):
        if string.strip() == '':
            return 0
        else:
            return float(string.strip())

    Event = namedtuple("event", ["time", "latitude", "longitude",
                                 "depth_km", "magnitude", "magnitude_type", "index"])
    def read_event_line(line, idx):
        time = line[:4]+"-"+line[4:6]+"-"+line[6:8]+"T"+line[8:10]+":"+line[10:12]+":"+line[12:14]+"."+line[14:16] 
        latitude = to_float(line[16:18]) + to_float(line[19:23])/6000.0
        longitude = -(to_float(line[23:26]) + to_float(line[27:31])/6000.0)
        if line[18] == 'S':
            latitude *= -1.0
        if line[26] == 'E':
            longitude *= -1.0
        depth_km = to_float(line[31:36])/100.0
        magnitude_type = (line[146:147]).strip()
        magnitude = to_float(line[147:150])/100.0
        return Event(time=time, latitude=np.round(latitude, 4), longitude=np.round(longitude,4), 
                     depth_km=depth_km, magnitude=magnitude, magnitude_type=magnitude_type, index=idx)

    def read_p_pick(line):
        if float(line[30:32].replace(" ", "0")) < 60:
            tp = (line[17:21]+"-"+line[21:23]+"-"
                +line[23:25]+"T"+line[25:27]+":"
                +line[27:29]+":"+line[30:32]+'.'
                +line[32:34]).replace(" ", "0")
            tp = obspy.UTCDateTime(tp)
        else:
            tp = (line[17:21]+"-"+line[21:23]+"-"
                +line[23:25]+"T"+line[25:27]+":"
                +line[27:29]+":"+'00'+'.'
                +line[32:34]).replace(" ", "0")
            tp = obspy.UTCDateTime(tp) + float(line[30:32].replace(" ", "0"))
        remark = line[13:15].strip()
        weight = line[16:17].strip()
        channel = line[9:12].strip()
        return tp, remark, weight, channel

    def read_s_pick(line):
        if float(line[42:44].replace(" ", "0")) < 60:
            ts = (line[17:21]+"-"+line[21:23]+"-"
                +line[23:25]+"T"+line[25:27]+":"
                +line[27:29]+":"+line[42:44]+'.'
                +line[44:46]).replace(" ", "0")
            ts = obspy.UTCDateTime(ts)
        else:
            ts = (line[17:21]+"-"+line[21:23]+"-"
                +line[23:25]+"T"+line[25:27]+":"
                +line[27:29]+":"+"00"+'.'
                +line[44:46]).replace(" ", "0")
            ts = obspy.UTCDateTime(ts) + float(line[42:44].replace(" ", "0"))
        remark = line[46:48].strip()
        weight = line[49:50].strip()
        channel = line[9:12].strip()
        return ts, remark, weight, channel

    Pick = namedtuple("pick",["p_time", "p_remark", "p_weight", "p_channel",
                              "s_time", "s_remark", "s_weight", "s_channel",
                              "first_motion", "distance_km",
                              "emergence_angle", "azimuth", 
                              "network", "station", "location_code", "event_index"])
    def read_phase_line(p_line, s_line, index):

        line = p_line
        network = (line[5:7]).strip()
        station = (line[:5]).strip()
        location_code = (line[111:113]).strip()
        distance_km = to_float(line[74:78])/10.0
        emergence_angle = to_float(line[78:81])
        # duration = to_float(line[87:91])
        azimuth = to_float(line[91:94])
        first_motion = (line[15:16]).strip()
        p_time, p_remark, p_weight, p_channel = read_p_pick(p_line)
        s_time, s_remark, s_weight, s_channel = read_s_pick(s_line)

        return Pick(p_time=p_time, p_remark=p_remark, p_weight=p_weight, p_channel=p_channel,
                    s_time=s_time, s_remark=s_remark, s_weight=s_weight, s_channel=s_channel,
                    first_motion=first_motion, distance_km=distance_km, 
                    emergence_angle=emergence_angle, azimuth=azimuth, 
                    network=network, station=station, location_code=location_code, event_index=index)

    # %%
    def resample(stream, default_sampling_rate):
        sampling_rate = stream[-1].stats.sampling_rate
        if sampling_rate < default_sampling_rate:
            print("Resample %s" % stream)
            stream.resample(default_sampling_rate) ## resample to 100HZ
            print("After resample: %s" % stream)
        elif sampling_rate > default_sampling_rate:
            print("Resample %s" % stream)
            if np.mod(sampling_rate, default_sampling_rate) == 0: ##directly throw away the data
                stream.decimate(int(sampling_rate//default_sampling_rate), strict_length=False, no_filter=True)
            else:
                stream.resample(default_sampling_rate) ## resample to 100HZ
            print("After resample: %s" % stream)
        return stream

    Station = namedtuple("station", ["id", "latitude", "longitude", "elevation_m", "unit"])
    def download_waveform(pick, waveform_path, station_path):

        Tstart =  pick.p_time - 60.0
        Tend = pick.p_time + 60.0

        if pick.p_channel[:-1] != pick.s_channel[:-1]:
            channels = [pick.p_channel[:-1], pick.s_channel[:-1]]
        else: 
            channels = [pick.p_channel[:-1]]
        channels = set(channels + ["HN", "HH", "EH", "BH", "DP"])

        stream_list = []
        station_list = []
        for channel in channels:
            fname = pick.network+"."+pick.station+"."+pick.location_code+"."+channel+"."+f"{pick.event_index:07d}"

            station_exist = False
            station_upload = False
            if exist_blob(os.path.join(station_path, fname+".xml")):
                download_blob(os.path.join(station_path, fname+".xml"), os.path.join(station_path, fname+".xml"))
                station = obspy.read_inventory(os.path.join(station_path, fname+".xml"), format="STATIONXML")
                os.remove(os.path.join(station_path, fname+".xml"))
                station_exist = True
            else:
                try:
                    station = client.get_stations(network=pick.network, station=pick.station, location=pick.location_code, 
                                                  channel=channel+"?", starttime=Tstart, endtime=Tend, level="response")
                    station_exist = True
                    station_upload = True
                except Exception as e:
                    if str(e)[:len("No data available")] != "No data available":
                        print("Failed downloading station: "+fname, "Error: "+str(e))
                        
            waveform_exist = False
            waveform_upload = False
            if exist_blob(os.path.join(waveform_path, fname+".mseed")):
                download_blob(os.path.join(waveform_path, fname+".mseed"), os.path.join(waveform_path, fname+".mseed"))
                stream = obspy.read(os.path.join(waveform_path, fname+".mseed"), format="MSEED")
                os.remove(os.path.join(waveform_path, fname+".mseed"))
                waveform_exist = True
            else:
                try:                
                    stream = client.get_waveforms(network=pick.network, station=pick.station, location=pick.location_code, 
                                                  channel=channel+"?", starttime=Tstart, endtime=Tend)
                    stream = stream.detrend('linear')
                    stream = stream.merge(fill_value=0)
                    stream = stream.trim(Tstart, Tend, pad=True, fill_value=0)
                    stream = resample(stream, 100) ## resample to 100 Hz
                    stream = stream.sort()
                    waveform_exist = True
                    waveform_upload = True
                    
                except Exception as e:
                    if str(e)[:len("No data available")] != "No data available":
                        print("Failed downloading waveform: "+fname, "Error: "+str(e))
            
            if waveform_exist and station_exist:
                
                try:
                    stream.attach_response(station)
                    stream.remove_sensitivity()
                    coord = station.get_coordinates(stream[0].get_id(), datetime=Tstart)
                    response = station.get_response(stream[0].get_id(), datetime=Tstart)
                except:
                    print("Error: ", stream[0].get_id(), station)
                    continue
                    
                sta_id = pick.network+"."+pick.station+"."+pick.location_code+"."+channel
                unit = response.instrument_sensitivity.input_units.lower()
                station_tuple = Station(id=sta_id, latitude=np.round(coord["latitude"], 4), longitude=np.round(coord["longitude"], 4), 
                                  elevation_m=coord["elevation"], unit=unit)

                station_list.append(station_tuple)
                stream_list.append(stream)
        
                if waveform_upload:
                    stream.write(os.path.join(waveform_path, fname+".mseed"),  format="MSEED")
                    upload_blob(os.path.join(waveform_path, fname+".mseed"), os.path.join(waveform_path, fname+".mseed"))
                    os.remove(os.path.join(waveform_path, fname+".mseed"))
                if station_upload:
                    station.write(os.path.join(station_path, fname+".xml"), format="STATIONXML")
                    upload_blob(os.path.join(station_path, fname+".xml"), os.path.join(station_path, fname+".xml"))
                    os.remove(os.path.join(station_path, fname+".xml"))

        return stream_list, station_list

    def calc_snr(vec, anchor, dt):
        npts = int(3/dt)
        eps = 10*np.finfo(vec.dtype).eps
        snr = (np.std(vec[anchor:anchor+npts, :], axis=0)+eps) / (np.std(vec[anchor-npts:anchor, :], axis=0)+eps)
        return snr

    def data2vec(i, data, vec, shift, window_size):
        if shift >= 0:
            if len(data)+shift <= window_size:
                vec[shift:len(data)+shift, i] = data
            else:
                vec[shift:, i] = data[:window_size-shift]

        else:
            if len(data)+shift <= window_size:
                vec[:len(data)+shift, i] = data[-shift:]
            else:
                vec[:, i] = data[-shift:window_size-shift]
        return vec

    Extra = namedtuple("pick_extra", ["p_idx", "s_idx", "channels", "snr", "dt", "station_index", "latitude", "longitude", "elevation_m", "unit", "fname"])
    def convert_sample(pick, event, stream, station, sample_path):

        dt = stream[-1].stats.delta
        npts = stream[-1].stats.npts 
        starttime = stream[-1].stats.starttime                                  
        endtime = stream[-1].stats.endtime

        p_idx = int(np.around( (pick.p_time - starttime)/(endtime - starttime)*npts )) 
        s_idx = int(np.around( (pick.s_time - starttime)/(endtime - starttime)*npts ))

        anchor = 6000
        window_size = 12000
        vec = np.zeros([window_size, 3])
        shift = anchor - p_idx
        if np.abs(shift) <= 1:
            shift = 0
        p_idx += shift
        s_idx += shift

        order = ['3','2','1','E','N','Z']
        order = {key: i for i, key in enumerate(order)}
        comps = [x.get_id() for x in stream]
        try:
            comps = sorted(comps, key=lambda x: order[x[-1]])
        except:
            print(f"Unknown channels: {comps}")
            return (1, None)

        fname = comps[0][:-1]+"."+f"{pick.event_index:07d}"+".npz"
        if len(comps) == 3:
            for i, c in enumerate(comps):
                data = stream.select(id=c)[0].data
                data2vec(i, data, vec, shift, window_size)
        elif len(comps) < 3:
            for c in comps:
                if c[-1] == "E":
                    i = 0
                elif c[-1] == "N":
                    i = 1
                elif c[-1] == "Z":
                    i = 2
                else:
                    print(f"Unknown channels: {comps}")
                    return (1, None)
                data = stream.select(id=c)[0].data
                data2vec(i, data, vec, shift, window_size)
        else:
            print(f"Unknown channels: {comps}")
            return (1, None)

        snr = calc_snr(vec, anchor, dt)
        channels = ",".join([x.split(".")[-1] for x in comps])
        extra = Extra(p_idx=p_idx, s_idx=s_idx, snr=tuple(snr.tolist()), dt=dt,
                      station_index=station.id, channels=channels,
                      latitude=station.latitude, longitude=station.longitude, elevation_m=station.elevation_m,
                      unit=station.unit, fname=fname)
        np.savez(os.path.join(sample_path, fname), 
                data=vec.astype("float32"), dt=dt, p_idx=p_idx, s_idx=s_idx, snr=snr.tolist(), 
                p_time=pick.p_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3], s_time=pick.s_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3],
                p_remark=pick.p_remark, p_weight=pick.p_weight, s_remark=pick.s_remark, s_weight=pick.s_weight,
                first_motion=pick.first_motion, distance_km=pick.distance_km, 
                azimuth=pick.azimuth, emergence_angle=pick.emergence_angle, 
                network=pick.network, station=pick.station, location_code=pick.location_code, 
                station_latitude=station.latitude, station_longitude=station.longitude, station_elevation_m=station.elevation_m,
                event_latitude=event.latitude, event_longitude=event.longitude, event_depth_km=event.depth_km,
                event_time=event.time, event_magnitude=event.magnitude, event_magnitude_type=event.magnitude_type,
                unit=station.unit, channels=channels, event_index=pick.event_index)

#         upload_blob(os.path.join(sample_path, fname), os.path.join(sample_path, fname))
        upload_blob(os.path.join(sample_path, fname), os.path.join(sample_path, fname))
        os.remove(os.path.join(sample_path, fname))
        return (0, extra)

    # %%
    def create_dataset(idx, events, phases, waveform_path, sample_path, station_path):
        events_ = []
        phases_ = []
        extras_ = []
#         for i in tqdm(idx, desc="create dataset"):
        for i in idx:
            print(f"Create dataset {i}/{max(idx)}")
            event = read_event_line(events[i], i)
            # events_.append(event)
            phase_lines = phases[i]
            has_phase = False
            for j in range(0, len(phase_lines), 2):
                pick = read_phase_line(phase_lines[j], phase_lines[j+1], i)
                # phases_.append(pick)
                (waveforms, stations) = download_waveform(pick, waveform_path, station_path)
                assert(len(waveforms)==len(stations))
                for w, s in zip(waveforms, stations):
                    (status, extra) = convert_sample(pick, event, w, s, sample_path)
                    if (status == 0):
                        phases_.append(pick)
                        extras_.append(extra)
                        has_phase = True
            if has_phase:
                events_.append(event)

        return events_, phases_, extras_

    waveform_path = "mseeds"
    station_path = "stations"
    sample_path = "data"
    if not os.path.exists(waveform_path):
        os.mkdir(waveform_path)
    if not os.path.exists(station_path):
        os.mkdir(station_path)
    if not os.path.exists(sample_path):
        os.mkdir(sample_path)
    # events_tuple, phases_tuple, extra_tuple = create_dataset([len(events)-1], events, phases, waveform_path, sample_path)
    with open(events_input, "rb") as fp:
        events = pickle.load(fp)
    with open(phases_input, "rb") as fp:
        phases = pickle.load(fp)
        
    events_tuple, phases_tuple, extra_tuple = create_dataset(index, events, phases, waveform_path, sample_path, station_path)
#     events_tuple, phases_tuple, extra_tuple = create_dataset(range(len(events)//6, len(events)), events, phases, waveform_path, sample_path)

    if len(events_tuple) >= 1:
        events_df = pd.DataFrame(data=events_tuple)
        events_df["latitude"] = events_df["latitude"]
        events_df["longitude"] = events_df["longitude"]
        events_df = events_df.set_index("index")
        events_df.to_csv(f"{index[0]:03d}_events.csv", sep="\t")

    if len(phases_tuple) >= 1:
        phases_df = pd.DataFrame(data=phases_tuple)
        extra_df = pd.DataFrame(data=extra_tuple)
        phases_df["fname"] = extra_df["fname"]
        phases_df["p_idx"] = extra_df["p_idx"]
        phases_df["s_idx"] = extra_df["s_idx"]
        phases_df["dt"] = extra_df["dt"]
        phases_df["latitude"] = extra_df["latitude"]
        phases_df["longitude"] = extra_df["longitude"]
        phases_df["elevation_m"] = extra_df["elevation_m"]
        phases_df["channels"] = extra_df["channels"]
        phases_df["unit"] = extra_df["unit"]
        phases_df["snr"] = extra_df["snr"].apply(lambda x: ",".join([f"{i:.2f}" for i in x]))
        phases_df["p_time"] = phases_df["p_time"].apply(lambda x: x.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3])
        phases_df["s_time"] = phases_df["s_time"].apply(lambda x: x.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3])
        phases_df.to_csv(f"{index[0]:03d}_phases.csv", sep="\t", index=False,
            columns=["fname", "network","station","location_code", 
            "p_idx", "p_time","p_remark","p_weight",
            "s_idx", "s_time","s_remark","s_weight",
            "first_motion","distance_km","emergence_angle","azimuth",
            "latitude", "longitude", "elevation_m", "unit", "dt",
            "event_index", "channels", "snr"])
        
    # %% 
#     with open(events_output, "w") as fout:
#         with open("event_catalog.csv") as fin:
#             for line in fin:
#                 fout.write(line)
#     with open(phases_output, "w") as fout:
#         with open("phases.csv") as fin:
#             for line in fin:
#                 fout.write(line)
    
    # %%
    if os.path.exists(f"{index[0]:03d}_events.csv"):
        upload_blob(f"{index[0]:03d}_events.csv", os.path.join("catalogs", f"{index[0]:03d}_events.csv"))
    if os.path.exists(f"{index[0]:03d}_phases.csv"):
        upload_blob(f"{index[0]:03d}_phases.csv", os.path.join("catalogs", f"{index[0]:03d}_phases.csv"))
    
    with open(events_path, "w") as fout:
        with open(f"{index[0]:03d}_events.csv") as fin:
            for line in fin:
                fout.write(line)
    with open(phases_path, "w") as fout:
        with open(f"{index[0]:03d}_phases.csv") as fin:
            for line in fin:
                fout.write(line)

In [7]:
# build_dataset(index = [100000], credentials=credentials, events_input = "events.pkl", phases_input = "phases.pkl",
#               events_output = "events.out", phases_output = "phases.out")

build_dataset_op = comp.func_to_container_op(build_dataset, 
                                             base_image='python:3.8',
                                             packages_to_install= [
#                                                   "tqdm",
                                                  "obspy",
                                                  "pandas",
                                                  "google-cloud-storage"
                                              ])

In [8]:
def merge_result(data_path="/tmp/"):
    
    from glob import glob
    import pandas as pd
    from google.cloud import storage
    import os
    import re
    
    # %%
    def upload_blob(bucket_name, source_file_name, destination_blob_name):
        storage_client = storage.Client()
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(destination_blob_name)
        blob.upload_from_filename(source_file_name, timeout=3600)
        print(f"File {source_file_name} uploaded to {destination_blob_name}")
        
    def download_blob(bucket_name, prefix, data_path, delimiter=""):
        if not os.path.exists(data_path):
            os.mkdir(data_path)
        storage_client = storage.Client()
        blobs = storage_client.list_blobs(bucket_name, prefix=prefix, delimiter=delimiter)
        for blob in blobs:
            if re.match(r"[0-9]*_(events|phases).csv", blob.name.split("/")[-1]):
                source_blob_name = blob.name
                destination_file_name = os.path.join(data_path, blob.name.split("/")[-1])
                blob.download_to_filename(destination_file_name)
                print("Blob {} downloaded to {}.".format(source_blob_name, destination_file_name))

        
    download_blob("quakeflow", "catalogs", data_path)
    join_path = lambda x: os.path.join(data_path, x)
    files_events = glob(join_path("[0-9]*_events.csv"))

    if len(files_events) > 0:
        combined_events = pd.concat([pd.read_csv(f, sep="\t", dtype=str) for f in files_events ]).sort_values(by="time")
        combined_events.to_csv(join_path("combined_events.csv"), sep="\t", index=False)
        upload_blob("quakeflow", join_path(f"combined_events.csv"), os.path.join("catalogs", f"combined_events.csv"))
    else:
        print("No events.csv found!")
    
    files_phases = glob(join_path("[0-9]*_phases.csv"))
    if len(files_phases) > 0:
        combined_phases = pd.concat([pd.read_csv(f, sep="\t", dtype=str) for f in files_phases ]).sort_values(by=["p_time"])
        combined_phases.to_csv(join_path("combined_phases.csv"), sep="\t", index=False)
        upload_blob("quakeflow", join_path(f"combined_phases.csv"), os.path.join("catalogs", f"combined_phases.csv"))
    else:
        print("No phases.csv found!")

In [9]:
# merge_result("./tmp")

merge_result_op = comp.func_to_container_op(merge_result, 
                                            base_image='python:3.8',
                                            packages_to_install= [
                                                  "pandas",
                                                  "google-cloud-storage",
                                            ])

In [10]:
def merge_hdf5(data_path="/tmp/"):
    
    from google.cloud import storage
    import h5py
#     from tqdm import tqdm
    import pandas as pd
    import numpy as np
    import os

#     data_path = "/tmp/npz"
    if not os.path.exists(data_path):
        os.mkdir(data_path)
    bucket_name = "quakeflow"
    prefix = "data"
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)

    blob = bucket.blob("catalogs/combined_events.csv")
    blob.download_to_filename("combined_events.csv", timeout=3600)
    print(f"File catalogs/combined_events.csv download to combined_events.csv")
    blob = bucket.blob("catalogs/combined_phases.csv")
    blob.download_to_filename("combined_phases.csv", timeout=3600)
    print(f"File catalogs/combined_phases.csv download to combined_phases.csv")            

    file_name = "ncedc.h5"

    events = pd.read_csv("combined_events.csv", sep="\t").sort_values("index")
    events.to_hdf(file_name, '/events', format="table", mode='w')
    print(events)
    catalogs = pd.read_csv("combined_phases.csv", sep="\t")
    catalogs.to_hdf(file_name, '/catalog', format="table", mode='r+')
    print(catalogs)

    blobs = storage_client.list_blobs(bucket_name, prefix=prefix, delimiter="")
    # blobs = list(blobs)
    with h5py.File(file_name, "r+", libver='latest') as fp:
        data = fp.create_group("/data")
    #     for blob in tqdm(blobs):
#         for blob in blobs:
#         for fname in tqdm(catalogs["fname"]):
        for fname in catalogs["fname"]:
#             source_blob_name = blob.name
#             fname = blob.name.split("/")[-1]
            source_blob_name = "data/{fname}"
            destination_file_name = os.path.join(data_path, fname)
            blob.download_to_filename(destination_file_name)
            print("Blob {} downloaded to {}.".format(source_blob_name, destination_file_name))

            meta = np.load(destination_file_name)
            ds = data.create_dataset(fname, data=meta["data"], dtype="float32")
            for k in meta:
                if k != "data":
                    if meta[k].dtype.type is np.str_:
                        ds.attrs[k] = str(meta[k])
                    else:
                        ds.attrs[k] = meta[k]
    #         print(ds.shape, dict(ds.attrs))

            os.system(f"rm {destination_file_name}")
#             print(f"rm {destination_file_name}")
    #         raise
#             break

    
    blob = bucket.blob(file_name)
    blob.upload_from_filename(file_name, timeout=3600)
    print(f"File {file_name} uploaded to {file_name}")


In [11]:
merge_hdf5_op = comp.func_to_container_op(merge_hdf5, 
                                        base_image='python:3.8',
                                        packages_to_install= [
                                              "pandas",
                                              "google-cloud-storage",
                                              "numpy",
                                              "h5py",
                                              "tables"
                                        ])

In [12]:
# Define the pipeline
@dsl.pipeline(name='QuakeFlow', description='')
def dataset_pipeline():
    
    data_path = "/tmp/"
    raw_catalog = download_catalog_op()#.set_memory_request("60G")
    raw_catalog.execution_options.caching_strategy.max_cache_staleness = "P30D"
    
    ps_catalog = read_ps_catalog_op(raw_catalog.outputs["catalog"])
    ps_catalog.execution_options.caching_strategy.max_cache_staleness = "P30D"
    
    with kfp.dsl.ParallelFor(ps_catalog.outputs["output"]) as idx:
        
        build_dataset_op_ = build_dataset_op(idx, ps_catalog.outputs["index"], ps_catalog.outputs["events"], ps_catalog.outputs["phases"]).set_memory_request("1100M")#.set_retry(1)
        build_dataset_op_.execution_options.caching_strategy.max_cache_staleness = "P30D"

    csv = merge_result_op(data_path).after(build_dataset_op_)
    csv.execution_options.caching_strategy.max_cache_staleness = "P0D"
    
#     hdf5 = merge_hdf5_op(data_path).after(csv)
#     hdf5.execution_options.caching_strategy.max_cache_staleness = "P0D"

In [13]:
import os
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/Users/weiqiang/.dotbot/cloud/quakeflow_wayne.json"

host_url = "10b36473d207cad7-dot-us-west1.pipelines.googleusercontent.com"
client = kfp.Client(host=host_url)

experiment_name = 'Dataset'
pipeline_func = dataset_pipeline
run_name = pipeline_func.__name__ + '_run'

arguments = {}

# Compile pipeline to generate compressed YAML definition of the pipeline.
# kfp.compiler.Compiler().compile(pipeline_func, '{}.zip'.format(experiment_name))

# Submit pipeline directly from pipeline function
results = client.create_run_from_pipeline_func(pipeline_func, 
                                               experiment_name=experiment_name, 
                                               run_name=run_name, 
                                               arguments=arguments)