In [5]:
import os
import glob
import numpy as np
import pandas as pd
import xarray as xr
from tqdm import tqdm
from darts import TimeSeries
from darts.models import DLinearModel

In [6]:
def check_data_folder(folder):
    return os.path.exists(folder) and os.path.isdir(folder)

def generate_date_range(start_date, end_date):
    """
    Generate a list of dates from start_date to end_date.
    """
    return pd.date_range(start=start_date, end=end_date, freq='D').strftime('%Y%m%d').tolist()

def load_data(file_path):
    """
    Load data from a NetCDF file.
    """
    if os.path.exists(file_path):
        return xr.open_dataset(file_path)
    else:
        raise FileNotFoundError(f'File not found: {file_path}')


# main program
# ========== 0. 基本設定 ==========
data_folder = "nc4"
var_name = "TLML"   # 你要用的變數名稱

if not (os.path.exists(data_folder) and os.path.isdir(data_folder)):
    raise FileNotFoundError(f"Data folder not found: {data_folder}")
print(f"Data folder found: {data_folder}")

# ========== 1. 收集所有 nc4 檔案 ==========
# 依照你目前檔名型態調整 pattern
# 若是 ...20240101.nc4.dap.nc4 就用下面這個
pattern = os.path.join(data_folder, "*.nc4.dap.nc4")
file_list = sorted(glob.glob(pattern))

if len(file_list) == 0:
    raise FileNotFoundError(f"No nc4 files found with pattern: {pattern}")

print(f"Found {len(file_list)} files.")

# ========== 2. 先讀第一個檔案，取得 shape ==========
sample = xr.open_dataset(file_list[0])
if var_name not in sample:
    raise KeyError(f"Variable '{var_name}' not found in file: {file_list[0]}")

data0 = sample[var_name]
time_dim = data0.sizes["time"]
nlat = data0.sizes["lat"]
nlon = data0.sizes["lon"]
sample.close()

print(f"Each file: time={time_dim}, lat={nlat}, lon={nlon}")

# ========== 3. 預先配置合併陣列 ==========
n_files = len(file_list)
ntot = n_files * time_dim

combined = np.empty((ntot, nlat, nlon), dtype=np.float32)
time_list = np.empty(ntot, dtype="datetime64[ns]")

# ========== 4. 逐檔讀入並填入 ==========
idx = 0
for f in tqdm(file_list, desc="Loading nc4"):
    ds = xr.open_dataset(f)
    da = ds[var_name]  # (time, lat, lon)
    t = xr.decode_cf(ds).time.values if "time" in ds else ds["time"].values

    n_t = da.shape[0]
    combined[idx:idx+n_t] = da.values.astype(np.float32)
    time_list[idx:idx+n_t] = t.astype("datetime64[ns]")

    idx += n_t
    ds.close()

print(f"Combined shape: {combined.shape}")  # (ntot, nlat, nlon)

# ========== 5. 整理時間軸，轉成 pandas.DatetimeIndex ==========
time_index = pd.to_datetime(time_list)
# 確保時間是遞增排序（通常已經是，這裡穩一點）
sort_idx = np.argsort(time_index)
time_index = time_index[sort_idx]
combined = combined[sort_idx]

print(f"Time index length: {len(time_index)}  from {time_index[0]} to {time_index[-1]}")

# ========== 6. 攤平成 cell × time（跟你原本 R code 一樣風格）==========
ntot, nlat, nlon = combined.shape
ncell = nlat * nlon

# cell × time
y_all = combined.reshape(ntot, ncell).T   # (ncell, ntot)

# lon-lat grid 對應每個 cell
lon = xr.open_dataset(file_list[0])["lon"].values
lat = xr.open_dataset(file_list[0])["lat"].values
lon_grid, lat_grid = np.meshgrid(lon, lat)
gg = np.vstack([lon_grid.ravel(), lat_grid.ravel()]).T   # (ncell, 2)

print(f"y_all shape: {y_all.shape}  (cells x time)")
print(f"gg shape: {gg.shape}        (cells x 2)")

FileNotFoundError: Data folder not found: nc4