In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
from scipy.stats import norm
from matplotlib.lines import Line2D
from matplotlib.ticker import FixedLocator, FuncFormatter, MaxNLocator, AutoMinorLocator
pd.set_option('display.max_rows', 20)
plt.rcParams['figure.dpi'] = 100
plt.rcParams['lines.markersize'] = 4
plt.rcParams['font.size'] = 14
plt.rcParams['axes.prop_cycle'] = plt.cycler('color', 'brgmyk')


In [None]:
from pathlib import Path
# location of processed monthly means produced by the "monthly_averages_git" notebook
DATA_DIR = Path("./trends/monthly_averages")

In [None]:
# read files with monthly average values
data_68 = pd.read_csv(DATA_DIR / 'monthly_averages_68_n2o.csv')
data_46 = pd.read_csv(DATA_DIR / 'monthly_averages_46_n2o.csv')
data_32 = pd.read_csv(DATA_DIR / 'monthly_averages_32_n2o.csv')
data_22 = pd.read_csv(DATA_DIR / 'monthly_averages_22_n2o.csv')

In [None]:
data_68['date'] = pd.to_datetime(data_68['date'])
data_46['date'] = pd.to_datetime(data_46['date'])
data_32['date'] = pd.to_datetime(data_32['date'])
data_22['date'] = pd.to_datetime(data_22['date'])

In [None]:
o3_68 = pd.read_csv(DATA_DIR / 'monthly_averages_68_o3.csv')
o3_46 = pd.read_csv(DATA_DIR / 'monthly_averages_46_o3.csv')
o3_32 = pd.read_csv(DATA_DIR / 'monthly_averages_32_o3.csv')
o3_22 = pd.read_csv(DATA_DIR / 'monthly_averages_22_o3.csv')

In [None]:
o3_68['date'] = pd.to_datetime(o3_68['date'])
o3_46['date'] = pd.to_datetime(o3_46['date'])
o3_32['date'] = pd.to_datetime(o3_32['date'])
o3_22['date'] = pd.to_datetime(o3_22['date'])

In [None]:
lat_dict = {'bashkortostan': 55.5,
    'brunei': 5.3,
    'california': 38.0,
    'estonia': 58.4,
    'florianopolis': -28.6,
    'florida': 26.4,
    'khabarovsk': 48.5,
    'iceland_e': 64.3,
    'iceland_w': 64.8,
    'colombia': 6.7,
    'kyrgyzstan': 41.9,
    'morocco': 35.2,
    'bozeman': 45.6,
    'huntingdon': 45.1,
    'mukhrino': 60.9,
    'myanmar': 20.6,
    'pantanal': -16.5,
    'quistococha': -3.9,
    'french_guiana': 5.1,
    'catalonia': 42.7,
    'france': 47.3,
    'romania': 47.3,
    'finland': 68.0,
    'taiwan': 22.2,
    'tasmania': -41.5,
    'tierra_del_fuego': -54.7,
    'uganda_e': 0.9,
    'uganda_n': 1.2,
    'uganda_s': -1.2,
    'nz_s': -46.6,
    'nz_n': -37.4,
    'wales': 53.0,
    'kongo': 1.35,
    'mexico': 19.3,
    'tarapoto': -6.5
           }

In [None]:
data_68["lat"] = data_68["location"].map(lat_dict)
data_46["lat"] = data_46["location"].map(lat_dict)
data_32["lat"] = data_32["location"].map(lat_dict)
data_22["lat"] = data_22["location"].map(lat_dict)

In [None]:
o3_68["lat"] = o3_68["location"].map(lat_dict)
o3_46["lat"] = o3_46["location"].map(lat_dict)
o3_32["lat"] = o3_32["location"].map(lat_dict)
o3_22["lat"] = o3_22["location"].map(lat_dict)

In [None]:
def compute_location_stats(
    df,
    location_col: str = 'location',
    value_col:    str = 'mean_concentration',
    lat_col:      str = 'lat',
    ci:           float = 0.95
) -> pd.DataFrame:
    """
    Group the DataFrame by `location_col` and compute:
      - mean_concentration
      - std_concentration
      - count (n)
      - sem (std/sqrt(n))
      - ci_lower, ci_upper (mean ± z*sem)
      - lat (first latitude per group for plotting)
    """
    # 1) Aggregate mean, std, count, and grab the first lat
    grouped = (
        df
        .groupby(location_col)
        .agg(
            mean_concentration=(value_col, 'mean'),
            std_concentration =(value_col, 'std'),
            count            =(value_col, 'size'),
            lat              =(lat_col, 'first')
        )
        .reset_index()
    )

    # 2) Standard error of the mean
    grouped['sem'] = grouped['std_concentration'] / np.sqrt(grouped['count'])

    # 3) Find z-score for desired confidence level
    z = norm.ppf(1 - (1 - ci) / 2)

    # 4) Compute confidence bounds
    grouped['ci_lower'] = grouped['mean_concentration'] - z * grouped['sem']
    grouped['ci_upper'] = grouped['mean_concentration'] + z * grouped['sem']

    return grouped


In [None]:
# DataFrames into a list:
dfs_n2o = [data_22, data_32, data_46, data_68]
dfs_o3 = [o3_22, o3_32, o3_46, o3_68]

In [None]:
majors = [-60, -30, 0, 30, 60]
minors = [-45, -15, 15, 45]


# 
def plot_dual_axis_grid(
    n2o_dfs,
    o3_dfs,
    altitudes=(22, 32, 46, 68),
    stats_func=compute_location_stats,
    ci_level=0.95,
    save_path: str = None,
    save_dpi: int = 300,
    show: bool = True
):
    """
    n2o_dfs, o3_dfs: lists of four DataFrames each (same order of pressures)
    altitudes:     tuple/list of four pressure levels, for titles/labels
    stats_func:    function to compute mean, SEM, CI
    ci_level:      confidence level for error bars
    
    Produces a 2×2 grid of subplots, each plotting N₂O (blue) and
    O₃ (orange) vs latitude at one pressure level, with a single legend.
    No connecting lines—markers only.
    """
    # Precompute summaries
    n2o_summaries = [stats_func(df, ci=ci_level) for df in n2o_dfs]
    o3_summaries  = [stats_func(df, ci=ci_level) for df in o3_dfs]
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True)
    axes = axes.flatten()
    
    for ax in axes:
        ax.set_xlim(-70, 70)
        # Major/minor tick placement
        ax.xaxis.set_major_locator(FixedLocator(majors))
        ax.xaxis.set_minor_locator(FixedLocator(minors))
        # Styling: longer/thicker majors, shorter/thinner minors
        ax.tick_params(axis='x', which='major', length=7, width=1.4)
        ax.tick_params(axis='x', which='minor', length=4, width=0.8)
        # Grid on majors only (keeps it clean)
        ax.grid(True, which='major', linestyle=':', alpha=0.6)
    
    for ax, n2o_sum, o3_sum, alt in zip(axes, n2o_summaries, o3_summaries, altitudes):
        # N₂O: blue circles, marker-only
        ax.set_xlim(-70, 70)
        err_n2o = [
            n2o_sum['mean_concentration'] - n2o_sum['ci_lower'],
            n2o_sum['ci_upper']  - n2o_sum['mean_concentration']
        ]
        ax.errorbar(
            n2o_sum['lat'],
            n2o_sum['mean_concentration'],
            yerr=err_n2o,
            fmt='o',
            linestyle='None',
            color='dodgerblue',
            capsize=3,
            label='_nolegend_'
        )
        ax.set_ylabel('N₂O (ppbv)')
        ax.grid(True, linestyle=':')
        
        # O₃: red squares, marker-only
        ax2 = ax.twinx()
        err_o3 = [
            o3_sum['mean_concentration'] - o3_sum['ci_lower'],
            o3_sum['ci_upper']  - o3_sum['mean_concentration']
        ]
        ax2.errorbar(
            o3_sum['lat'],
            o3_sum['mean_concentration'],
            yerr=err_o3,
            fmt='o',
            linestyle='None',
            color='coral',
            capsize=3,
            label='_nolegend_'
        )
        ax2.set_ylabel('O₃ (ppmv)')
        
        ax.set_title(f'{alt} hPa')
        
        for a in (ax, ax2):
            # Aim for 4–5 majors; pick nice steps (1, 2, 2.5, 5, 10)
            a.yaxis.set_major_locator(MaxNLocator(nbins=5, steps=[1, 2, 2.5, 5, 10], prune=None))
            a.yaxis.set_minor_locator(AutoMinorLocator(2))  # 1 minor between majors
            a.tick_params(axis='y', which='major', length=7, width=1.2)
            a.tick_params(axis='y', which='minor', length=4, width=0.8)
    
    # Labels: bottom row only, majors labeled; minors unlabeled
    fmt = FuncFormatter(lambda x, pos: f"{int(x)}")
    for ax in axes[:2]:
        ax.tick_params(axis='x', which='both', labelbottom=False)

    for ax in axes[2:]:
        ax.set_xlabel('Latitude (°)')
        ax.xaxis.set_major_formatter(fmt)
    
    # Reserve space at bottom for legend
    fig.tight_layout(rect=[0, 0.08, 1, 1])
    
    # Single legend with two entries
    proxy_handles = [
        Line2D([0], [0], marker='o', linestyle='None', markersize = 8, color='dodgerblue', label='N₂O'),
        Line2D([0], [0], marker='o', linestyle='None', markersize = 8, color='coral',  label='O₃'),
    ]
    fig.legend(
        handles=proxy_handles,
        loc='lower center',
        ncol=2,
        frameon=False,
        fontsize='medium'
    )
    if save_path:
        fig.savefig(save_path, dpi=save_dpi, bbox_inches='tight')

    if show:
        plt.show()


In [None]:
plot_dual_axis_grid(dfs_n2o, dfs_o3, save_path = None)