# Example usage of data-drive-quadtrature annealing optimization for spectral fluxes

### Neal Ma

This notebook shows an example of using the data-driven-quadrature package for estimating spectral fluxes. 

TODO:

- Data availability and attribution

In [1]:
# Import all necessary packages 
import numpy as np
import xarray as xr
import cvxpy as cp
import datadrivenquadrature as ddq

# Not a necessary package but helpful
import matplotlib.pyplot as plt

In [2]:
# Load in data
x_file_path = './../data/pyarts_training_72000_all_levels.h5'
x = xr.open_mfdataset(x_file_path, combine = 'nested', concat_dim = 'column', engine = "netcdf4")

# # Extract flattened spectral fluxes
# x_sup = np.reshape(np.array(x['spectral_fluxes']), (55*50, x['spectral_fluxes'].shape[-1]))
# Extract flattened reference fluxes
y_ref = np.array(x['reference_fluxes'].data).flatten()

In [3]:
# Print dataset, reference fluxes, and axis length
print("x:\n", x)
print("y:\n", np.asarray(y_ref).shape)
print("axis_len:\n", (x['spectral_fluxes'].shape)[-1])

x:
 <xarray.Dataset>
Dimensions:            (half_level: 55, column: 50, spectral_coord: 72120,
                        level: 54)
Coordinates:
  * half_level         (half_level) int64 0 1 2 3 4 5 6 ... 48 49 50 51 52 53 54
  * level              (level) int64 0 1 2 3 4 5 6 7 ... 46 47 48 49 50 51 52 53
  * column             (column) int64 0 1 2 3 4 5 6 7 ... 43 44 45 46 47 48 49
  * spectral_coord     (spectral_coord) float32 0.0002 0.0202 ... 3.26e+03
Data variables:
    spectral_fluxes    (half_level, column, spectral_coord) float64 dask.array<chunksize=(55, 50, 72120), meta=np.ndarray>
    reference_fluxes   (half_level, column) float64 dask.array<chunksize=(55, 50), meta=np.ndarray>
    reference_heating  (column, level) float64 dask.array<chunksize=(50, 54), meta=np.ndarray>
    pressures          (column, half_level) float32 dask.array<chunksize=(50, 55), meta=np.ndarray>
y:
 (2750,)
axis_len:
 72120


In [4]:
# calculate the scale to use for integration axis normalization
norm_vector = ddq.find_normalization_vector(x, ['spectral_coord'])

def user_cost_fnc(y, y_hat):
    cost = cp.norm(y - y_hat)
    return cost

def map_func(x, point_set, x_sup=None):
    point_idxs = [point[0] for point in point_set]
    points = np.array(x.spectral_fluxes[:,:,point_idxs].values).reshape((55*50, 15)) / norm_vector[0][0]
    return points

In [5]:
params = {}
integration_list = ['spectral_coord']
params['integration_list'] = integration_list
params['n_points'] = 15
params['epochs'] = 5
params['success'] = 5
params['block_size'] = 10
ddq.check_params(x,  y_ref / norm_vector[0][0], user_cost_fnc, map_func, params)

(2750, 15) (2750, 15)
True


0

In [6]:
history = ddq.optimize(x, y_ref / norm_vector[0][0], user_cost_fnc, map_func, params)

<class 'cvxpy.atoms.pnorm.Pnorm'>
INITIAL BLOCK: iteration 0
INITIAL BLOCK: iteration 1
INITIAL BLOCK: iteration 2
INITIAL BLOCK: iteration 3
INITIAL BLOCK: iteration 4
INITIAL BLOCK: iteration 5
INITIAL BLOCK: iteration 6
INITIAL BLOCK: iteration 7
INITIAL BLOCK: iteration 8
INITIAL BLOCK: iteration 9
1 0 [[36765], [6150], [51451], [36334], [25143], [43256], [56241], [44248], [35309], [9267], [51636], [54818], [8469], [32101], [43756]] 3.458477602247485
1 1 [[36765], [6150], [51451], [36334], [25143], [43256], [56241], [44248], [35309], [9267], [51636], [54818], [8469], [41535], [43756]] 3.4584776022474455
1 2 [[36765], [6150], [56911], [36334], [25143], [43256], [56241], [44248], [35309], [9267], [51636], [54818], [8469], [41535], [43756]] 3.45847760259763
1 3 [[36765], [6150], [56911], [36334], [25143], [43256], [56241], [58388], [35309], [9267], [51636], [54818], [8469], [41535], [43756]] 3.458477602246507
1 4 [[36765], [6150], [56911], [36334], [25143], [43211], [56241], [58388], 

In [None]:
# Plot cost history
mean_costs = []
for cost in history['cost']:
    mean_costs.append(np.min(cost))

plt.plot(np.arange(len(mean_costs)), mean_costs)
print(history.keys())
print(history['best'])
# print(sum(history['weight_sets'][5][4]))
# print(history['cost'][6][4])


In [None]:
flat_history = flatten_history(history)
flat_cost = flat_history['cost']
plt.plot(range(len(flat_cost)), np.log10(flat_cost))