In [13]:
import fsspec
import json
import os
import matplotlib.pyplot as plt
import fv3fit
from vcm import DerivedMapping
from vcm.catalog import catalog
import numpy as np

from loaders.mappers import open_nudge_to_fine

import pandas as pd
import xarray as xr
from vcm import parse_datetime_from_str
import vcm

import pandas as pd


In [5]:
runs = {
    "Base": "gs://vcm-ml-experiments/n2f-pire-stable-ml/2022-08-24/ablation-b22-data-sampling-higher-lr-no-lat-seed-0/",
    "Base + latitude": "gs://vcm-ml-experiments/n2f-pire-stable-ml/2022-08-18/ablation-b22-data-sampling-higher-lr-seed-0/",
    "Base + latitude + improved sampling": "gs://vcm-ml-experiments/n2f-pire-stable-ml/2022-08-16/ablation-higher-lr-seed-0/",
    "Base + latitude + lower LR": "gs://vcm-ml-experiments/n2f-pire-stable-ml/2022-08-16/ablation-b22-data-sampling-seed-0",
    "Base + latitude + improved sampling + lower LR": "gs://vcm-ml-experiments/n2f-pire-stable-ml/2022-05-17/tapering-effect-mae-no-taper-ensemble/"
}

In [6]:

metrics = {}
for step, rundir in runs.items():
    with fsspec.open(os.path.join(rundir, "offline_diags", "tq_tendencies", "scalar_metrics.json"), "r") as f:
        metrics[step] = json.load(f)

In [7]:
names = list(metrics.keys())
dQ1_col_int_R2 = []
dQ2_col_int_R2 = []
polar_dQ1_bias = []

In [22]:
for step, step_metrics in metrics.items():
    print(f"{step}: {step_metrics['column_integrated_dq1_r2_2d_global']}")
    dQ1_col_int_R2 .append(step_metrics['column_integrated_dq1_r2_2d_global']['mean'])

print()

for step, step_metrics in metrics.items():
    print(f"{step}: {step_metrics['column_integrated_dq2_r2_2d_global']}")
    dQ2_col_int_R2.append(step_metrics['column_integrated_dq2_r2_2d_global']['mean'])

Base: {'mean': 0.17367746414761803, 'std': 0.006323067113701896}
Base + latitude: {'mean': 0.18145564902158468, 'std': 0.0074711910188758465}
Base + latitude + improved sampling: {'mean': 0.1773509485181309, 'std': 0.006837434967932587}
Base + latitude + lower LR: {'mean': 0.27824892628496994, 'std': 0.009523459102442856}
Base + latitude + improved sampling + lower LR: {'mean': 0.29144340847073585, 'std': 0.008304898439304131}

Base: {'mean': 0.1478540474154636, 'std': 0.006050290804313889}
Base + latitude: {'mean': 0.15604672217871, 'std': 0.005923853798441016}
Base + latitude + improved sampling: {'mean': 0.14702832450291373, 'std': 0.006719409487227927}
Base + latitude + lower LR: {'mean': 0.22176653248656003, 'std': 0.006255953148123605}
Base + latitude + improved sampling + lower LR: {'mean': 0.2263754181937519, 'std': 0.0071515851451324}


In [11]:

models = {}
for step, rundir in runs.items():
    models[step] = fv3fit.load(os.path.join(rundir, "trained_models", "tq_tendencies",) )




In [15]:
grid = catalog["grid/c48"].read()
train_mapper = open_nudge_to_fine(
    data_path = "gs://vcm-ml-experiments/n2f-pire-sfc-updates/2022-01-21/nudged-run/fv3gfs_run/",
    nudging_variables = ["air_temperature", "specific_humidity", "pressure_thickness_of_atmospheric_layer"],
    cache_size_mb=4000,
)


In [14]:
ranges = [
    ("20200119.000000","20200124.000000"),
    ("20200114.000000","20200129.000000")
]

In [17]:
full_model = models["Base"]


In [18]:
for i, (tmin, tmax) in enumerate(ranges):
    keys = [key for key in train_mapper.keys() if ((key>tmin) and (key<tmax))]
    time_coords = [parse_datetime_from_str(key) for key in keys]
    ds = xr.concat([train_mapper[key] for key in keys], pd.Index(time_coords, name="time"))
    ds_input_N = DerivedMapping(ds.merge(grid)).dataset(full_model.input_variables + full_model.output_variables +  [ "pressure_thickness_of_atmospheric_layer",]) \
        .where(grid.lat > 60)
    ds_input_S = DerivedMapping(ds.merge(grid)).dataset(full_model.input_variables + full_model.output_variables +  [ "pressure_thickness_of_atmospheric_layer",]) \
        .where(grid.lat < -60)
    ds_input_N_stacked = ds_input_N.stack(sample=["x", "y", "tile", "time"]) \
        .dropna('sample') \
        .transpose("sample", ...)
    ds_input_S_stacked = ds_input_S.stack(sample=["x", "y", "tile", "time"]) \
        .dropna('sample') \
        .transpose("sample", ...)
    ds_input_N_stacked.reset_index('sample', drop=True).to_netcdf(f"temp/N_{i}.nc")
    ds_input_S_stacked.reset_index('sample', drop=True).to_netcdf(f"temp/S_{i}.nc")
    
    del ds_input_N_stacked, ds_input_S_stacked, ds_input_N, ds_input_S

TypeError: can only concatenate list (not "tuple") to list

In [19]:
data_files = [os.path.join("temp",file) for file in os.listdir("temp") if file.endswith(".nc") and file.startswith("N")]

sample_count = 0
pole_test_data = []
for file in data_files:
    ds_ = xr.open_dataset(file)
    ds_.assign_coords({"sample": range(sample_count, sample_count+len(ds_.sample))})
    sample_count += len(ds_.sample)
    pole_test_data .append(ds_)

pole_test_data = xr.concat(pole_test_data, dim="sample")

In [21]:

data_files = [os.path.join("temp",file) for file in os.listdir("temp") if file.endswith(".nc") ]

sample_count = 0
pole_test_data = []
for file in data_files:
    ds_ = xr.open_dataset(file)
    ds_.assign_coords({"sample": range(sample_count, sample_count+len(ds_.sample))})
    sample_count += len(ds_.sample)
    pole_test_data .append(ds_)

pole_test_data = xr.concat(pole_test_data, dim="sample")
predictions = {
    step: step_model.predict(pole_test_data) for step, step_model in models.items()
}
biases = {
    step: step_predictions[["dQ1", "dQ2"]]-pole_test_data[["dQ1", "dQ2"]]
    for step, step_predictions in predictions.items()
}
interp_dQ1_bias = {}
for step, bias in biases.items():
    bias_dQ1_plev = vcm.interpolate_to_pressure_levels(
        bias['dQ1'],
        pole_test_data["pressure_thickness_of_atmospheric_layer"],
        dim="z",
        #levels=np.array([20000.,]),
    )
    interp_dQ1_bias[step] = bias_dQ1_plev
    #print(step, bias_dQ1_200hPa.sel(pressure=20000).values.mean())
    
    print(step, bias_dQ1_plev.sel(pressure=slice(15000, 40000)).mean("pressure").values.mean())
    polar_dQ1_bias.append( bias_dQ1_plev.sel(pressure=slice(15000, 40000)).mean("pressure").values.mean())

Base 1.1158522423034133e-06
Base + latitude 8.105506845318954e-07
Base + latitude + improved sampling 6.992832160864569e-07
Base + latitude + lower LR -1.7640969219544372e-07
Base + latitude + improved sampling + lower LR 2.1738980763588844e-07


In [23]:
df = pd.DataFrame(
    list(zip(names, dQ1_col_int_R2, dQ2_col_int_R2, polar_dQ1_bias)),
    columns =[
        'Name', 
        'Column-integrated dQ1 $R^2$', 
        'Column-integrated dQ2 $R^2$', 
        '150-400 hPa mean dQ1 bias ($\lvert \mathrm{lat} \rvert > 60^{\circ}$, days 0-10) [K/s]'])

In [24]:
df.to_latex('sensitivity_table.tex', formatters={'cost':'${:,.2f}'.format})


In [25]:
df

Unnamed: 0,Name,Column-integrated dQ1 $R^2$,Column-integrated dQ2 $R^2$,"150-400 hPa mean dQ1 bias ($\lvert \mathrm{lat} \rvert > 60^{\circ}$, days 0-10) [K/s]"
0,Base,0.173677,0.147854,1.115852e-06
1,Base + latitude,0.181456,0.156047,8.105507e-07
2,Base + latitude + improved sampling,0.177351,0.147028,6.992832e-07
3,Base + latitude + lower LR,0.278249,0.221767,-1.764097e-07
4,Base + latitude + improved sampling + lower LR,0.291443,0.226375,2.173898e-07
