In [1]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from util import *
from joblib import Parallel, delayed


def batch_get_sediment_quantity(date_range, site_codes):
    sediment_data = pd.read_csv('discharge/SSC_sediment.csv', parse_dates=['date']).fillna(0)
    discharge_data = pd.read_csv('discharge/SSC_discharge.csv', parse_dates=['date'])

    sediment_data = sediment_data[sediment_data['date'].isin(date_range)]
    discharge_data = discharge_data[discharge_data['date'].isin(date_range)]

    sediment_data.set_index(['date'], inplace=True)
    discharge_data.set_index(['date'], inplace=True)

    return sediment_data, discharge_data


def process_site(site, date_range, sediment_data, discharge_data):
    N = len(date_range)
    data = np.zeros((N, 2))
    for j, date in enumerate(date_range):
        if date in sediment_data.index and str(site) in sediment_data.columns:
            data[j, 0] = sediment_data.loc[date, str(site)]
        if date in discharge_data.index and str(site) in discharge_data.columns:
            data[j, 1] = discharge_data.loc[date, str(site)]
    return data


def generate_spatio_temporal_data(start_date, end_date, site_codes):
    start_date = datetime.strptime(start_date, '%Y-%m-%d')
    end_date = datetime.strptime(end_date, '%Y-%m-%d')
    date_range = pd.date_range(start_date, end_date, freq='D')
    sediment_data, discharge_data = batch_get_sediment_quantity(date_range, site_codes)
    results = Parallel(n_jobs=-1)(delayed(process_site)(site, date_range, sediment_data, discharge_data) for site in site_codes)
    data = np.stack(results, axis=0)

    return data

In [2]:
start_date = '2015-04-15'
end_date = '2022-12-24'
site_codes = ["4178000", "4182000", "4183000", "4183500", "4184500", "4185000", "4185318", "4185440", "4186500", "4188100",
              "4188496", "4189000", "4190000", "4191058", "4191444", "4191500", "4192500", "4192574", "4192599", "4193500"]

spatio_temporal_data = generate_spatio_temporal_data(start_date, end_date, site_codes)
print(f"Data shape: {spatio_temporal_data.shape}")
# (N, T, F) -> (20, 2811, 2) 对应 site 数量, 时间步数, 特征数
# F=0: sendiment, F=1: discharge
print(spatio_temporal_data)
np.save('data/discharge/data_encoder.npy', spatio_temporal_data)

Data shape: (20, 2811, 2)
[[[0.0000e+00 4.7700e+02]
  [0.0000e+00 4.1800e+02]
  [0.0000e+00 3.8000e+02]
  ...
  [0.0000e+00 1.5700e+02]
  [0.0000e+00 1.3300e+02]
  [0.0000e+00 1.1700e+02]]

 [[0.0000e+00 1.1800e+03]
  [0.0000e+00 7.9500e+02]
  [0.0000e+00 5.6800e+02]
  ...
  [0.0000e+00 3.0300e+01]
  [0.0000e+00 2.7800e+01]
  [0.0000e+00 2.9200e+01]]

 [[0.0000e+00 2.0800e+03]
  [0.0000e+00 1.6100e+03]
  [0.0000e+00 1.3700e+03]
  ...
  [0.0000e+00 3.8000e+02]
  [0.0000e+00 3.7000e+02]
  [0.0000e+00 3.2300e+02]]

 ...

 [[0.0000e+00 1.3000e+01]
  [0.0000e+00 1.1200e+01]
  [0.0000e+00 1.0200e+01]
  ...
  [0.0000e+00 1.0000e-02]
  [0.0000e+00 1.0000e-01]
  [0.0000e+00 1.0000e-02]]

 [[0.0000e+00 6.1100e+01]
  [0.0000e+00 5.3200e+01]
  [0.0000e+00 4.8000e+01]
  ...
  [0.0000e+00 4.6300e+00]
  [0.0000e+00 7.5400e+00]
  [0.0000e+00 7.5100e+00]]

 [[9.3065e+01 7.6300e+03]
  [5.8823e+01 5.4800e+03]
  [5.6013e+01 3.2900e+03]
  ...
  [0.0000e+00 8.8100e+02]
  [0.0000e+00 9.1800e+02]
  [0.0000e+0

In [3]:
spatio_temporal_data.shape

(20, 2811, 2)