In [1]:
import xarray as xr
import matplotlib.pyplot as plt
import xarray_tools as xrt
import cf_units, datetime
import numpy as np

ModuleNotFoundError: No module named 'xarray_tools'

In [None]:
d = xr.open_dataset('/g/data/qv56/replicas/input4MIPs/CMIP7/CMIP/uoexeter/UOEXETER-CMIP-2-0-0/atmos/mon/ext/gnz/v20250227/ext_input4MIPs_aerosolProperties_CMIP_UOEXETER-CMIP-2-0-0_gnz_175001-202312.nc')

In [None]:
w550 = 7

lat_wts = np.diff(np.sin(np.radians(d.lat_bnds)), axis=1)
lat_wts.shape = (len(d.lat),)
lat_wts /= lat_wts.sum()

In [None]:
ext_g = (d.ext[:,:,:,w550] * lat_wts[:, np.newaxis]).sum('lat')

In [None]:
fig = plt.figure(figsize=(10,4))
axes = fig.add_subplot(1,1,1)
(1000*ext_g[1200:]).plot(cmap='Reds', x='time', vmax=0.01)
axes.set_ylabel('Height')

In [None]:
d113 = xr.open_dataset('/g/data/qv56/replicas/input4MIPs/CMIP6Plus/CMIP/uoexeter/UOEXETER-CMIP-1-1-3/atmos/mon/ext/gnz/v20240903/ext_input4MIPs_aerosolProperties_CMIP_UOEXETER-CMIP-1-1-3_gnz_175001-202312.nc')

In [None]:
# Note different wavelength
ext113_g = (d113.ext[:,:,:,9] * lat_wts[:, np.newaxis]).sum('lat')

In [None]:
fig = plt.figure(figsize=(10,8))
axes = fig.add_subplot(2,1,1)
(1000*ext113_g[1200:]).plot(cmap='Reds', x='time', vmax=0.01, cbar_kwargs={'label':'km$^{-1}$'})
axes.set_ylabel('Height')
axes.set_title('CMIP6Plus (1-1-3) extinction at 550 nm')
axes = fig.add_subplot(2,1,2)
(1000*ext_g[1200:]).plot(cmap='Reds', x='time', vmax=0.01, cbar_kwargs={'label':'km$^{-1}$'})
axes.set_ylabel('Height')
axes.set_title('CMIP7Plus (2-0-0) extinction at 550 nm')
plt.tight_layout()

In [None]:
thick = d.height_bnds[:,1] - d.height_bnds[:,0]

In [None]:
ext_tot = (ext_g * thick).sum('height')
ext_tot_ann = xrt.annual_mean(ext_tot)
fig = plt.figure(figsize=(10,4))
axes = fig.add_subplot(1,1,1)
ext_tot.plot()
axes.set_title('Global mean monthly SAOD at 550 nm')
axes.set_xlabel('Year')
axes.set_ylim(0,0.35)
taxis = cf_units.Unit("days since 1970-01-01 00:00", calendar="proleptic_gregorian")
t0 = taxis.date2num(datetime.datetime(1750,1,1,0,0,0))
t1 = taxis.date2num(datetime.datetime(2022,1,1,0,0,0))
print(t0, t1)
axes.set_xlim(t0,t1)
axes.set_xticks([taxis.date2num(datetime.datetime(y,1,1,0,0,0)) for y in range(1750,2001,50)])

In [None]:
def total_mean(ds):
    """ Properly month length weighted mean of a DataArray"""
    month_length = ds.time.dt.days_in_month
    mean = (ds*month_length).sum(dim='time') / month_length.sum()
    return mean

In [None]:
# PI should use 1850-2021 mean
ext_mean = total_mean(ext_tot.sel(time=slice("1850-01-01", "2021-12-01")))

In [None]:
ext_mean

In [None]:
# Sum over height
ext_tot = (d.ext[:,:,:,w550] * thick).sum('height')

# Equal area latitude bands used by ESM
bands = [slice(30,90), slice(0,30), slice(-30,0), slice(-90,-30)]
ext_band = np.zeros((len(ext_tot),4))
for b, band in enumerate(bands):
    # Factor of 4 for global area / band area
    ext_band[:,b] = 4*(ext_tot*lat_wts).sel(lat=band).sum('lat')

lat_band = xr.IndexVariable('band', [1,2,3,4])
ext_band = xr.DataArray(ext_band, coords=[ext_tot.time, lat_band])

In [None]:
for b in range(4):
    ext_band[1200:,b].plot()


In [None]:
esm_cmip6 = np.loadtxt('/g/data/vk83/experiments/inputs/access-esm1p5/modern/historical/atmosphere/forcing/resolution_independent/2021.06.22/volcts_cmip6.dat')
esm_cmip6_od = esm_cmip6[:,2:].mean(axis=1)

esm_cmip6_time = esm_cmip6[:,0] + esm_cmip6[:,1]/12.
fig, axes = plt.subplots()
plt.plot(esm_cmip6_time,esm_cmip6_od*1e-4,label='ESM1.5 CMIP6')
axes.set_title('Optical depth')

od_cmip7 = ext_band[1200:].mean('band')
plt.plot(esm_cmip6_time[:len(od_cmip7)], od_cmip7, label='CMIP7')

axes.set_xlim(1850,2015)
axes.set_ylim(0,0.18)
axes.legend()

In [None]:
ext_band

In [None]:
ext_mean = total_mean(ext_tot.sel(time=slice("1850-01-01", "2021-12-01")))
ext_mean_band = total_mean(ext_band.sel(time=slice("1850-01-01", "2021-12-01")))

In [None]:
print(ext_mean_band.values)
print(ext_mean_band.mean())

In [None]:
ext_mean.plot()