# expand each sample .nc file with additional featuers such as previous steps' information

## Load modules, determine available cpus, create list of input files

In [18]:
import os
import glob
import xarray as xr
import numpy as np
import multiprocessing as mp
from climsim_adding_input_import import process_one_file, get_pressure_thickness, tropopause_profile_2d
from dask.distributed import Client, progress
from dask.diagnostics import ProgressBar
from glob import glob
from tqdm.auto import tqdm

import metpy.constants.nounit as metconstnondim
from numba import njit

In [19]:
# Get the number of available CPUs
num_cpus = os.cpu_count()

print(f"Number of available CPUs: {num_cpus}")

Number of available CPUs: 96


In [32]:
# with open('./my_rad_files.txt', 'r') as file:
#     rad_files_in = file.read().splitlines()
rad_files_in = sorted(glob('/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/**/*.rad.*.nc', recursive=True))
    
# rad_files_in = rad_files_in[3:7]
mli_files_in = sorted([f.replace('.rad.', '.mli.') for f in rad_files_in])

len(rad_files_in), len(mli_files_in), rad_files_in[:4], mli_files_in[:4]

(4,
 4,
 ['/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0003-12/E3SM-MMF.rad.0003-12-01-75600.nc',
  '/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0003-12/E3SM-MMF.rad.0003-12-01-76800.nc',
  '/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0003-12/E3SM-MMF.rad.0003-12-01-78000.nc',
  '/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0003-12/E3SM-MMF.rad.0003-12-01-79200.nc'],
 ['/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0003-12/E3SM-MMF.mli.0003-12-01-75600.nc',
  '/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0003-12/E3SM-MMF.mli.0003-12-01-76800.nc',
  '/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0003-12/E3SM-MMF.mli.0003-12-01-78000.nc',
  '/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0003-12/E3SM-MMF.mli.0003-12-01-79200.nc'])

In [33]:
import re
dates = [re.match('.*(000\d-\d\d-\d\d).*', os.path.basename(p)).group(1) for p in mli_files_in]
dates_uniq = sorted(np.unique(dates))
dates_uniq

['0003-12-01']

In [34]:
from datetime import datetime, timedelta

def split_consecutive_dates(date_list, return_months=False):
    # Define the timestep of 20 minutes in seconds
    timestep = 1200  # 20 minutes = 1200 seconds
    # Prepare a list for results and a temporary sublist
    result = []
    months = []
    current_sublist = []
    skip_this_date = False
    prev_month_str = ''

    for i, date_str in enumerate(date_list):
        # Extract date and seconds
        date_group = re.match('.*(000\d-\d\d-\d\d)-(\d*).*', os.path.basename(date_str))
        base_date_str, seconds_str = date_group.group(1), date_group.group(2)
        month_str = base_date_str[:7]
        seconds = int(seconds_str)
        base_date = datetime.strptime(base_date_str, "%Y-%m-%d")
        current_date = base_date + timedelta(seconds=seconds)
        # if prev_month_str == month_str and skip_this_date:
        #     print(f'Skipping {base_date_str}')
        #     continue
        # else:
        #     skip_this_date = False

        # Add the first element to the first sublist
        if not current_sublist:
            current_sublist.append(date_str)
        else:
            # Compare with the previous date
            previous_date_str = current_sublist[-1]
            prev_date_group = re.match('.*(000\d-\d\d-\d\d)-(\d*).*', os.path.basename(previous_date_str))
            prev_base_date_str, prev_seconds_str = prev_date_group.group(1), prev_date_group.group(2)
            prev_month_str = prev_base_date_str[:7]
            prev_seconds = int(prev_seconds_str)
            prev_base_date = datetime.strptime(prev_base_date_str, "%Y-%m-%d")
            previous_date = prev_base_date + timedelta(seconds=prev_seconds)

            # Check if the current date is within the timestep
            if (current_date - previous_date).total_seconds() == timestep:
                current_sublist.append(date_str)
            else:
                # skip_this_date = True
                result.append(current_sublist)
                months.append(prev_month_str)
                current_sublist = [date_str]

    # Append the last sublist
    if current_sublist:
        result.append(current_sublist)
        months.append(month_str)

    if return_months:
        return result, months
    else:
        return result

In [35]:
# for l in split_consecutive_dates(mli_files_in):
    # print(len(l))
for l,m in zip(*split_consecutive_dates(mli_files_in, return_months=True)):
    # if len(l) < 144:
    #     print(m,len(l))
    print(m, len(l))

0003-12 4


## Create new nc files that contains additional input features
Below we will use multiprocessing to speed up the data processing work.

In [36]:
# ds1 = xr.open_dataset('../../grid_info/ClimSim_low-res_grid-info.nc')
ds0 = xr.open_dataset('/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/ClimSim_high-res_grid-info.nc')
# ds0 = xr.open_dataset('/scratch/b/b309215/LEAP/ClimSim_high-res/ClimSim_high-res_grid-info.nc')
lat = ds0['lat']
lon = ds0['lon']

In [37]:
iface = dict()
iface['hyai'] = ds0['hyai']
iface['hybi'] = ds0['hybi']
iface['P0']   = ds0['P0']

# mli = xr.open_dataset(mli_files_in[0])
# dP = get_pressure_thickness(mli['state_ps'], iface, mli['state_pmid'].coords)

In [39]:
mp.set_start_method('spawn')

# test = []
if __name__ == '__main__':
    # Determine the number of processes based on system's capabilities or your preference
    num_processes = mp.cpu_count()  # You can adjust this to a fixed number if preferred

    # for d in np.unique(dates):
    #     in_files_date = [p for p in mli_files_in if re.match(f'.*/E3SM-MMF.mli.{d}-\d*.nc', p)]
    #     print(d, ': ', len(in_files_date))
    #     # test.append(in_files_date)
    for in_files_date in tqdm(split_consecutive_dates(mli_files_in)):
    # for in_files_date in tqdm(split_consecutive_dates(mli_files_in)[-6:-5]):
        
        # Adjust the range as necessary, starting from 2 since here we need timestep t=i-1 and i-2 in the data processing function
        # args_for_processing = [(i, nc_files_in) for i in range(2, len(nc_files_in))]
        args_for_processing = [(i, in_files_date, lat, lon, iface, False, 'mli', 'mlo', 'rad', 'mlexpandcnv') for i in range(2, len(in_files_date))] # will create new input files with .mlexpand.

        with mp.Pool(num_processes) as pool:
            # Use pool.map to process files in parallel
            pool.map(process_one_file, args_for_processing)

  0%|          | 0/1 [00:00<?, ?it/s]

In [13]:
old_date = -1
for t in test:
    for f in t:
        s = re.match('.*E3SM-MMF.mli.\d*-\d*-\d*-(\d*).nc', f).group(1)
        if old_date > float(s):
            print('Error: ', f)
        old_date = float(s)

Error:  /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-03/E3SM-MMF.mli.0001-03-01-00000.nc
Error:  /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-04/E3SM-MMF.mli.0001-04-01-00000.nc
Error:  /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-05/E3SM-MMF.mli.0001-05-01-00000.nc
Error:  /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-06/E3SM-MMF.mli.0001-06-01-00000.nc
Error:  /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-07/E3SM-MMF.mli.0001-07-01-00000.nc
Error:  /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-08/E3SM-MMF.mli.0001-08-01-00000.nc
Error:  /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-09/E3SM-MMF.mli.0001-09-01-00000.nc
Error:  /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-10/E3SM-MMF.mli.0001-10-01-00000.nc
Error:  /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-11/E3SM-MMF.mli.0001-11-01-00000.nc
Error:  /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-r

In [None]:
!scancel 10756330

In [12]:
%ls /p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-02/*mlexpandcnv*.nc | wc -l

382


## What does the process_one_file function do

We had to put the process_one_file function in a separate .py file to let the multiprocessing function to work without problem. We copied the process_one_file function in climsim_adding_input.py below for your convenience to check what is inside the process_one_file function.

In [9]:
def process_one_file_copy(args):
    """
    Process a single NetCDF file by updating its dataset with information from previous files.
    
    Args:
        i: int
            The index of the current file in the full file list.
        nc_files_in: list of str
            List of the full filenames.
        lat: xarray.DataArray
            DataArray of latitude.
        lon: xarray.DataArray
            DataArray of longitude.
        input_abbrev: str
            The input file name abbreviation, the default input data should be 'mli'.
        output_abbrev: str
            The output file name abbreviation, the default output data should be 'mlo'.
        input_abbrev_new: str
            The abbreviation for the new input file name.
    
    Returns:
        None
    """
    timestep = 1200 # s
    ncols = 21600
    
    i, nc_files_in, lat, lon, iface, mask_strato, input_abbrev, output_abbrev, rad_abbrev, input_abbrev_new = args
    xr_args = dict()
    # xr_args = dict(chunks='auto')
    # xr_args = dict(chunks={'lev':1})
    dsin = xr.open_dataset(nc_files_in[i], **xr_args)
    dsin_prev = xr.open_dataset(nc_files_in[i-1], **xr_args)
    dsin_prev2 = xr.open_dataset(nc_files_in[i-2], **xr_args)
    dsout = xr.open_dataset(nc_files_in[i].replace(input_abbrev, output_abbrev), **xr_args)
    dsout_prev = xr.open_dataset(nc_files_in[i-1].replace(input_abbrev, output_abbrev), **xr_args)
    dsout_prev2 = xr.open_dataset(nc_files_in[i-2].replace(input_abbrev, output_abbrev), **xr_args)
    dsrad = xr.open_dataset(nc_files_in[i].replace(input_abbrev, rad_abbrev), **xr_args).rename({'col': 'ncol'})
    dsrad_prev = xr.open_dataset(nc_files_in[i-1].replace(input_abbrev, rad_abbrev), **xr_args).rename({'col': 'ncol'})
    dsrad_prev2 = xr.open_dataset(nc_files_in[i-2].replace(input_abbrev, rad_abbrev), **xr_args).rename({'col': 'ncol'})

    dsin['tm_state_t'] = dsin_prev['state_t']
    dsin['tm_state_q0001'] = dsin_prev['state_q0001']
    dsin['tm_state_q0002'] = dsin_prev['state_q0002']
    dsin['tm_state_q0003'] = dsin_prev['state_q0003']
    dsin['tm_state_u'] = dsin_prev['state_u']
    dsin['tm_state_v'] = dsin_prev['state_v']
    
    dsin['state_t_phy'] = (dsout['state_t'] - dsin['state_t']) / timestep - dsrad['ptend_t']
    dsin['state_q0001_phy'] = (dsout['state_q0001'] - dsin['state_q0001']) / timestep
    dsin['state_q0002_phy'] = (dsout['state_q0002'] - dsin['state_q0002']) / timestep
    dsin['state_q0003_phy'] = (dsout['state_q0003'] - dsin['state_q0003']) / timestep
    dsin['state_u_phy'] = (dsout['state_u'] - dsin['state_u']) / timestep
    dsin['state_v_phy'] = (dsout['state_v'] - dsin['state_v']) / timestep
    
    if mask_strato:
        tropopauses = tropopause_profile_2d(dsin.state_pmid.values,
                                             dsin.state_t.values,
                                             qv_profile=dsin.state_q0001.values,
                                             pmin=0.01, pmax=450e2)
        ptp = tropopauses[:,1]
        mask_strato = dsin.state_pmid.values < ptp

        for varname in dsin:
            if varname.endswith('phy'):
                # print(f'Stratospheric masking of {varname}')
                dsin[varname].values[mask_strato] = 0

    dsin['state_t_prvphy'] = (dsout_prev['state_t'] - dsin_prev['state_t']) / timestep - dsrad_prev['ptend_t']
    dsin['state_q0001_prvphy'] = (dsout_prev['state_q0001'] - dsin_prev['state_q0001']) / timestep
    dsin['state_q0002_prvphy'] = (dsout_prev['state_q0002'] - dsin_prev['state_q0002']) / timestep
    dsin['state_q0003_prvphy'] = (dsout_prev['state_q0003'] - dsin_prev['state_q0003']) / timestep
    dsin['state_u_prvphy'] = (dsout_prev['state_u'] - dsin_prev['state_u']) / timestep

    dsin['tm_state_t_prvphy'] = (dsout_prev2['state_t'] - dsin_prev2['state_t']) / timestep - dsrad_prev2['ptend_t']
    dsin['tm_state_q0001_prvphy'] = (dsout_prev2['state_q0001'] - dsin_prev2['state_q0001']) / timestep
    dsin['tm_state_q0002_prvphy'] = (dsout_prev2['state_q0002'] - dsin_prev2['state_q0002']) / timestep
    dsin['tm_state_q0003_prvphy'] = (dsout_prev2['state_q0003'] - dsin_prev2['state_q0003']) / timestep
    dsin['tm_state_u_prvphy'] = (dsout_prev2['state_u'] - dsin_prev2['state_u']) / timestep

    dsin['state_t_dyn'] = (dsin['state_t'] - dsout_prev['state_t']) / timestep
    dsin['state_q0_dyn'] = (dsin['state_q0001'] - dsout_prev['state_q0001'] + dsin['state_q0002'] - dsout_prev['state_q0002'] + dsin['state_q0003'] - dsout_prev['state_q0003']) / timestep
    dsin['state_u_dyn'] = (dsin['state_u'] - dsout_prev['state_u']) / timestep

    dsin['tm_state_t_dyn'] = (dsin_prev['state_t'] - dsout_prev2['state_t']) / timestep
    dsin['tm_state_q0_dyn'] = (dsin_prev['state_q0001'] - dsout_prev2['state_q0001'] + dsin_prev['state_q0002'] - dsout_prev2['state_q0002'] + dsin_prev['state_q0003'] - dsout_prev2['state_q0003']) / timestep
    dsin['tm_state_u_dyn'] = (dsin_prev['state_u'] - dsout_prev2['state_u']) / timestep
    
    dsin['dP'] = get_pressure_thickness(dsin['state_ps'], iface, dsin['state_pmid'].coords)

    dsin['tm_state_ps'] = dsin_prev['state_ps']
    dsin['tm_pbuf_SOLIN'] = dsin_prev['pbuf_SOLIN']
    dsin['tm_pbuf_SHFLX'] = dsin_prev['pbuf_SHFLX']
    dsin['tm_pbuf_LHFLX'] = dsin_prev['pbuf_LHFLX']
    dsin['tm_pbuf_COSZRS'] = dsin_prev['pbuf_COSZRS']

    dsin['lat'] = lat
    dsin['lon'] = lon
    clat = lat.copy()
    slat = lat.copy()
    icol = lat.copy()
    clat[:] = np.cos(lat*np.pi/180.)
    slat[:] = np.sin(lat*np.pi/180.)
    icol[:] = np.arange(1,ncols+1)
    dsin['clat'] = clat
    dsin['slat'] = slat
    dsin['icol'] = icol

    new_file_path = nc_files_in[i].replace(input_abbrev, input_abbrev_new)
    dsin.to_netcdf(new_file_path)
    # delayed_obj = dsin.to_netcdf(new_file_path, compute=False)
    # with ProgressBar():
    #     results = delayed_obj.compute()

    return None

In [10]:
in_files_date = split_consecutive_dates(mli_files_in)[0]
args_for_processing = [(i, in_files_date, lat, lon, iface, True, 'mli', 'mlo', 'rad', 'mlexpandconserv_mask') for i in range(2, len(in_files_date))] # will create new input files with .mlexpand.

process_one_file_copy(args_for_processing[0])

In [12]:
print("Compare masked vs unmasked")

# ds1 = xr.open_dataset('/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-02/E3SM-MMF.mlexpandconserv.0001-02-01-02400.nc')
# ds2 = xr.open_dataset('/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-02/E3SM-MMF.mlexpandconserv_mask.0001-02-01-02400.nc')

# ds_compare = ds1 == ds2
# # display(ds_compare)

# np.all(ds_compare.state_t_phy.values)

# for varname in ds_compare:
#     if varname.endswith('_phy'):
#         print(np.sum(ds_compare[varname], axis=1))

Compare masked vs unmasked


In [10]:
# n_workers = num_cpus
threads_per_worker=4
# client = Client(threads_per_worker=num_cpus//n_workers, n_workers=n_workers, dashboard_address=':8787')#, memory_limit="auto")
client = Client(threads_per_worker=threads_per_worker, n_workers=num_cpus//threads_per_worker, dashboard_address=':8787')#, memory_limit="auto")
print('Dashboard Link: ', client.dashboard_link)

Dashboard Link:  http://127.0.0.1:8787/status


In [11]:
args_for_processing = [(i, mli_files_in, lat, lon, 'mli', 'mlo', 'mlexpandcnv') for i in range(32, 64)] # will create new input files with .mlexpand.

futures = client.map(process_one_file, args_for_processing)
results = client.gather(futures)
results

2024-09-19 21:35:59,279 - distributed.worker - ERROR - 
Traceback (most recent call last):
  File "/p/project/icon-a-ml/mambaforge/envs/heuer1_climlab/lib/python3.11/asyncio/locks.py", line 213, in wait
    await fut
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/p/project/icon-a-ml/mambaforge/envs/heuer1_climlab/lib/python3.11/site-packages/distributed/nanny.py", line 981, in run
    await worker.finished()
  File "/p/project/icon-a-ml/mambaforge/envs/heuer1_climlab/lib/python3.11/site-packages/distributed/core.py", line 630, in finished
    await self._event_finished.wait()
  File "/p/project/icon-a-ml/mambaforge/envs/heuer1_climlab/lib/python3.11/asyncio/locks.py", line 216, in wait
    self._waiters.remove(fut)
  File "/p/project/icon-a-ml/mambaforge/envs/heuer1_climlab/lib/python3.11/asyncio/runners.py", line 157, in _on_sigint
    raise KeyboardInterrupt()
KeyboardInterrupt

g han

# debug

In [21]:
process_one_file_copy((5, mli_files_in, lat, lon, 'mli', 'mlo', 'mlexpandcnv'))

In [32]:
# mli_files_in[2]

'/p/scratch/icon-a-ml/heuer1/LEAP/ClimSim_high-res/train/0001-02/E3SM-MMF.mli.0001-02-01-02400.nc\n'

In [8]:
# from dask.distributed import Client, progress
# n_workers = 4
# client = Client(threads_per_worker=num_cpus//n_workers, n_workers=n_workers)
# client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 96,Total memory: 503.18 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:36673,Workers: 8
Dashboard: http://127.0.0.1:8787/status,Total threads: 96
Started: Just now,Total memory: 503.18 GiB

0,1
Comm: tcp://127.0.0.1:45287,Total threads: 12
Dashboard: http://127.0.0.1:36939/status,Memory: 62.90 GiB
Nanny: tcp://127.0.0.1:43141,
Local directory: /tmp/dask-scratch-space-21188/worker-n57wusn3,Local directory: /tmp/dask-scratch-space-21188/worker-n57wusn3

0,1
Comm: tcp://127.0.0.1:37157,Total threads: 12
Dashboard: http://127.0.0.1:37481/status,Memory: 62.90 GiB
Nanny: tcp://127.0.0.1:36883,
Local directory: /tmp/dask-scratch-space-21188/worker-spop0x3r,Local directory: /tmp/dask-scratch-space-21188/worker-spop0x3r

0,1
Comm: tcp://127.0.0.1:42063,Total threads: 12
Dashboard: http://127.0.0.1:36911/status,Memory: 62.90 GiB
Nanny: tcp://127.0.0.1:45597,
Local directory: /tmp/dask-scratch-space-21188/worker-atm2afpf,Local directory: /tmp/dask-scratch-space-21188/worker-atm2afpf

0,1
Comm: tcp://127.0.0.1:39557,Total threads: 12
Dashboard: http://127.0.0.1:46747/status,Memory: 62.90 GiB
Nanny: tcp://127.0.0.1:39501,
Local directory: /tmp/dask-scratch-space-21188/worker-jtqicfbx,Local directory: /tmp/dask-scratch-space-21188/worker-jtqicfbx

0,1
Comm: tcp://127.0.0.1:41245,Total threads: 12
Dashboard: http://127.0.0.1:38079/status,Memory: 62.90 GiB
Nanny: tcp://127.0.0.1:37883,
Local directory: /tmp/dask-scratch-space-21188/worker-r5fbvzj0,Local directory: /tmp/dask-scratch-space-21188/worker-r5fbvzj0

0,1
Comm: tcp://127.0.0.1:41279,Total threads: 12
Dashboard: http://127.0.0.1:43043/status,Memory: 62.90 GiB
Nanny: tcp://127.0.0.1:43535,
Local directory: /tmp/dask-scratch-space-21188/worker-r68plhlt,Local directory: /tmp/dask-scratch-space-21188/worker-r68plhlt

0,1
Comm: tcp://127.0.0.1:35597,Total threads: 12
Dashboard: http://127.0.0.1:44287/status,Memory: 62.90 GiB
Nanny: tcp://127.0.0.1:46009,
Local directory: /tmp/dask-scratch-space-21188/worker-n789hvbn,Local directory: /tmp/dask-scratch-space-21188/worker-n789hvbn

0,1
Comm: tcp://127.0.0.1:35779,Total threads: 12
Dashboard: http://127.0.0.1:37499/status,Memory: 62.90 GiB
Nanny: tcp://127.0.0.1:41523,
Local directory: /tmp/dask-scratch-space-21188/worker-ddw0zc61,Local directory: /tmp/dask-scratch-space-21188/worker-ddw0zc61


In [None]:
# args_for_processing = [(i, nc_files_in, lat, lon, 'mli', 'mlo', 'mlexpandrad') for i in range(2, 32)] # will create new input files with .mlexpand.

# futures = client.map(process_one_file, args_for_processing)
# results = client.gather(futures)