# Read GNN outputs (graphs) into xarray
* read all graphs predictions
* Average the results for the overlapping results
* Retrend and rescale the results
* Concat in an xarray dataset

In [None]:
import torch
import xarray as xr
import numpy as np
import pandas as pd
import joblib
import matplotlib.pyplot as plt
from scipy.special import softmax

## read result graphs

In [None]:
results = torch.load('../results_400km_k10.pt', map_location=torch.device('cpu'))
results[0]

In [None]:
len(results)

## creating netcdf

In [None]:
dims = ['station', 'time', 'pred_id']

In [None]:
for i, index in enumerate(indexes):
    for pred_nb in range(10):
        try:
            data = results[index+pred_nb]
            xr.Dataset(
                        data_vars=dict(
                            pred_n=(dims, np.array(data.n_out).astype(np.float32).reshape(-1, 30, 1),
                               dict(description="scaled prediction")),
                            pred_e=(dims, np.array(data.e_out).astype(np.float32).reshape(-1, 30, 1),
                               dict(description="scaled prediction")),
                            pred_z=(dims, np.array(data.z_out).astype(np.float32).reshape(-1, 30, 1),
                               dict(description="scaled prediction")),
                            input_n=(dims, np.array(data.signal_n).astype(np.float32).reshape(-1, 30, 1),
                               dict(description="scaled input")),
                            input_e=(dims, np.array(data.signal_e).astype(np.float32).reshape(-1, 30, 1),
                               dict(description="scaled input")),
                            input_z=(dims, np.array(data.signal_z).astype(np.float32).reshape(-1, 30, 1),
                               dict(description="scaled input")),
                        ),
                        coords=dict(
                            station=(["station"], data.id),
                            time=(pd.date_range(start=data.date_start, periods=30, freq='1D')),
                            pred_id=([pred_nb])
                        ),
                    ).to_netcdf(f"./tmp/pred{pred_nb}_{str(data.date_start)[:10]}.nc")
        except IndexError:
            break
        
    if i%100==0:
        print(f"{i+100} snapshots processed")

In [None]:
ds_pred = []
for i in range(10):
    ds_pred.append(xr.open_mfdataset(f"./tmp/pred{i}*"))
    print(f"read pred nb {i}")

In [None]:
ds_pred_merge = xr.merge(ds_pred)
ds_pred_merge

## Calculate denoised signal

In [None]:
ds_pred_merge['denoised_n'] = ds_pred_merge['input_n']-ds_pred_merge['pred_n']
ds_pred_merge['denoised_e'] = ds_pred_merge['input_e']-ds_pred_merge['pred_e']
ds_pred_merge['denoised_z'] = ds_pred_merge['input_z']-ds_pred_merge['pred_z']
ds_pred_merge

In [None]:
# calculate std
ds_pred_merge['denoised_std_n'] = ds_pred_merge['denoised_n'].std(dim="pred_id")
ds_pred_merge['denoised_std_e'] = ds_pred_merge['denoised_e'].std(dim="pred_id")
ds_pred_merge['denoised_std_z'] = ds_pred_merge['denoised_z'].std(dim="pred_id")
ds_pred_merge

In [None]:
# simple average for result
ds_pred_merge = ds_pred_merge.mean(dim="pred_id").compute()
ds_pred_merge

## combine with original ds

In [None]:
ds = xr.open_dataset('../original_dataset.nc')
ds = ds.drop_vars(['n_norm', 'e_norm', 'z_norm'])
ds

In [None]:
ds_out = xr.merge([ds, ds_pred_merge]).sel(time=slice("2000-01-01", "2023-12-31"))
ds_out

In [None]:
for i in ['e', 'n', 'z']:
    ds_out[i] = ds_out[i].astype(np.float32)
ds_out

## retrend and unscale

In [None]:
scalers = joblib.load("../scalers_daily")

def inv_min_max(x_scaled, xmin, xmax):
    return ((xmax - xmin) / 2) * (x_scaled + 1) + xmin
scalers

In [None]:
ds_fit = xr.open_dataset("../fit_daily.nc")
ds_fit

In [None]:
dims = ['station', 'time']

In [None]:
var_to_modify = ['pred_n', 'pred_e', 'pred_z', 'input_n', 'input_e', 'input_z', 'denoised_n', 'denoised_e', 'denoised_z', 'denoised_std_n', 'denoised_std_e', 'denoised_std_z']
#ds_fit = xr.open_dataset("./stations/fit_all.nc")
ds_out = ds_out.sel(time=slice('2010-01-01', '2023-12-31'))

for var in var_to_modify:
    tmp = inv_min_max(ds_out[var], scalers[var[-1]][0], scalers[var[-1]][1])
    ds_out = ds_out.assign(variables={f"detrend_{var}": (dims, (tmp).data.astype(np.float32))})
    ds_out = ds_out.assign(variables={f"final_{var}": (dims, (tmp+ds_fit[var[-1]]).data.astype(np.float32))})
    
ds_out = ds_out.drop_vars(['final_denoised_std_n', 'final_denoised_std_e', 'final_denoised_std_z'])
ds_out

In [None]:
for c in ['n', 'e', 'z']:
    ds_out = ds_out.assign(variables={f"final_error_{c}": (dims, (abs(ds_out[f'final_pred_{c}']-ds_out[f'final_input_{c}'])).data.astype(np.float32))})
ds_out

# verrification

In [None]:
import matplotlib.pyplot as plt

In [None]:
ds_sel = ds_out.isel(time=slice(0,365*6))
ds_sel = ds_sel.dropna(dim="station", how="all", subset=['pred_n', 'pred_e', 'pred_z'])
ds_sel

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
scatter = ax.scatter(x=ds.longitude, y=ds.latitude, alpha=1, s=15)
plt.tight_layout()
plt.show()

In [None]:
fig, ax1 = plt.subplots(figsize=(8, 8))
ds_out.sel(station="ALBH", time=slice("2017-03-01", "2017-04-15")).detrend_input_e.plot(label="input", ax=ax1)
ds_out.sel(station="ALBH", time=slice("2017-03-01", "2017-04-15")).detrend_pred_e.plot(label="pred", ax=ax1)
ds_out.sel(station="ALBH", time=slice("2017-03-01", "2017-04-15")).detrend_denoised_e.plot(label="denoised", ax=ax1)

ax1.set_ylabel("E-ref(m)")
ax1.set_title(label=f"E at station ALBH: 2017-03-01, 2017-04-15")
ax1.legend(loc='center left', bbox_to_anchor=(1, 0.8))

In [None]:
ds_mean = ds_out.mean(dim='station')
ds_mean

In [None]:
i = 2
ds_mean_tmp = ds_mean.rolling(time=5, center=True, min_periods=1).mean().isel(time=slice((i-1)*1000, i*1000))
fig, ax = plt.subplots(figsize=(8, 6))
ds_mean_tmp.detrend_input_e.plot(ax=ax, label="input")
# ds_mean_tmp.detrend_pred_e.plot(ax=ax, label="pred")
ds_mean_tmp.detrend_denoised_e.plot(ax=ax, label="denoised")

ax.axhline(y = 0, color = 'r', linestyle = '-')

ax.legend(loc='upper right')

ax.set_ylabel("scaled detrend average position")
ax.set_title(label=f"scaled detrend average position in e")
plt.tight_layout()

In [None]:
fig, ax1 = plt.subplots(figsize=(8, 8))
ax2 = ax1.twinx()
ds_out.sel(station="ALBH").e.plot(label="e", ax=ax1)
ds_out.sel(station="ALBH").final_pred_e.plot(label="pred_e", ax=ax1)
ds_out.sel(station="ALBH").final_denoised_e.plot(label="denoised_e", ax=ax1)
try:
    ds_out.sel(station="ALBH").tremor_count.plot(label="tremor", ax=ax2)
except:
    print("tremor not in ds")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

In [None]:
ds_out

## final cleanup

In [None]:
ds_out = ds_out.set_coords(("latitude", "longitude"))

### Calculate daily "noise"

Noise here is defined as the average position of the network compared to 0, assuming the network is supposed to be centered on average at any time, calculated on detrend data

In [None]:
for c in ['e', 'n', 'z']:
    ds_out[f'noise_input_{c}'] = abs(ds_out[f'detrend_input_{c}']).astype(np.float32)
    ds_out[f'noise_denoised_{c}'] = abs(ds_out[f'detrend_denoised_{c}']).astype(np.float32)
ds_out

In [None]:
# adding attributes to the DS
ds_out.attrs['dataset source'] = "CWU daily solution from earthscope"
ds_out.attrs['data processing'] = "raw data are detrended for each channel (detrend_input*). Denoising is done by a GNN with (GAT), encoding MLP and decoding MLP size 512. Graph is along edges > 400km The noise is the network abs position at all time steps. The trend added back, using the CWU trend"
ds_out

## Save DS

In [None]:
# save 2010-2023
ds_out.to_netcdf("../daily_results_2023_400km_k10.nc")

## clean up the tmp folder

In [None]:
%reset

In [None]:
# delete tmp files
import os, shutil
folder = './tmp'
for filename in os.listdir(folder):
    file_path = os.path.join(folder, filename)
    try:
        if os.path.isfile(file_path) or os.path.islink(file_path):
            os.unlink(file_path)
        elif os.path.isdir(file_path):
            shutil.rmtree(file_path)
    except Exception as e:
        print('Failed to delete %s. Reason: %s' % (file_path, e))