In [1]:
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pickle
import re

In [2]:
trace_dir = Path("/p/tmp/annabu/attrici_interpolation/output_corr/testarea_31/traces/tas")

In [3]:
def get_float_from_string(file_name):
    floats_in_string = re.findall(r"[-+]?(?:\d*\.*\d+)", file_name)
    if len(floats_in_string) != 1:
        raise ValueError("there is no ore more than one float in this string")
    return float(floats_in_string[0])

# Write from single parameter files to netcdf

In [4]:
parameter_files = []
for trace_file in trace_dir.glob("**/lon*"):
    lat = get_float_from_string(trace_file.parent.name)
    lon = get_float_from_string(trace_file.name)
    data_vars = []
    with open(trace_file, "rb") as trace:
        free_params = pickle.load(trace)
    for key in free_params.keys():
        try:
            d = np.arange(len(free_params[key]))
        except TypeError as e:
            if str(e) == "len() of unsized object":
                d = np.arange(1)
            else:
                raise e

        data_vars.append(
            xr.DataArray(
                dims=["lat", "lon", "d"],
                data=free_params[key].reshape((1,1,-1)),
                coords={
                    "lat": ("lat", [lat]),
                    "lon": ("lon", [lon]),
                    "d": ("d", d),
                },
                name=key
            )
        )
    parameter_files.append(xr.merge(data_vars))

merged_parameters = xr.merge(parameter_files)

In [7]:
?merged_parameters.to_netcdf

# Write from netcdf file back to parameter files

In [5]:
test_output = Path("test_output/tas")
test_output.mkdir(exist_ok=True, parents=True)

for i in range(len(merged_parameters.lat)):
    for j in range(len(merged_parameters.lon)):
        ## If all values for a gridcell are non then no parameter file should be stored
        if len(merged_parameters["logp"].isel(lat=i, lon=j).dropna(dim="d"))==0:
            continue
        parameter_dict = {}
        for key in merged_parameters:
            parameter = merged_parameters[key].isel(lat=i, lon=j).dropna(dim="d")
            parameter_dict[key] = parameter.to_numpy().squeeze()
            
        lat = merged_parameters.isel(lat=i, lon=j).lat.item()
        lon = merged_parameters.isel(lat=i, lon=j).lon.item()
        
        outdir = test_output / f"lat_{lat}"
        outdir.mkdir(exist_ok=True, parents=True)
        with open(outdir / f"lon{lon}", "wb") as trace:
            free_params = pickle.dump(parameter_dict, trace)


# Test if merging parameter files and writing them back into single parameter files is the identity function

In [6]:
for trace_file in trace_dir.glob("**/lon*"):
    lat = get_float_from_string(trace_file.parent.name)
    lon = get_float_from_string(trace_file.name)
    data_vars = []
    with open(trace_file, "rb") as trace:
        params_from_model = pickle.load(trace)
    with open(test_output / trace_file.parent.name / trace_file.name, "rb") as trace:
        params_from_meged_file = pickle.load(trace)
    np.testing.assert_equal(params_from_model, params_from_meged_file)

for trace_file in test_output.glob("**/lon*"):
    lat = get_float_from_string(trace_file.parent.name)
    lon = get_float_from_string(trace_file.name)
    data_vars = []
    with open(trace_dir  / trace_file.parent.name / trace_file.name, "rb") as trace:
        params_from_model = pickle.load(trace)
    with open(trace_file , "rb") as trace:
        params_from_meged_file = pickle.load(trace)
    np.testing.assert_equal(params_from_model, params_from_meged_file)