# Idealized Synthetic Data

*Under development*

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
from IPython.display import display  # so can run as script too

from melodies_monet import driver

In [None]:
an = driver.analysis()
an.control = "control_idealized.yaml"
an.read_control()
an

````{admonition} Note: This is the complete file that was loaded.
:class: dropdown

```{literalinclude} control_idealized.yaml
:caption:
:linenos:
```
````

## Generate data

### Model

In [None]:
rs = np.random.RandomState(42)

control = an.control_dict

nlat = 100
nlon = 200

lon = np.linspace(-161, -60, nlon)
lat = np.linspace(18, 60, nlat)
Lon, Lat = np.meshgrid(lon, lat)

time = pd.date_range(control['analysis']['start_time'], control['analysis']['end_time'], freq="3H")
ntime = time.size

# Generate translating and expanding Gaussian
x_ = np.linspace(-1, 1, lon.size)
y_ = np.linspace(-1, 1, lat.size)
x, y = np.meshgrid(x_, y_)
mu = np.linspace(-0.5, 0.5, ntime)
sigma = np.linspace(0.3, 1, ntime)
g = np.exp(
    -(
        (
            (x[np.newaxis, ...] - mu[:, np.newaxis, np.newaxis])**2
            + y[np.newaxis, ...]**2
        ) / ( 
            2 * sigma[:, np.newaxis, np.newaxis]**2
        ) 
    ) 
)

# Coordinates
lat_da = xr.DataArray(lat, dims="lat", attrs={'longname': 'latitude', 'units': 'degN'}, name="lat")
lon_da = xr.DataArray(lon, dims="lon", attrs={'longname': 'longitude', 'units': 'degE'}, name="lon")
time_da = xr.DataArray(time, dims="time", name="time")

# Generate dataset
field_names = control['model']['test_model']['variables'].keys()
ds_dict = dict()
for field_name in field_names:
    units = control['model']['test_model']['variables'][field_name]['units']
    # data = rs.rand(ntime, nlat, nlon)
    data = g
    da = xr.DataArray(
        data,
        # coords={"lat": lat_da, "lon": lon_da, "time": time_da},
        coords=[time_da, lat_da, lon_da],
        dims=['time', 'lat', 'lon'],
        attrs={'units': units},
    )
    ds_dict[field_name] = da
ds = xr.Dataset(ds_dict).expand_dims("z", axis=1)
ds["z"] = [1]

ds_mod = ds
ds_mod

In [None]:
ds.squeeze("z").A.plot(col="time", col_wrap=5, size=3);

In [None]:
ds.to_netcdf(control['model']['test_model']['files'])

### Obs

In [None]:
from string import ascii_lowercase

# Generate positions
# TODO: only within land boundaries would be cleaner
n = 500
lats = rs.uniform(lat[0], lat[-1], n)
lons = rs.uniform(lon[0], lon[-1], n)
siteid = np.array([f"{x:0{len(str(n))}}" for x in range(n)])[np.newaxis, :]
something_vlen = np.array(
    [
        "".join(rs.choice(list(ascii_lowercase), size=x, replace=True))
        for x in rs.randint(low=2, high=8, size=n)
    ],
    dtype=str,
)[np.newaxis, :]

# Generate dataset
field_names = control['model']['test_model']['variables'].keys()
ds_dict = dict()
for field_name0 in field_names:
    field_name = control['model']['test_model']['mapping']['test_obs'][field_name0]
    units = control['model']['test_model']['variables'][field_name0]['units']
    values = (
        ds_mod.A.squeeze().interp(lat=xr.DataArray(lats), lon=xr.DataArray(lons)).values
        + rs.normal(scale=0.3, size=(ntime, n))
    )[:, np.newaxis]
    da = xr.DataArray(
        values,
        coords={
            "x": ("x", np.arange(n)),  # !!!
            "time": ("time", time),
            "latitude": (("y", "x"), lats[np.newaxis, :], lat_da.attrs),
            "longitude": (("y", "x"), lons[np.newaxis, :], lon_da.attrs),
            "siteid": (("y", "x"), siteid),
            "something_vlen": (("y", "x"), something_vlen),
        },
        dims=("time", "y", "x"),
        attrs={'units': units},
    )
    ds_dict[field_name] = da
ds = xr.Dataset(ds_dict)
ds

In [None]:
ds.to_netcdf(control['obs']['test_obs']['filename'])

## Load

In [None]:
an.open_models()

In [None]:
an.models['test_model'].obj

In [None]:
an.open_obs()

In [None]:
an.obs['test_obs'].obj

## Pair

In [None]:
%%time

an.pair_data()

In [None]:
an.paired

In [None]:
an.paired['test_obs_test_model'].obj

In [None]:
an.paired['test_obs_test_model'].obj.dims

## Plot

In [None]:
%%time

an.plotting()

## Save/load paired data -- netCDF

And compare to the original pair object.

In [None]:
from copy import deepcopy

p0 = deepcopy(an.paired['test_obs_test_model'].obj)

an.save_analysis()
an.read_analysis()
p1 = deepcopy(an.paired['test_obs_test_model'].obj)
p1.close()

display(p0)
display(p1)
assert p1 is not p0 and p1.equals(p0)

## Save/load paired data -- Python object

In [None]:
print(an.save)
an.save["paired"]["method"] = "pkl"
del an.save["paired"]["prefix"]
# We could leave `prefix` since unused, but we need to set `output_name`
an.save["paired"]["output_name"] = "asdf.joblib"
print("->", an.save)

print()
print(an.read)
an.read["paired"]["method"] = "pkl"
an.read["paired"]["filenames"] = "asdf.joblib"
print("->", an.read)

print()
an.save_analysis()
an.read_analysis()
p2 = deepcopy(an.paired['test_obs_test_model'].obj)
p2.close()

# display(p0)
display(p2)
assert p2 is not p0 and p2.equals(p0)