In [None]:
from pathlib import Path
import obspy
import pandas as pd
import torch
from tqdm import tqdm
from mccc import MCCCPicker

In [None]:
input_path = Path("output/waveforms_by_station")
output_path = Path("output/mccc_results")
if not output_path.exists():
    output_path.mkdir(parents=True)

In [None]:
events = pd.read_csv("output/selected_events.csv")
plt.figure()
plt.scatter(events["longitude"], events["latitude"], s=1)
plt.axis("scaled")
plt.show()

In [None]:
events = pd.read_csv("output/filt_events.csv")
print(len(events))
plt.figure()
plt.scatter(events["longitude"], events["latitude"], s=1)
plt.axis("scaled")
plt.show()

In [None]:
events = events.sort_values(by="latitude").reset_index(drop=True)
events_loa = events[events["latitude"] < 19.52]
events_kea = events[events["latitude"] > 19.8]
events_chain = events[(events["latitude"] >= 19.52) & (events["latitude"] <= 19.8)]

index_loa = events_loa.index
index_kea = events_kea.index
index_chain = events_chain.index

In [None]:
plt.figure(figsize=(4, 8))
plt.scatter(events_chain["longitude"], events_chain["latitude"], c="C1", s=2, label=f"Ambiguous: {len(events_chain)}")
plt.scatter(events_loa["longitude"], events_loa["latitude"], c="C0", s=2, label=f"Loa: {len(events_loa)}")
plt.scatter(events_kea["longitude"], events_kea["latitude"], c="C2", s=2, label=f"Kea: {len(events_kea)}")
plt.axis("scaled")
# plt.grid("on")
plt.legend(markerscale=6)
plt.savefig("events.png", dpi=300, bbox_inches="tight")
plt.show()

In [7]:
files = list(input_path.glob("*.npy"))
for f in tqdm(files):

    waveforms = np.load(f, allow_pickle=True)
    if len(waveforms) == 0:
        continue
    min_nt = min([len(w) for w in waveforms])
    data = []
    for w in waveforms:
        data.append(w[:min_nt])
    data = np.array(data)
    data = torch.tensor(data, dtype=torch.float32)

    
    picker = MCCCPicker(data.cuda(), 0.01, mccc_mincc=0.5, mccc_maxlag=1.0, mccc_maxwin=100, max_niter=1, chunk_size=1000, ma=1, damp=10, whitening_waterlevel=0.1, mode="align")
    result = picker.solve()

    cc_dt = result["cc_dt"].cpu().detach().numpy()
    plt.figure()
    plt.plot(cc_dt)
    plt.savefig(output_path / f"{f.stem}_dt.png", dpi=300, bbox_inches="tight")
    plt.close()
    # plt.show()

    if np.max(np.abs(cc_dt)) < 1.0:
        cc_dt *= 0.0

    plt.figure(figsize=(8, 30))
    normalize = lambda x: (x - x.mean()) / x.std()
    t = np.arange(min_nt) * 0.01
    t_interp = np.linspace(-6, 31, 3701)
    y = []
    for i in range(len(data)):
        y_interp = np.interp(t_interp, t-cc_dt[i], normalize(data[i]), left=np.nan, right=np.nan)
        y.append(y_interp)
        plt.plot(t_interp, y_interp/6 + i, linewidth=0.5, color="k")
    plt.grid("on")
    plt.ylim([-1, len(data)+1])
    plt.xlabel("Time (s)")
    plt.ylabel("Event #")
    plt.savefig(output_path / f"{f.stem}_waveform.png", dpi=300, bbox_inches="tight")
    plt.close()
    # plt.show()

    y = np.array(y)
    plt.figure(figsize=(10, 20))
    plt.pcolormesh(t_interp, np.arange(y.shape[0]),  y, cmap="seismic", vmin=-1.5, vmax=1.5)
    plt.grid("on")
    plt.xlabel("Time (s)")
    plt.ylabel("Event #")
    plt.savefig(output_path / f"{f.stem}_colormesh.png", dpi=300, bbox_inches="tight")
    plt.close()
    # plt.show()

    # break