In [1]:
import os
import re
import numpy as np
import xarray as xr
from glob import glob

# 1. 配置

In [None]:
fp = '/data/keeling/a/xx24/e/proj_ml/cesm_data'
dataset = 'd651001'
data_dir = os.path.join(fp, dataset)

output_dir = os.path.join(data_dir, "ensemble_with_bulk")
os.makedirs(output_dir, exist_ok=True)

densities = {
    'so4': 1800,
    'ncl': 2200,
    'pom': 1000,
    'bc': 1800,
    'dst': 2600,
    'soa': 1400,
}

# 2. 扫描文件 & 按 (ens_letter, time, variable) 分组

In [None]:
pattern = re.compile(
    r'b\.e21\.(?P<prefix>BHIST_CESM2_2010_01_ens)'
    r'(?P<num>\d+)(?P<letter>[a-z]?)\.cam\.h0'
    r'\.(?P<var>[^.]+)\.(?P<time>[^.]+)\.nc'
)

grouped = {}
for path in glob(os.path.join(data_dir, "b.e21.*.cam.h0.*.nc")):
    fname = os.path.basename(path)
    m = pattern.match(fname)
    if not m:
        continue
    letter = m.group('letter') or 'a'
    time_part = m.group('time')
    var = m.group('var')
    key = (letter, time_part, var)
    grouped.setdefault(key, []).append(path)

# 3. 对每组进行 ensemble 平均

In [None]:
ensemble_avg = {}
for (letter, time_part, var), paths in grouped.items():
    arrs = []
    coords0 = None
    for p in paths:
        ds = xr.open_dataset(p)
        if var in ds:
            arrs.append(ds[var])
            if coords0 is None:
                coords0 = ds.coords
        ds.close()
    if not arrs:
        continue
    combined = xr.concat(arrs, dim="ens_member")
    avg_da = combined.mean(dim="ens_member")
    ensemble_avg[(letter, time_part, var)] = (avg_da, coords0)

# 4. 将 RHO_CLUBB 从 ilev 降到 lev

In [None]:
new_avg = {}
for (letter, time_part, var), (da, coords) in ensemble_avg.items():
    if var == 'RHO_CLUBB':
        ilev = da.coords['ilev'].values    
        t    = da.coords['time'].values
        lat  = da.coords['lat'].values
        lon  = da.coords['lon'].values
        data = da.data                     

        data_mid = 0.5 * (data[:, :-1, :, :] + data[:, 1:, :, :])   
        lev_mid  = 0.5 * (ilev[:-1] + ilev[1:])                     

        da_mid = xr.DataArray(
            data_mid,
            dims=('time', 'lev', 'lat', 'lon'),
            coords={'time': t, 'lev': lev_mid, 'lat': lat, 'lon': lon},
            name='RHO_CLUBB'
        )
        new_avg[(letter, time_part, var)] = (da_mid, da_mid.coords)
    else:
        new_avg[(letter, time_part, var)] = (da, coords)

ensemble_avg = new_avg

# 5. 计算 bulk diameter 及各 species mass_vol / number conc

In [None]:
species_list = ['so4','ncl','pom','bc','dst','soa','num']
bulk_results = {}

for letter, time_part in set((k[0], k[1]) for k in ensemble_avg):
    summed = {}
    coords0 = None
    for sp in species_list:
        arrs = [ da for (l,t,var),(da,_) in ensemble_avg.items()
                 if l==letter and t==time_part and var.startswith(sp+'_') ]
        if arrs:
            summed[sp] = sum(arrs)
            coords0 = arrs[0].coords

    da_rho, _ = ensemble_avg.get((letter, time_part, 'RHO_CLUBB'), (None,None))
    if da_rho is None:
        continue

    mass_vol = {}
    for sp, rho_sp in densities.items():
        da_sp = summed.get(sp)
        if da_sp is not None:
            mass_vol[sp] = da_sp * da_rho

    da_num = summed.get('num')
    if da_num is None:
        continue
    N_tot = da_num * da_rho

    M_tot = sum(mass_vol.values())                                
    V_tot = sum(mass_vol[sp] / densities[sp] for sp in mass_vol)  
    rho_mix = M_tot / V_tot
    d_bulk  = ((6 * M_tot) / (np.pi * rho_mix * N_tot)) ** (1/3)

    bulk_results[(letter, time_part)] = {
        'coords': coords0,
        'mass_vol': mass_vol,
        'N_tot':   N_tot,
        'bulk':    d_bulk
    }

# 6. 合并所有变量到一个 Dataset 并保存

In [None]:
for (letter, time_part), res in bulk_results.items():
    ds_out = xr.Dataset(coords=res['coords'])

    for sp, mv in res['mass_vol'].items():
        ds_out[sp] = mv

    ds_out['tot_number_conc'] = res['N_tot']

    ds_out['bulk_diameter'] = res['bulk']

    for var in ['T', 'RELHUM', 'RHO_CLUBB', 'CCN3']:
        entry = ensemble_avg.get((letter, time_part, var))
        if entry:
            ds_out[var] = entry[0]

    fname = f"b.e21.ens_{letter}.cam.h0.all_vars.{time_part}.nc"
    outpath = os.path.join(output_dir, fname)
    ds_out.to_netcdf(outpath)
    print("Saved:", outpath)

In [None]:
filepaths = [
    '/data/keeling/a/xx24/e/proj_ml/cesm_data/d651001/ensemble_combined/b.e21.ens_a.cam.h0.all_vars.201001-201112.nc',
    '/data/keeling/a/xx24/e/proj_ml/cesm_data/d651001/ensemble_combined/b.e21.ens_b.cam.h0.all_vars.201001-201112.nc',
    '/data/keeling/a/xx24/e/proj_ml/cesm_data/d651001/ensemble_combined/b.e21.ens_c.cam.h0.all_vars.201001-201112.nc',
    '/data/keeling/a/xx24/e/proj_ml/cesm_data/d651001/ensemble_combined/b.e21.ens_d.cam.h0.all_vars.201001-201112.nc',
    '/data/keeling/a/xx24/e/proj_ml/cesm_data/d651001/ensemble_combined/b.e21.ens_e.cam.h0.all_vars.201001-201112.nc'
]

ens_members = [f.split('ens_')[1][0] for f in filepaths]

datasets = []
for member, file in zip(ens_members, filepaths):
    ds = xr.open_dataset(file)
    ds_expanded = ds.expand_dims(ensemble=[member])
    datasets.append(ds_expanded)

combined = xr.concat(datasets, dim='ensemble')

mean_ds = combined.mean(dim='ensemble', keep_attrs=True)  # 保留属性

mean_ds.to_netcdf('/data/keeling/a/xx24/e/proj_ml/cesm_data/d651001/ensemble_combined/ensemble_mean.nc')

print("Saved as ensemble_mean.nc")