In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd

In [3]:
import xarray as xr

In [4]:
from file_processing import get_netcdf_files, get_rhime_outs

In [5]:
from country_totals import get_country_trace, get_x_to_country_mat

In [40]:
from country_totals import get_xr_dummies, sparse_xr_dot, make_quantiles

In [6]:
#!pip install sparse

In [7]:
species = "sf6"

In [8]:
files = get_netcdf_files("/home/brendan/Documents/inversions/plotting/sf6_best")

In [9]:
outs = get_rhime_outs(files)

In [10]:
countries = xr.open_dataset("/home/brendan/Documents/inversions/openghg_inversions/countries/country_EUROPE.nc")

In [11]:
countries_ukmo = xr.open_dataset("/home/brendan/Documents/inversions/openghg_inversions/countries/country-ukmo_EUROPE.nc")

In [12]:
x_to_country_mats = [get_x_to_country_mat(countries, ds, sparse=True) for ds in outs]
x_to_country_mats_ukmo = [get_x_to_country_mat(countries_ukmo, ds, sparse=True) for ds in outs]

In [13]:
country_traces = [get_country_trace(countries, species, hbmcmc_outs=ds, x_to_country=mat) for ds, mat in zip(outs, x_to_country_mats)]

In [14]:
country_traces_ukmo = [get_country_trace(countries_ukmo, species, hbmcmc_outs=ds, x_to_country=mat) for ds, mat in zip(outs, x_to_country_mats_ukmo)]

In [15]:
country_traces = [trace.expand_dims({"time": [out_ds.Ytime.min().values]}) for trace, out_ds in zip(country_traces, outs)]
country_trace_ds = xr.concat(country_traces, dim="time")

In [16]:
country_traces_ukmo = [trace.expand_dims({"time": [out_ds.Ytime.min().values]}) for trace, out_ds in zip(country_traces_ukmo, outs)]
country_trace_ukmo_ds = xr.concat(country_traces_ukmo, dim="time")

# Getting country prior traces

In [17]:
min_model_error = 0.15

In [18]:
from flux_output_format import get_prior_samples

In [19]:
idatas = [get_prior_samples(ds, min_model_error) for ds in outs]

Sampling: [bc, sigma, x, ymod, ymodbc]
Sampling: [bc, sigma, x, ymod, ymodbc]
Sampling: [bc, sigma, x, ymod, ymodbc]
Sampling: [bc, sigma, x, ymod, ymodbc]
Sampling: [bc, sigma, x, ymod, ymodbc]
Sampling: [bc, sigma, x, ymod, ymodbc]
Sampling: [bc, sigma, x, ymod, ymodbc]
Sampling: [bc, sigma, x, ymod, ymodbc]
Sampling: [bc, sigma, x, ymod, ymodbc]
Sampling: [bc, sigma, x, ymod, ymodbc]
Sampling: [bc, sigma, x, ymod, ymodbc]


In [20]:
prior_x_traces = [idata.prior.x.isel(chain=0).rename({"x_dim_0": "basis_region"}) for idata in idatas]
country_prior_traces = [get_country_trace(countries, species, hbmcmc_outs=ds, x_to_country=mat, x_trace=xtrace) for ds, xtrace, mat in zip(outs, prior_x_traces, x_to_country_mats)]
country_prior_traces_ukmo = [get_country_trace(countries_ukmo, species, hbmcmc_outs=ds, x_to_country=mat, x_trace=xtrace) for ds, xtrace, mat in zip(outs, prior_x_traces, x_to_country_mats_ukmo)]

In [21]:
country_prior_traces = [trace.expand_dims({"time": [out_ds.Ytime.min().values]}) for trace, out_ds in zip(country_prior_traces, outs)]
country_prior_trace_ds = xr.concat(country_prior_traces, dim="time")

In [22]:
country_prior_traces_ukmo = [trace.expand_dims({"time": [out_ds.Ytime.min().values]}) for trace, out_ds in zip(country_prior_traces_ukmo, outs)]
country_prior_trace_ukmo_ds = xr.concat(country_prior_traces_ukmo, dim="time")

In [29]:
country_merged_ds = xr.merge([
    country_prior_trace_ds.mean("draw").rename("countryapriori"), 
    make_quantiles(country_prior_trace_ds, sample_dim="draw").rename("pcountryapriori"),
    country_trace_ds.mean("steps").rename("countryapost"), 
    make_quantiles(country_trace_ds).rename("pcountryapost"),
])

country_ukmo_merged_ds = xr.merge([
    country_prior_trace_ukmo_ds.mean("draw").rename("countryapriori"), 
    make_quantiles(country_prior_trace_ukmo_ds, sample_dim="draw").rename("pcountryapriori"),
    country_trace_ukmo_ds.mean("steps").rename("countryapost"), 
    make_quantiles(country_trace_ukmo_ds).rename("pcountryapost"),
])

# Process flux

We can't use the same method (producing flux traces) because they would be massing... 1000 samples for the EUROPE domain would be several gigabytes.

Since the flux is constant on the basis regions, we can compute means and quantiles before mapping them to the original lat/lon domain.

In [24]:
basis_mats = [get_xr_dummies(ds.basisfunctions, cat_dim="basis_region") for ds in outs]

In [31]:
fluxes = [ds.fluxapriori for ds in outs]
x_traces = [ds.xtrace.rename({"nparam": "basis_region"}) for ds in outs]

In [32]:
x_means = [trace.mean("steps") for trace in x_traces]
x_quantiles = [make_quantiles(trace) for trace in x_traces]

In [34]:
flux_means = [sparse_xr_dot(flux * mat, mean) for flux, mean, mat in zip(fluxes, x_means, basis_mats)] 

In [36]:
flux_quantiles = [sparse_xr_dot(flux * mat, quantiles) for flux, quantiles, mat in zip(fluxes, x_quantiles, basis_mats)] 

In [37]:
prior_x_means = [trace.mean("draw") for trace in prior_x_traces]
prior_x_quantiles = [make_quantiles(trace, sample_dim="draw") for trace in prior_x_traces]

In [38]:
prior_flux_means = [sparse_xr_dot(flux * mat, mean) for flux, mean, mat in zip(fluxes, prior_x_means, basis_mats)] 

In [39]:
prior_flux_quantiles = [sparse_xr_dot(flux * mat, quantiles) for flux, quantiles, mat in zip(fluxes, prior_x_quantiles, basis_mats)] 

In [41]:
times = [ds.Ytime.min().values for ds in outs]

In [45]:
prior_flux_mean_ds = xr.concat([da.expand_dims({"time": [time]}) for da, time in zip(prior_flux_means, times)], dim="time")
flux_mean_ds = xr.concat([da.expand_dims({"time": [time]}) for da, time in zip(flux_means, times)], dim="time")

In [46]:
prior_x_quantile_ds = xr.concat([da.expand_dims({"time": [time]}) for da, time in zip(prior_flux_quantiles, times)], dim="time")
flux_quantile_ds = xr.concat([da.expand_dims({"time": [time]}) for da, time in zip(flux_quantiles, times)], dim="time")

In [47]:
flux_merged_ds = xr.merge([
    prior_flux_mean_ds.rename("fluxapriori"),
    prior_flux_quantile_ds.rename("pfluxapriori"),
    flux_mean_ds.rename("fluxapost"),
    flux_quantile_ds.rename("pfluxapot")
])

# Merge all

In [51]:
paris_countries = ["IRELAND", "UNITED KINGDOM OF GREAT BRITAIN AND NORTHERN IRELAND", "FRANCE", 
             "BELGIUM", "NETHERLANDS", "GERMANY", "DENMARK", "SWITZERLAND", "AUSTRIA", "ITALY", 
             "CZECHIA", "POLAND", "HUNGARY", "SLOVAKIA", "NORWAY", "SWEDEN", "FINLAND"]

In [54]:
country_filt = country_merged_ds.country.isin(paris_countries)

In [59]:
country_ukmo_filt = country_ukmo_merged_ds.country.isin(["BENELUX", "RestEU", "SpaPor"])

In [55]:
country_merged_ds.where(country_filt, drop=True)

In [62]:
country_final_ds = xr.concat([country_merged_ds.where(country_filt, drop=True), country_ukmo_merged_ds.where(country_ukmo_filt, drop=True)], dim="ncountries")

In [63]:
rhime_emissions = xr.merge([flux_merged_ds, country_final_ds])

In [64]:
rhime_emissions