# Global-scale atmospheric mass and energy conservations on hybrid sigma-pressure level data

In [1]:
import torch
import numpy as np
import xarray as xr

In [2]:
from typing import Dict, Any, Optional

### Load data

In [3]:
base_dir = '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_1deg/'
filename = base_dir + 'all_in_one/ERA5_mlevel_1deg_6h_subset_1980_conserve.zarr'

ds_surf = xr.open_zarr(filename)
ds_accum = xr.open_zarr(filename)
ds_upper = xr.open_zarr(filename)
ds_static = xr.open_zarr(base_dir + 'static/ERA5_mlevel_1deg_static_subset.zarr')

In [4]:
x = ds_surf['longitude']
y = ds_surf['latitude']

lon, lat = np.meshgrid(x, y)

In [5]:
mlevel = ds_upper['level'].values

In [6]:
mlevel[:31]

array([  1.,   9.,  19.,  29.,  39.,  49.,  59.,  69.,  79.,  89.,  97.,
       104., 111., 116., 122., 126., 131., 136.], dtype=float32)

In [7]:
coef_a = ds_static['coef_a'].values
coef_b = ds_static['coef_b'].values

tensor_shape = (len(coef_b)-1,) + lon.shape

### Convert data to `torch.Tensor`

In [8]:
batch_size = 16
target_shape_4D = (batch_size, 2)+tensor_shape
target_shape_3D = (batch_size, 2,)+tensor_shape[1:]

t_slice = np.arange(batch_size+1)

In [9]:
# (batch, time, level, lat, lon) version
def time_series_to_batch(q, target_shape):
    q_batch = torch.as_strided(
        q, size=target_shape, 
        stride=(q.stride(0), q.stride(0), *q.stride()[1:]))
    return q_batch

In [10]:
sp = torch.from_numpy(np.array(ds_surf['SP'].isel(time=t_slice))) # Pa

q = torch.from_numpy(np.array(ds_upper['specific_total_water'].isel(time=t_slice))) # kg/kg
T = torch.from_numpy(np.array(ds_upper['temperature'].isel(time=t_slice))) # kg/kg
u = torch.from_numpy(np.array(ds_upper['u_component_of_wind'].isel(time=t_slice))) # m/s
v = torch.from_numpy(np.array(ds_upper['v_component_of_wind'].isel(time=t_slice)))
precip = torch.from_numpy(np.array(ds_accum['total_precipitation'].isel(time=t_slice)))
evapor = torch.from_numpy(np.array(ds_accum['evaporation'].isel(time=t_slice)))

GPH_surf = torch.from_numpy(np.array(ds_static['geopotential_at_surface'])) # J/m2
TOA_net = torch.from_numpy(np.array(ds_accum['top_net_solar_radiation'].isel(time=t_slice))) # J/m2
OLR = torch.from_numpy(np.array(ds_accum['top_net_thermal_radiation'].isel(time=t_slice))) # J/m2
R_short = torch.from_numpy(np.array(ds_accum['surface_net_solar_radiation'].isel(time=t_slice))) # J/m2
R_long = torch.from_numpy(np.array(ds_accum['surface_net_thermal_radiation'].isel(time=t_slice))) # J/m2
LH = torch.from_numpy(np.array(ds_accum['surface_latent_heat_flux'].isel(time=t_slice))) # J/m2
SH = torch.from_numpy(np.array(ds_accum['surface_sensible_heat_flux'].isel(time=t_slice))) # J/m2

In [11]:
sp_batch = time_series_to_batch(sp, target_shape_3D)

q_batch = time_series_to_batch(q, target_shape_4D).permute(0, 2, 1, 3, 4)
T_batch = time_series_to_batch(T, target_shape_4D).permute(0, 2, 1, 3, 4)
u_batch = time_series_to_batch(u, target_shape_4D).permute(0, 2, 1, 3, 4)
v_batch = time_series_to_batch(v, target_shape_4D).permute(0, 2, 1, 3, 4)
precip_batch = time_series_to_batch(precip, target_shape_3D)
evapor_batch = time_series_to_batch(evapor, target_shape_3D)

GPH_surf_batch = GPH_surf.unsqueeze(0).unsqueeze(0).unsqueeze(0)
TOA_net_batch = time_series_to_batch(TOA_net, target_shape_3D)
OLR_batch = time_series_to_batch(OLR, target_shape_3D)
R_short_batch = time_series_to_batch(R_short, target_shape_3D)
R_long_batch = time_series_to_batch(R_long, target_shape_3D)
LH_batch = time_series_to_batch(LH, target_shape_3D)
SH_batch = time_series_to_batch(SH, target_shape_3D)

In [12]:
longitude = torch.from_numpy(lon)
latitude = torch.from_numpy(lat)
coef_a = torch.from_numpy(coef_a)
coef_b = torch.from_numpy(coef_b)

### `credit.physics_core` pressure level class

In [13]:
# Earth's radius
RAD_EARTH = 6371000 # m
RVGAS = 461.5 # J/kg/K
RDGAS = 287.05 # J/kg/K
GRAVITY = 9.80665 # m/s^2
RHO_WATER = 1000.0 # kg/m^3
LH_WATER = 2.501e6  # J/kg
LH_ICE = 333700 # J/kg
CP_DRY = 1004.64 # J/kg K
CP_VAPOR = 1810.0 # J/kg K
CP_LIQUID = 4188.0 # J/kg K
CP_ICE = 2117.27 # J/kg K

In [14]:
class physics_hybrid_sigma_level:
    '''
    Hybrid sigma-pressure level physics

    Attributes:
        lon (torch.Tensor): Longitude in degrees.
        lat (torch.Tensor): Latitude in degrees.
        surface_pressure (torch.Tensor): Surface pressure in Pa.
        coef_a (torch.Tensor): Hybrid sigma-pressure coefficient 'a' [Pa].
        coef_b (torch.Tensor): Hybrid sigma-pressure coefficient 'b' [unitless].
        pressure (torch.Tensor): Pressure at each hybrid sigma level [Pa].
        pressure_thickness (torch.Tensor): Pressure thickness between levels [Pa].
        area (torch.Tensor): Area of grid cells [m^2].
        integral (function): Vertical integration method (midpoint or trapezoidal).
    '''

    def __init__(self,
                 lon: torch.Tensor,
                 lat: torch.Tensor,
                 coef_a: torch.Tensor,
                 coef_b: torch.Tensor,
                 midpoint: bool = False):
        '''
        Initialize the class with longitude, latitude, and hybrid sigma-pressure levels.

        All inputs must be on the same torch device.

        Full order of dimensions: (batch, level, time, latitude, longitude)

        Args:
            lon (torch.Tensor): Longitude in degrees.
            lat (torch.Tensor): Latitude in degrees.
            coef_a (torch.Tensor): Hybrid sigma-pressure coefficient 'a' [Pa] (level,).
            coef_b (torch.Tensor): Hybrid sigma-pressure coefficient 'b' [unitless] (level,).
            midpoint (bool): True if vertical level quantities are midpoint values; otherwise False.

        Note:
            pressure = coef_a + coef_b * surface_pressure
        '''
        self.lon = lon
        self.lat = lat
        self.coef_a = coef_a  # (level,)
        self.coef_b = coef_b  # (level,)

        # ========================================================================= #
        # Compute pressure on each hybrid sigma level
        # Reshape coef_a and coef_b for broadcasting
        self.coef_a = coef_a.view(1, -1, 1, 1, 1)  # (1, level, 1, 1, 1)
        self.coef_b = coef_b.view(1, -1, 1, 1, 1)  # (1, level, 1, 1, 1)
        
        # ========================================================================= #
        # compute gtid area
        # area = R^2 * d_sin(lat) * d_lon
        lat_rad = torch.deg2rad(self.lat)
        lon_rad = torch.deg2rad(self.lon)
        sin_lat_rad = torch.sin(lat_rad)
        d_phi = torch.gradient(sin_lat_rad, dim=0, edge_order=2)[0]
        d_lambda = torch.gradient(lon_rad, dim=1, edge_order=2)[0]
        d_lambda = (d_lambda + torch.pi) % (2 * torch.pi) - torch.pi
        self.area = torch.abs(RAD_EARTH**2 * d_phi * d_lambda)

        # ========================================================================== #
        # Vertical integration method
        if midpoint:
            self.integral = self.pressure_integral_midpoint
            self.integral_sliced = self.pressure_integral_midpoint_sliced
        else:
            self.integral = self.pressure_integral_trapz
            self.integral_sliced = self.pressure_integral_trapz_sliced

    def pressure_integral_midpoint(self, 
                                   q_mid: torch.Tensor,
                                   surface_pressure: torch.Tensor,) -> torch.Tensor:
        '''
        Compute the pressure level integral of a given quantity; assuming its mid-point
        values are pre-computed.

        Args:
            q_mid: The quantity with dims of (batch, level-1, time, latitude, longitude)
            surface_pressure: Surface pressure in Pa (batch, time, latitude, longitude).

        Returns:
            Pressure level integrals of q
        '''
        surface_pressure = surface_pressure.unsqueeze(1)  # (batch, 1, time, lat, lon)
        pressure = self.coef_a + self.coef_b * surface_pressure  # (batch, level, time, lat, lon)        
        delta_p = pressure.diff(dim=1).to(q_mid.device)  # (batch, level-1, time, lat, lon)
        q_area = q_mid * delta_p  # Element-wise multiplication
        q_integral = torch.sum(q_area, dim=1)  # Sum over level dimension
        return q_integral
        
    def pressure_integral_midpoint_sliced(self,
                                          q_mid: torch.Tensor,
                                          surface_pressure: torch.Tensor,
                                          ind_start: int,
                                          ind_end: int) -> torch.Tensor:
        '''
        As in `pressure_integral_midpoint`, but supports pressure level indexing,
        so it can calculate integrals of a subset of levels.
        '''
        surface_pressure = surface_pressure.unsqueeze(1)  # (batch, 1, time, lat, lon)
        pressure = self.coef_a + self.coef_b * surface_pressure  # (batch, level, time, lat, lon)
        pressure_thickness = pressure.diff(dim=1)  # (batch, level-1, time, lat, lon)
        delta_p = pressure_thickness[:, ind_start:ind_end, :, :, :].to(q_mid.device)
        
        q_mid_sliced = q_mid[:, ind_start:ind_end, :, :, :]
        q_area = q_mid_sliced * delta_p
        q_integral = torch.sum(q_area, dim=1)
        return q_integral

    def pressure_integral_trapz(self, 
                                q: torch.Tensor,
                                surface_pressure: torch.Tensor) -> torch.Tensor:
        '''
        Compute the pressure level integral of a given quantity using the trapezoidal rule.

        Args:
            q: The quantity with dims of (batch, level, time, latitude, longitude)

        Returns:
            Pressure level integrals of q
        '''
        surface_pressure = surface_pressure.unsqueeze(1)  # (batch, 1, time, lat, lon)
        pressure = self.coef_a + self.coef_b * surface_pressure  # (batch, level, time, lat, lon)
        delta_p = pressure.diff(dim=1).to(q.device)  # (batch, level-1, time, lat, lon)
        q1 = q[:, :-1, :, :, :]  # (batch, level-1, time, lat, lon)
        q2 = q[:, 1:, :, :, :]   # (batch, level-1, time, lat, lon)
        q_area = 0.5 * (q1 + q2) * delta_p  # Trapezoidal rule
        q_trapz = torch.sum(q_area, dim=1)  # Sum over level dimension
        return q_trapz

    def pressure_integral_trapz_sliced(self,
                                       q: torch.Tensor,
                                       surface_pressure: torch.Tensor,
                                       ind_start: int,
                                       ind_end: int) -> torch.Tensor:
        '''
        As in `pressure_integral_trapz`, but supports pressure level indexing,
        so it can calculate integrals of a subset of levels.
        '''
        surface_pressure = surface_pressure.unsqueeze(1)  # (batch, 1, time, lat, lon)
        pressure = self.coef_a + self.coef_b * surface_pressure  # (batch, level, time, lat, lon)
        delta_p = pressure[:, ind_start:ind_end, :, :, :].diff(dim=1).to(q.device)
        
        q_slice = q[:, ind_start:ind_end, :, :, :]
        q1 = q_slice[:, :-1, :, :, :]
        q2 = q_slice[:, 1:, :, :, :]
        q_area = 0.5 * (q1 + q2) * delta_p
        q_trapz = torch.sum(q_area, dim=1)
        return q_trapz

    def weighted_sum(self,
                     q: torch.Tensor,
                     axis: Dict[tuple, None] = None, 
                     keepdims: bool = False) -> torch.Tensor:
        '''
        Compute the weighted sum of a given quantity for PyTorch tensors.
        
        Args:
            data: the quantity to be summed (PyTorch tensor)
            axis: dims to compute the sum (can be int or tuple of ints)
            keepdims: whether to keep the reduced dimensions or not
    
        Returns:
            Weighted sum (PyTorch tensor)
        '''
        q_w = q * self.area.to(q.device)
        q_sum = torch.sum(q_w, dim=axis, keepdim=keepdims)
        return q_sum

    def total_dry_air_mass(self, 
                           q: torch.Tensor,
                           surface_pressure: torch.Tensor) -> torch.Tensor:
        '''
        Compute the total mass of dry air over the entire globe [kg]
        '''
        mass_dry_per_area = self.integral(1-q, surface_pressure) / GRAVITY # kg/m^2
        # weighted sum on latitude and longitude dimensions
        mass_dry_sum = self.weighted_sum(mass_dry_per_area, axis=(-2, -1)) # kg
        
        return mass_dry_sum

    def total_column_water(self, 
                           q: torch.Tensor,
                           surface_pressure: torch.Tensor,) -> torch.Tensor:
        '''
        Compute total column water (TCW) per air column [kg/m2]
        '''
        TWC = self.integral(q, surface_pressure) / GRAVITY # kg/m^2
        
        return TWC


In [15]:
flag_midpoint = True
physics_core = physics_hybrid_sigma_level(longitude, latitude, coef_a, coef_b, midpoint=flag_midpoint)

## Conservation of total dry air mass

In [16]:
ind_fix = 7

if flag_midpoint:
    ind_fix_start = ind_fix
else:
    ind_fix_start = ind_fix-1

N_levels = len(coef_a)-1

In [17]:
q_batch_correct = q_batch.clone()
sp_batch_correct = sp_batch.clone()

mass_dry_per_area_hold = physics_core.integral_sliced(1-q_batch_correct, sp_batch_correct, 0, ind_fix) / GRAVITY
mass_dry_sum_hold = physics_core.weighted_sum(mass_dry_per_area_hold, axis=(-2, -1))

mass_dry_per_area_fix = physics_core.integral_sliced(1-q_batch_correct, sp_batch_correct, ind_fix_start, N_levels) / GRAVITY
mass_dry_sum_fix = physics_core.weighted_sum(mass_dry_per_area_fix, axis=(-2, -1))

mass_dry_sum = mass_dry_sum_hold + mass_dry_sum_fix

# ------------------------------------------------------------------------------ #
# check residual term
mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
# ------------------------------------------------------------------------------ #

mass_residual_on_fix = mass_dry_sum[:, 0] - mass_dry_sum_hold[:, 1]

# Compute the ratio
q_correct_ratio = mass_residual_on_fix / mass_dry_sum_fix[:, 1]
q_correct_ratio = q_correct_ratio.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
q_batch_correct[:, ind_fix-1:, 1, ...] = 1 - (1 - q_batch_correct[:, ind_fix-1:, 1, ...]) * q_correct_ratio

mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct, sp_batch_correct)

delta_coef_a = coef_a.diff()
delta_coef_b = coef_b.diff()

if flag_midpoint:
    p_dry_a = ((delta_coef_a.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4)) * (1 - q_batch_correct)).sum(1)
    p_dry_b = ((delta_coef_b.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4)) * (1 - q_batch_correct)).sum(1)
else:
    q_mid = (q_batch_correct[:, :-1, ...] + q_batch_correct[:, 1:, ...]) / 2
    p_dry_a = ((delta_coef_a.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4)) * (1 - q_mid)).sum(1)
    p_dry_b = ((delta_coef_b.unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4)) * (1 - q_mid)).sum(1)
    
mass_dry_a = (p_dry_a * physics_core.area.unsqueeze(0).unsqueeze(0)).sum((-2, -1)) / GRAVITY
mass_dry_b = (p_dry_b * sp_batch_correct * physics_core.area.unsqueeze(0).unsqueeze(0)).sum((-2, -1)) / GRAVITY

# sp correction ratio using t0 dry air mass and t1 moisture
sp_correct_ratio = (mass_dry_sum[:, 0, ...] - mass_dry_a[:, 1, ...]) / mass_dry_b[:, 1, ...]
sp_correct_ratio = sp_correct_ratio.unsqueeze(-1).unsqueeze(-1)
sp_batch_correct[:, 1, ...] = sp_correct_ratio * sp_batch_correct[:, 1, ...]

Residual to conserve the dry air mass [kg]: tensor([ 3.6834e+13,  3.0511e+14,  4.0132e+13, -3.3700e+14,  1.5943e+13,
         1.6493e+12,  1.9241e+13, -2.4519e+14,  1.0995e+12,  2.3530e+14,
         3.6284e+13, -1.1710e+14,  2.2540e+13,  4.0132e+13,  2.1440e+13,
        -2.1276e+14])


In [19]:
# ------------------------------------------------------------------------------ #
mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct, sp_batch_correct)
mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
# ------------------------------------------------------------------------------ #

Residual to conserve the dry air mass [kg]: tensor([ 0.0000e+00,  0.0000e+00, -5.4976e+11,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  1.0995e+12,  0.0000e+00,  5.4976e+11, -5.4976e+11,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  5.4976e+11,  5.4976e+11,
         0.0000e+00])


In [20]:
(q_batch_correct - q_batch).abs().max()

tensor(7.3224e-05)

In [21]:
(sp_batch_correct - sp_batch).abs().max()

tensor(0.3359)

**Old**

In [20]:
q_batch_correct = q_batch.clone()

correction_cycle_num = 1 # iterative to handle numrical precision

for i in range(correction_cycle_num):
    mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct, sp_batch)
    
    # ------------------------------------------------------------------------------ #
    # check residual term
    mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
    print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
    # ------------------------------------------------------------------------------ #
    
    q_correct_ratio = mass_dry_sum[:, 0] / mass_dry_sum[:, 1]
    q_correct_ratio = q_correct_ratio.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    q_batch_correct[:, :, 1, ...] = 1 - (1 - q_batch_correct[:, :, 1, ...]) * q_correct_ratio

Residual to conserve the dry air mass [kg]: tensor([ 2.9687e+13,  3.1226e+14,  3.1336e+13, -3.3700e+14,  2.0341e+13,
         5.4976e+11,  1.5943e+13, -2.4959e+14,  9.3458e+12,  2.3859e+14,
         3.3535e+13, -1.1380e+14,  2.3090e+13,  4.2331e+13,  2.0341e+13,
        -2.1331e+14])


In [21]:
# ------------------------------------------------------------------------------ #
mass_dry_sum = physics_core.total_dry_air_mass(q_batch_correct, sp_batch)
mass_dry_res = mass_dry_sum[:, 1] - mass_dry_sum[:, 0]
print('Residual to conserve the dry air mass [kg]: {}'.format(mass_dry_res))
# ------------------------------------------------------------------------------ #

Residual to conserve the dry air mass [kg]: tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  5.4976e+11,
         0.0000e+00,  5.4976e+11,  5.4976e+11,  5.4976e+11,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00, -5.4976e+11,  0.0000e+00,
        -1.0995e+12])


In [22]:
(q_batch_correct - q_batch).max()

tensor(6.1125e-05)

## Conservation of moisture

In [23]:
N_seconds = 3600 * 6 # 6 hourly data

precip_batch_flux = precip_batch[:, 1, ...] * RHO_WATER / N_seconds # m/hour --> kg/m^2/s, positive
evapor_batch_flux = evapor_batch[:, 1, ...] * RHO_WATER / N_seconds # kg/m^2/s, negative

precip_batch_correct = precip_batch_flux.clone()

# pre-compute TWC
TWC = physics_core.total_column_water(q_batch_correct, sp_batch)
dTWC_dt = (TWC[:, 1, ...] - TWC[:, 0, ...]) / N_seconds # kg/m^2/s
TWC_sum = physics_core.weighted_sum(dTWC_dt, axis=(-2, -1)) # kg/s

# pre-compute evaporation
E_sum = physics_core.weighted_sum(evapor_batch_flux, axis=(-2, -1)) # kg/s

correction_cycle_num = 1

for i in range(correction_cycle_num):
    P_sum = physics_core.weighted_sum(precip_batch_correct, axis=(-2, -1)) # kg/s
    residual = -TWC_sum - E_sum - P_sum # kg/s

    # ------------------------------------------------------------------------------ #
    print('Residual to conserve moisture budge [kg/s]: {}'.format(residual))
    # ------------------------------------------------------------------------------ #
    
    # P_correct = P_sum + residual # kg/s
    P_correct_ratio = (P_sum + residual) / P_sum
    P_correct_ratio = P_correct_ratio.unsqueeze(-1).unsqueeze(-1)
    precip_batch_correct = precip_batch_correct * P_correct_ratio

Residual to conserve moisture budge [kg/s]: tensor([-1.5418e+09, -1.5883e+10, -2.3156e+09,  1.4763e+10, -1.1064e+09,
        -1.2916e+09, -1.8417e+09,  1.0411e+10, -9.3729e+08, -1.1939e+10,
        -2.1515e+09,  4.8960e+09, -1.1885e+09, -3.1335e+09, -1.5421e+09,
         7.1912e+09])


In [24]:
# ------------------------------------------------------------------------------ #
P_sum = physics_core.weighted_sum(precip_batch_correct, axis=(-2, -1)) # kg/s
residual = -TWC_sum - E_sum - P_sum # kg/s
print('Residual to conserve moisture budge [kg/s]: {}'.format(residual))
# ------------------------------------------------------------------------------ #

Residual to conserve moisture budge [kg/s]: tensor([    0.,     0.,  1024., -2048.,     0.,  1024.,     0.,     0., -1024.,
            0.,  1024.,     0.,  1024.,  1024.,  1024.,     0.])


In [25]:
(precip_batch_correct - precip_batch_flux).mean()

tensor(-7.0661e-07)

In [26]:
(precip_batch_correct - precip_batch_flux).max()

tensor(0.0022)

### Conservation of energy

In [27]:
N_seconds = 3600 * 6 # 6 hourly data

# C_p (batch, time, level, lat, lon)
C_p = (1 - q_batch_correct) * CP_DRY + q_batch_correct * CP_VAPOR
# kinetic energy (batch, time, level, lat, lon)
ken = 0.5 * (u_batch ** 2 + v_batch ** 2)

# initialize T_correct
T_batch_correct = T_batch.clone()

# layer-wise atmospheric energy, but without thermal energy 
# (batch, time, level, lat, lon)
E_qgk = LH_WATER * q_batch_correct + GPH_surf_batch + ken

# TOA net energy flux (batch, time, lat, lon)
R_T = (TOA_net_batch + OLR_batch) / N_seconds
R_T = R_T[:, 1, :, :]
# R_T global sum
R_T_sum = physics_core.weighted_sum(R_T, axis=(-2, -1))

# surface net energy flux (batch, time, lat, lon)
F_S = (R_short_batch + R_long_batch + LH_batch + SH_batch) / N_seconds
F_S = F_S[:, 1, :, :]  # Extract time index 1
# F_S global sum
F_S_sum = physics_core.weighted_sum(F_S, axis=(-2, -1))

correction_cycle_num = 1

for i in range(correction_cycle_num):

    # layer-wise atmospheric energy (sensible heat + others)
    #  (batch, time, level, lat, lon)
    E_level = C_p * T_batch_correct + E_qgk

    # total atmospheric energy (TE) of an air column
    # (batch, time, lat, lon)
    TE = physics_core.integral(E_level, sp_batch) / GRAVITY

    # ---------------------------------------------------------------------------- #
    # tendency of TE (batch, lat, lon)
    dTE_dt = (TE[:, 1, :, :] - TE[:, 0, :, :]) / N_seconds
    # global sum of TE tendency (batch,)
    dTE_sum = physics_core.weighted_sum(dTE_dt, axis=(1, 2), keepdims=False)
    # compute the residual (batch,)
    delta_dTE_sum = (R_T_sum - F_S_sum) - dTE_sum
    print('Residual to conserve energy budget [Watts]: {}'.format(delta_dTE_sum))
    print('Sources & sinks [Watts]: {}'.format(R_T_sum - F_S_sum))
    print('Tendency [Watts]: {}'.format(dTE_sum))
    # ---------------------------------------------------------------------------- #

    # TE at t0 and t1 (batch,)
    total_weighted_TE_t0 = physics_core.weighted_sum(TE[:, 0, :, :], axis=(-2, -1)) 
    total_weighted_TE_t1 = physics_core.weighted_sum(TE[:, 1, :, :], axis=(-2, -1))

    # calculate the correction ratio for E_t1 (batch,) --> (batch, 1, 1, 1)
    E_correct_ratio = (N_seconds * (R_T_sum - F_S_sum) + total_weighted_TE_t0) / total_weighted_TE_t1
    E_correct_ratio = E_correct_ratio.view(-1, 1, 1, 1)

    # Apply the correction to layer-wise atmospheric energy at t1
    # (batch, level, lat, lon)
    E_t1_correct = E_level[:, :, 1, :, :] * E_correct_ratio

    # barotropic correction of T at t1
    T_batch_correct[:, :, 1, :, :] = (E_t1_correct - E_qgk[:, :, 1, :, :]) / C_p[:, :, 1, :, :]

Residual to conserve energy budget [Watts]: tensor([-4.8474e+15, -5.8542e+16, -4.8065e+15,  4.1037e+16, -4.0271e+15,
        -9.7805e+15, -3.2388e+15,  2.5913e+16, -3.0752e+15, -4.0709e+16,
        -4.6738e+15,  1.4178e+16, -4.5057e+15, -1.7434e+16, -3.3635e+15,
         1.6267e+16])
Sources & sinks [Watts]: tensor([-8.8138e+15, -2.2061e+15, -3.2601e+15, -1.2018e+16, -9.7891e+15,
        -2.2000e+15, -2.4414e+15, -1.0756e+16, -8.4184e+15, -1.1962e+15,
        -1.5977e+15, -1.0163e+16, -7.5918e+15,  6.7430e+14,  1.1722e+15,
        -6.8434e+15])
Tendency [Watts]: tensor([-3.9664e+15,  5.6335e+16,  1.5464e+15, -5.3055e+16, -5.7620e+15,
         7.5804e+15,  7.9741e+14, -3.6669e+16, -5.3433e+15,  3.9513e+16,
         3.0761e+15, -2.4341e+16, -3.0860e+15,  1.8109e+16,  4.5357e+15,
        -2.3110e+16])


In [28]:
# ---------------------------------------------------------------------------- #
E_level = C_p * T_batch_correct + E_qgk
TE = physics_core.integral(E_level, sp_batch) / GRAVITY
dTE_dt = (TE[:, 1, :, :] - TE[:, 0, :, :]) / N_seconds
dTE_sum = physics_core.weighted_sum(dTE_dt, axis=(-2, -1), keepdims=False)
energy_residual = dTE_sum - (R_T_sum - F_S_sum)
print('Residual to conserve energy budget [Watts]: {}'.format(energy_residual))
# ---------------------------------------------------------------------------- #

Residual to conserve energy budget [Watts]: tensor([-2.9710e+12,  2.3767e+12, -4.2099e+12,  7.4958e+12,  3.9589e+12,
         5.6746e+12, -2.2218e+12, -1.0437e+12, -2.0213e+12, -1.4498e+12,
         5.2101e+12,  3.2760e+12,  1.5521e+12,  6.7621e+12,  3.4212e+11,
        -2.6532e+12])


In [41]:
(T_batch_correct - T_batch).mean()

tensor(-0.0077)

In [42]:
(T_batch_correct - T_batch).abs().max()

tensor(0.3331)