# chunking data into the MPAS grid

## import package

In [1]:
import numpy as np
import joblib as jl
from joblib import Parallel, delayed


## Load data

In [2]:
# directory
path: str = '/work/b11209013/2024_Research/CloudSat/'

# case name
case: int = 8

# file name
fname = f'{path}CloudSat_itp_{case:03}.joblib'

# load data
data = jl.load(fname)

lon : np.ndarray[float] = data['lon']
lat : np.ndarray[float] = data['lat']
hgt : np.ndarray[float] = data['hgt']
qlw : np.ndarray[float] = data['qlw']
qsw : np.ndarray[float] = data['qsw']
ldate, lswath, lray, lbin = qsw.shape

del data

## Chunking data

In [3]:
# Reshape data
lon_rs: np.ndarray = lon.reshape((-1, lray))
lat_rs: np.ndarray = lat.reshape((-1, lray))
qlw_rs: np.ndarray = qlw.reshape((-1, lray, lbin))
qsw_rs: np.ndarray = qsw.reshape((-1, lray, lbin))

# setup chunk condition
lon_cond: np.ndarray = np.arange(160, 260+0.1, 0.5)
lat_cond: np.ndarray = np.arange(-15, 15+0.1, 0.5)

# Pre-allocate output arrays (time, lat_bins, lon_bins, bins)
qlw_chunk = np.full((qlw_rs.shape[0], len(lat_cond) - 1, len(lon_cond) - 1, 18), np.nan)
qsw_chunk = np.full((qlw_rs.shape[0], len(lat_cond) - 1, len(lon_cond) - 1, 18), np.nan)

# Flatten latitude and longitude arrays once
lon_rs_flat = lon_rs.reshape(lon_rs.shape[0], -1)
lat_rs_flat = lat_rs.reshape(lat_rs.shape[0], -1)

# Precompute masks for latitude and longitude bins
lat_masks = [
    (lat_rs_flat >= lat_cond[la]) & (lat_rs_flat < lat_cond[la + 1])
    for la in range(len(lat_cond) - 1)
]
lon_masks = [
    (lon_rs_flat >= lon_cond[lo]) & (lon_rs_flat < lon_cond[lo + 1])
    for lo in range(len(lon_cond) - 1)
]

# Pre-flatten the data arrays
qlw_rs_flat = qlw_rs.reshape(qlw_rs.shape[0], -1, qlw_rs.shape[2])
qsw_rs_flat = qsw_rs.reshape(qsw_rs.shape[0], -1, qsw_rs.shape[2])

# Function to process a single time slice
def process_time_slice(time_idx):
    # Allocate temporary arrays for the current time step
    qlw_result = np.full((len(lat_cond) - 1, len(lon_cond) - 1, 18), np.nan)
    qsw_result = np.full((len(lat_cond) - 1, len(lon_cond) - 1, 18), np.nan)

    for la, lat_mask in enumerate(lat_masks):
        for lo, lon_mask in enumerate(lon_masks):
            # Combine latitude and longitude masks
            combined_mask = lat_mask[time_idx] & lon_mask[time_idx]

            # Skip if no valid points
            if not np.any(combined_mask):
                continue

            # Apply mask to select data points for all bins (last dimension)
            qlw_data = qlw_rs_flat[time_idx][combined_mask, :]
            qsw_data = qsw_rs_flat[time_idx][combined_mask, :]

            # Compute mean values across bins (axis 0 represents spatial points)
            if not np.all(np.isnan(qlw_data)):
                qlw_result[la, lo, :] = np.nanmean(qlw_data, axis=0)

            if not np.all(np.isnan(qsw_data)):
                qsw_result[la, lo, :] = np.nanmean(qsw_data, axis=0)

    return qlw_result, qsw_result

# Parallelize over time slices
results = Parallel(n_jobs=-1)(
    delayed(process_time_slice)(i) for i in range(qlw_rs.shape[0])
)

# Unpack the results into the pre-allocated arrays
for i, (qlw_result, qsw_result) in enumerate(results):
    qlw_chunk[i] = qlw_result
    qsw_chunk[i] = qsw_result

qlw_chunk = qlw_chunk.reshape((ldate, lswath, len(lat_cond) - 1, len(lon_cond) - 1, 18))
qsw_chunk = qsw_chunk.reshape((ldate, lswath, len(lat_cond) - 1, len(lon_cond) - 1, 18))

In [4]:
output_dict = {
    'lon': lon_cond[:-1],
    'lat': lat_cond[:-1],
    'hgt': hgt,
    'qlw': qlw_chunk,
    'qsw': qsw_chunk,
    'description': 'qlw, qsw dimension: (date, max_swath, lat, lon, hgt)'
}

jl.dump(output_dict, f'/work/b11209013/2024_Research/CloudSat/CloudSat_chunk_{case:003}.joblib', compress=('zlib', 1))

['/work/b11209013/2024_Research/CloudSat/CloudSat_chunk_008.joblib']