In [1]:
import numpy as np

def modes_for_energy(S, energy_frac=0.95):
    """
    Compute the minimal number of singular modes needed to capture
    at least `energy_frac` (default 0.95) of the total energy of S.
    """
    # Full SVD (you can swap in a truncated SVD if S is huge)
    U, s, Vt = np.linalg.svd(S, full_matrices=False)
    # Energy in each mode is sigma_i^2
    energies = s**2
    # Cumulative energy ratio
    cum_energy = np.cumsum(energies) / np.sum(energies)
    # Find the first k where cum_energy[k-1] >= energy_frac
    k95 = np.searchsorted(cum_energy, energy_frac) + 1
    return k95, cum_energy

In [2]:
obs_path = "gs://weatherbench2/datasets/era5/1959-2022-full_37-1h-0p25deg-chunk-1.zarr-v2/"
import xarray as xr
import pickle
import h5py
import numpy as np

ds = xr.open_zarr(obs_path)

In [3]:
splits = {

    # "train": ['2019-01-01T00:00:00', '2019-02-11T00:00:00'], v1
    # "test":  ['2019-05-10T00:00:00', '2019-05-15T03:00:00'],
    # "val":   ['2019-07-20T00:00:00', '2019-07-30T03:00:00'],

    "train": ['2019-01-01T00:00:00', '2019-03-11T00:00:00'],
    "test":  ['2019-05-10T00:00:00', '2019-06-15T03:00:00'],
    "val":   ['2019-07-20T00:00:00', '2019-08-30T03:00:00'],

    # "train": ['2019-01-01T00:00:00', '2019-01-01T04:00:00'],
    # "test":  ['2019-05-10T00:00:00', '2019-05-10T03:00:00'],
    # "val":   ['2019-07-20T00:00:00', '2019-07-20T02:00:00'],

}

lat_slice = slice(50., 10.)    
lon_slice = slice(260., 295.)

# vars_list = ['10m_u_component_of_wind', '10m_v_component_of_wind']

vars_list = ['10m_u_component_of_wind']

dataset = {}

for split, timerange in splits.items():
    sub = ds[vars_list] \
          .sel(latitude=lat_slice,
               longitude=lon_slice,
               time=slice(*timerange))
    da = sub.to_array(dim="variable") \
            .transpose("time", "variable", "latitude", "longitude")

    dataset[split] = {
        "data":     da.values,         # (T, V, Y, X)
        "times":    da.time.values,
        "lat_vals": da.latitude.values,
        "lon_vals": da.longitude.values,
    }


train_data = dataset['train']['data']
S_train_org = train_data[:,0,:,:]
S_train = S_train_org.reshape((S_train_org.shape[0], S_train_org.shape[1] * S_train_org.shape[2])).T

test_data = dataset['test']['data']
S_test_org = test_data[:,0,:,:]
S_test = S_test_org.reshape((S_test_org.shape[0], S_test_org.shape[1] * S_test_org.shape[2])).T

val_data = dataset['val']['data']
S_val_org = val_data[:,0,:,:]
S_val = S_val_org.reshape((S_val_org.shape[0], S_val_org.shape[1] * S_val_org.shape[2])).T

# min_train_val, max_train_val = S_train.min(), S_train.max()

# S_train_normalized = (S_train - min_train_val)/ (max_train_val - min_train_val)
# S_val_normalized = (S_val - min_train_val)/ (max_train_val - min_train_val)
# S_test_normalized = (S_test - min_train_val)/ (max_train_val - min_train_val)

X_train = S_train[:, :-1]
Y_train = S_train[:, 1: ]

# U_vals, sing_vals, Vt_vals = np.linalg.svd(S_train, full_matrices= False)
U_vals, sing_vals, Vt_vals = np.linalg.svd(X_train, full_matrices= False)

train_vals =  {'data': S_train, 'time': dataset['train']['times'] }
test_vals =   {'data': S_test, 'time':  dataset['test']['times'] }
val_vals =    {'data': S_val, 'time':   dataset['val']['times'] }

r95, cum_e = modes_for_energy(S_train, 0.95)

# u_10m_comp_vals_3 = {
#     'lat_vals': dataset['train']['lat_vals'],
#     'lon_vals': dataset['train']['lon_vals'],
#     'train': train_vals,
#     'test': test_vals,
#     'val': val_vals,
#     'U_vals': U_vals,
#     'sing_vals': sing_vals,
#     'Vt_vals': Vt_vals,
#     'r95': r95,
#     'min_train_val': min_train_val,
#     'max_train_val': max_train_val,
# }

# with open('../data/u_10m_comp_vals_3_normalized.pkl', 'wb') as f:
#     pickle.dump(u_10m_comp_vals_3, f)

In [4]:
# print(S_train.shape, X_train.shape, U_vals.shape)

print(S_train_org.shape, S_test_org.shape, S_val_org.shape)


(1657, 161, 141) (868, 161, 141) (988, 161, 141)


In [None]:
with open('../data/u_10m_comp_vals_3_normalized.pkl', 'rb') as f:
    sample = pickle.load(f)

# print(sample['train']['data'].shape) # (22701, 985) # v1
# print(sample['test']['data'].shape) # (22701, 124)
# print(sample['val']['data'].shape) # (22701, 244)

print(sample['train']['data'].shape) # (22701, 1657)
print(sample['test']['data'].shape) # (22701, 868)
print(sample['val']['data'].shape) # (22701, 988)

(22701, 1657)
(22701, 868)
(22701, 988)


In [7]:
with open('../data/u_10m_comp_vals_3_normalized.pkl', 'rb') as f:
    sample = pickle.load(f)

print(sample['train']['data'].shape) # (22701, 985)
print(sample['test']['data'].shape) # (22701, 124)
print(sample['val']['data'].shape) # (22701, 244)


(22701, 1657)
(22701, 868)
(22701, 988)


In [22]:
# print(S_train_normalized.shape) # (22701, 1657)
# print(S_train_normalized.min(), ":", S_train_normalized.max())
# print(S_val_normalized.min(), ":", S_val_normalized.max())
# print(S_test_normalized.min(), ":", S_test_normalized.max())
# U_vals, sing_vals, Vt_vals

# print(U_vals.min(), ":", U_vals.max()) # -0.09883697 : 0.065371275
import jax.numpy as jnp
l_val = 20
r_val = 8
p_val = 20 # number of non-linearities to select from r_val x len(library)

U_l = U_vals[:, : l_val]
Sig_l = sing_vals[: l_val]
Vt_l = Vt_vals[: l_val, :]

# print(U_l.min(), ":", U_l.max()) # -0.042215355 : 0.043401726

phi_mat = jnp.vstack([
    jnp.eye(r_val, dtype=jnp.float32),
    jnp.zeros((l_val - r_val, r_val), dtype=jnp.float32),
])

U_r = U_l @ phi_mat
X_hat = U_r.T @ X_train

# print(X_hat.min(), ":", X_hat.max()) # -76.05583 : 14.7015705




