In [None]:
import numpy as np
#%matplotlib notebook
import matplotlib.pyplot as plt
import os
import pandas as pd
from scipy.stats import pearsonr
pd.set_option('display.max_rows', 20)
#from numpy.polynomial.polynomial import polyfit
plt.rcParams['figure.dpi'] = 100
plt.rcParams['lines.markersize'] = 4
plt.rcParams['font.size'] = 18
plt.rcParams['axes.prop_cycle'] = plt.cycler('color', 'brgmyk')
#plt.rcParams["font.family"] = "Times New Roman"


## Load data from files

In [None]:
from pathlib import Path
# location of processed data files created as shown in examples in data/code_examples folder
n2o_path = Path("./data/data_n2o")
o3_path = Path("./data/data_o3")

In [None]:
def load_gas_files(location, 
                   altitudes=[68, 46, 32, 22],
                   base_paths={'n2o': n2o_path,
                               'o3': o3_path},
                   has_header=False):
    """
    Load N2O and O3 data files for a given location and pressure altitudes.
    
    This function assumes:
      - Files are stored in subdirectories by altitude (e.g., './n2o/68/').
      - Each file is named as "gas_location_pressure.csv", for example, 
        "n2o_quistococha_68.csv".
      - Each file contains two columns:
            1. A date column (formatted like 2018311.0)
            2. A gas concentration column.
      - The columns are separated by one or more whitespace characters.
      - If the file does not have a header row, it assigns column names "date" and "concentration".
    
    Parameters:
        location (str): The location name (e.g., 'quistococha').
        altitudes (list of int, optional): List of pressure altitudes to load. Defaults to [68, 46, 32, 22].
        base_paths (dict, optional): Dictionary mapping gas names to their base directories.
                                     Defaults to {'n2o': './n2o', 'o3': './o3'}.
        has_header (bool, optional): Whether the CSV files include a header row. Defaults to False.
        
    Returns:
        dict: A nested dictionary where the first-level keys are the gas names and the
              second-level keys are the pressure altitudes. Each leaf is a pandas DataFrame.
              For example, result['n2o'][68] contains the DataFrame for 'n2o_quistococha_68.csv'.
    """
    data = {}
    
    for gas, folder in base_paths.items():
        data[gas] = {}
        for pressure in altitudes:
            filename = f"{gas}_{location}_{pressure}.csv"
            # Construct the path assuming files are in a subfolder named by the pressure altitude.
            file_path = os.path.join(folder, str(pressure), filename)
            try:
                if has_header:
                    df = pd.read_csv(file_path, sep='\s+', engine='python')
                else:
                    # Explicitly assign column names if no header exists.
                    df = pd.read_csv(file_path, sep='\s+', engine='python', header=None, names=["date", "concentration"])
                data[gas][pressure] = df
            except FileNotFoundError:
                print(f"File not found: {file_path}")
                data[gas][pressure] = None
                
    return data

## Combine N2O and O3 data

In [None]:
def combine_all(gas_data):
    """
    For each altitude in the gas_data dictionary, combine the N2O and O3 data by performing
    an inner join on the 'date' column. The function assumes that each DataFrame has at least
    the columns 'date' and 'concentration'.
    
    Parameters:
        gas_data (dict): A dictionary with keys 'n2o' and 'o3'. Each of these is a dictionary 
                         keyed by altitude, where the value is a DataFrame.
                         
    Returns:
        dict: A dictionary keyed by altitude. Each value is a combined DataFrame with columns:
              'date', 'n2o', and 'o3'.
    """
    combined_data = {}
    
    # Loop over altitudes (assuming both gases have the same altitude keys)
    for altitude in gas_data['n2o']:
        df_n2o = gas_data['n2o'][altitude]
        df_o3  = gas_data['o3'][altitude]
        
        if df_n2o is None or df_o3 is None:
            print(f"Skipping altitude {altitude} hPa because one of the data sets is missing.")
            combined_data[altitude] = None
            continue
        
        # Check that both DataFrames have a 'date' column
        if 'date' not in df_n2o.columns:
            print(f"Altitude {altitude} hPa: 'date' column missing in N2O data. Columns found: {df_n2o.columns}")
            combined_data[altitude] = None
            continue
        if 'date' not in df_o3.columns:
            print(f"Altitude {altitude} hPa: 'date' column missing in O3 data. Columns found: {df_o3.columns}")
            combined_data[altitude] = None
            continue
        
        # Merge on 'date' (inner join keeps only rows with shared dates)
        combined_df = pd.merge(df_n2o, df_o3, on='date', suffixes=('_n2o', '_o3'))
        # Rename columns to be clear
        combined_df.rename(columns={'concentration_n2o': 'n2o',
                                    'concentration_o3': 'o3'}, inplace=True)
        
        combined_data[altitude] = combined_df
        
    return combined_data

## Correlations with p-values

In [None]:
# Define a mapping from internal location names to reader-friendly names.
locations_dict = {
    'bashkortostan': 'Bashkortostan',
    'bozeman': 'Montana',
    'brunei': 'Kalimantan',
    'california': 'California',
    'catalonia': 'Catalonia',
    'colombia': 'Colombia',
    'estonia': 'Estonia',
    'finland': 'Finland',
    'florianopolis': 'Santa Catarina',
    'florida': 'Florida',
    'france': 'France',
    'french_guiana': 'French Guiana',
    'huntingdon': 'Southern Québec',
    'iceland_e': 'Iceland (E)',
    'iceland_w': 'Iceland (W)',
    'khabarovsk': 'Khabarovsk',
    'congo': 'Congo',
    'kyrgyzstan': 'Kyrgyzstan',
    'mexico': 'Xochimilco, Mexico City',
    'morocco': 'Morocco',
    'mukhrino': 'Mukhrino',
    'myanmar': 'Myanmar',
    'nz_n': 'New Zealand (N)',
    'nz_s': 'New Zealand (S)',
    'pantanal': 'Pantanal',
    'quistococha': 'Peruvian Amazon',
    'romania': 'Romania',
    'taiwan': 'Taiwan',
    'tarapoto': 'Tarapoto',
    'tasmania': 'Tasmania',
    'tierra_del_fuego': 'Tierra del Fuego',
    'uganda_e': 'Uganda (E)',
    'uganda_n': 'Uganda (N)',
    'uganda_s': 'Uganda (S)',
    'wales': 'Wales'
    # Add more mappings as needed.
}

In [None]:
def calculate_correlations_with_p(locations, 
                                  altitudes=[68, 46, 32, 22],
                                  base_paths={'n2o': n2o_path,
                                              'o3': o3_path},
                                  has_header=False):
    """
    For each location, load the gas data, combine it so that only days with both N₂O and O₃ measurements remain,
    and then calculate the Pearson correlation coefficient and p-value for each altitude using scipy.stats.pearsonr.
    
    Parameters:
      - locations (list of str): List of location names used in the file names.
      - altitudes (list of int, optional): List of pressure altitudes (default: [68, 46, 32, 22]).
      - base_paths (dict, optional): Dictionary mapping gas names to their base directories.
      - has_header (bool, optional): Whether the CSV files include a header row.
      
    Returns:
      dict: A nested dictionary with top-level keys as locations. For each location, the inner dictionary
            has altitudes as keys and a dictionary with keys 'r' (the correlation coefficient) and 'p' (the p-value)
            as values. If the correlation cannot be computed, both are set to None.
    """
    correlations = {}
    
    for loc in locations:
        # Load the data for the given location.
        gas_data = load_gas_files(loc, altitudes=altitudes, base_paths=base_paths, has_header=has_header)
        # Combine the data to retain only rows with both measurements.
        combined_data = combine_all(gas_data)
        
        correlations[loc] = {}
        for altitude in altitudes:
            df = combined_data.get(altitude)
            if df is not None and len(df) > 1:
                try:
                    # Compute Pearson correlation and p-value.
                    r, p = pearsonr(df['n2o'], df['o3'])
                    correlations[loc][altitude] = {'r': r**2, 'p': p}
                except Exception as e:
                    correlations[loc][altitude] = {'r': None, 'p': None}
            else:
                correlations[loc][altitude] = {'r': None, 'p': None}
                
    return correlations



In [None]:
if __name__ == '__main__':
    # List of locations (as used in your file names)
    locations = locations_dict.keys()
    
    # Calculate correlation coefficients for each location and altitude.
    corr_results = calculate_correlations_with_p(locations, 
                                          altitudes=[68, 46, 32, 22],
                                          has_header=False)
    
    # Print the results.
    for loc, alt_dict in corr_results.items():
        print(f"\nCorrelation coefficients for {loc}:")
        for altitude, values in alt_dict.items():
            r = values['r']
            p = values['p']
            if r is not None:
                print(f"  {altitude} hPa: r = {r:.3f}, p = {p:.5f}")
            else:
                print(f"  {altitude} hPa: Not enough data to compute correlation")

## Scatterplots with trend lines included

### Function with p-value markings

In [None]:
def _stars_for_p(p, levels=(0.001, 1e-6)):
    if np.isnan(p):
        return ''
    #if p < levels[2]:
    #    return '***'
    if p < levels[1]:
        return '**'
    if p < levels[0]:
        return '*'
    return ''

def plot_scatter_corr(locations, 
                      display_names=None,
                      altitudes=[68, 46, 32, 22],
                      base_paths={'n2o': n2o_path,
                                  'o3': o3_path},
                      has_header=False,
                      x_range=None,
                      y_range=None,
                      save_path=None,
                      save_dpi=300,
                      colors=None,
                      pstar_levels=(0.05, 1e-3, 1e-6)):
    """
    For a list of locations, load and combine the gas data, then create a grid of subplots.
    Each column corresponds to a location and each row to an altitude (top row = highest altitude / lowest pressure).
    Each subplot shows:
      - A scatter plot of N₂O (x-axis) vs. O₃ (y-axis)
      - A linear trend line fitted to the data (x values from 0 to 350)
      - An annotation with r² and significance stars based on Pearson correlation p-value
      - A title showing the altitude (in hPa)
    """
    # Sort altitudes so highest altitude (lowest pressure) is at the top.
    sorted_alts = sorted(altitudes)  # e.g., [22, 32, 46, 68]
    n_rows = len(sorted_alts)
    n_cols = len(locations)
    
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 4), sharex=True, sharey=True)
    # Handle case n_rows==1 or n_cols==1 so axs is 2D
    if n_rows == 1 and n_cols == 1:
        axs = np.array([[axs]])
    elif n_rows == 1:
        axs = axs.reshape(1, -1)
    elif n_cols == 1:
        axs = axs.reshape(-1, 1)

    for col, loc in enumerate(locations):
        disp_name = display_names[col] if display_names is not None and col < len(display_names) else loc
        col_color = colors[col] if colors is not None and col < len(colors) else None

        # Load and combine the data for this location.
        gas_data = load_gas_files(loc, altitudes=altitudes, base_paths=base_paths, has_header=has_header)
        combined_data = combine_all(gas_data)

        # Column header
        axs[0, col].annotate(disp_name, xy=(0.5, 1.15), xycoords='axes fraction',
                             ha='center', va='center', fontsize=20, fontweight='bold')

        for row, alt in enumerate(sorted_alts):
            ax = axs[row, col]
            df = combined_data.get(alt)

            if df is not None and not df.empty and ('n2o' in df.columns and 'o3' in df.columns):
                x = df['n2o']
                y = df['o3']

                # Scatter
                ax.scatter(x, y, alpha=0.7, color=col_color)

                # Trend line (robust to small issues via try/except)
                try:
                    m, b = np.polyfit(x.values, y.values, 1)
                    x_line = np.linspace(0, 350, 100)
                    y_line = m * x_line + b
                    ax.plot(x_line, y_line, color=col_color if col_color is not None else 'black',
                            lw=2, linestyle='--')
                except Exception:
                    pass

                # Pearson r, p and stars (clean NaNs first)
                mask = x.notna() & y.notna()
                xv, yv = x[mask].values, y[mask].values
                if xv.size >= 3:
                    r, p = pearsonr(xv, yv)
                    r2 = r * r
                    stars = _stars_for_p(p, levels=pstar_levels)
                else:
                    r2, p, stars = np.nan, np.nan, ''

                # Annotate with r² and significance stars
                ax.text(0.05, 0.95, f"$r^2$ = {r2:.2f}{stars}",
                        transform=ax.transAxes, fontsize=20, weight='bold',
                        va='top', bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))

                ax.set_title(f"{alt} hPa")
            else:
                ax.text(0.5, 0.5, "No data", transform=ax.transAxes,
                        ha='center', va='center', fontsize=12)
                ax.set_title(f"{alt} hPa")

            # Axis ranges
            if x_range is not None:
                ax.set_xlim(x_range)
            if y_range is not None:
                ax.set_ylim(y_range)

            # Labels
            if row == n_rows - 1:
                ax.set_xlabel("N₂O (ppbv)")
            else:
                ax.set_xlabel("")
            if col == 0:
                ax.set_ylabel("O₃ (ppmv)")
            else:
                ax.set_ylabel("")

            ax.grid(True)

    plt.tight_layout()
    plt.subplots_adjust(top=0.92)
    if save_path:
        plt.savefig(save_path, dpi=save_dpi)
    plt.show()


In [None]:
# Areas with similar longitude above the Americas:
if __name__ == '__main__':
    # Define the list of locations (used in file names).
    location_list = ['huntingdon', 'quistococha', 'tierra_del_fuego']
    # Define a corresponding list of display names for the subplot headers.
    display_list = ['Southern Québec','Peruvian Amazon', 'Tierra del Fuego']
    # Optionally define a list of colors for each column.
    color_list = ['steelblue', 'darkorange', 'forestgreen']
    # Altitudes can be provided in any order; they will be sorted (lowest pressure/highest altitude on top).
    alt_list = [68, 46, 32, 22]
    
    plot_scatter_corr(locations=location_list, display_names=display_list, altitudes=alt_list,
                      has_header=False,
                      x_range=(0, 350),   # Adjust as needed
                      y_range=(0, 8.2),    # Adjust as needed
                      save_path=None,
                      colors=color_list)


In [None]:
# Areas with similar longitude above Asia/Oceania:
if __name__ == '__main__':
    # Define the list of locations (used in file names).
    location_list = ['khabarovsk', 'brunei', 'tasmania']
    # Define a corresponding list of display names for the subplot headers.
    display_list = ['Khabarovsk','Kalimantan', 'Tasmania']
    # Optionally define a list of colors for each column.
    color_list = ['steelblue', 'darkorange', 'forestgreen']
    # Altitudes can be provided in any order; they will be sorted (lowest pressure/highest altitude on top).
    alt_list = [68, 46, 32, 22]
    
    plot_scatter_corr(locations=location_list, display_names=display_list, altitudes=alt_list,
                      has_header=False,
                      x_range=(0, 350),   # Adjust as needed
                      y_range=(0, 8.2),    # Adjust as needed
                      save_path=None,
                      colors=color_list)


### Plotting by latitude zone, figures with adjusted layouts also included in supplements

In [None]:
zone_dict = {
    "bashkortostan": 1,
    "bozeman": 1,
    "brunei": 2,
    "california": 1,
    "catalonia": 1,
    "colombia": 2,
    "estonia": 1,
    "finland": 1,
    "florianopolis": 2,
    "florida": 2,
    "france": 1,
    "french_guiana": 2,
    "huntingdon": 1,
    "iceland_e": 1,
    "iceland_w": 1,
    "khabarovsk": 1,
    "congo": 2,
    "kyrgyzstan": 1,
    "mexico": 2,
    "morocco": 1,
    "mukhrino": 1,
    "myanmar": 2,
    "nz_n": 3,
    "nz_s": 3,
    "pantanal": 2,
    "quistococha": 2,
    "romania": 1,
    "taiwan": 2,
    "tarapoto": 2,
    "tasmania": 3,
    "tierra_del_fuego": 3,
    "uganda_e": 2,
    "uganda_n": 2,
    "uganda_s": 2,
    "wales": 1
}

In [None]:
keys_nh = [k for k, v in zone_dict.items() if v == 1]
keys_eq = [k for k, v in zone_dict.items() if v == 2]
keys_sh = [k for k, v in zone_dict.items() if v == 3]

In [None]:
vals_nh = [locations_dict[k] for k, v in zone_dict.items() if v in {1}]
vals_eq = [locations_dict[k] for k, v in zone_dict.items() if v in {2}]
vals_sh = [locations_dict[k] for k, v in zone_dict.items() if v in {3}]

In [None]:
def categorical_colors(n, name='tab20'):
    cmap = plt.get_cmap(name, n)   # discretize into n distinct colors
    return [cmap(i) for i in range(n)]

In [None]:
colors_nh = categorical_colors(len(vals_nh), 'tab20b')  # or 'tab20', 'Dark2', etc.
colors_eq = categorical_colors(len(vals_eq), 'tab20')
colors_sh = categorical_colors(len(vals_sh), 'tab10')

In [None]:
# Plotting all northern hemisphere (NH) study regions
if __name__ == '__main__':
    # Define the list of locations (used in file names).
    location_list = keys_nh
    # Define a corresponding list of display names for the subplot headers.
    display_list = vals_nh
    # Optionally define a list of colors for each column.
    color_list = colors_nh
    # Altitudes can be provided in any order; they will be sorted (lowest pressure/highest altitude on top).
    alt_list = [68, 46, 32, 22]
    
    plot_scatter_corr(locations=location_list, display_names=display_list, altitudes=alt_list,
                      has_header=False,
                      x_range=(0, 350),   # Adjust as needed
                      y_range=(0, 8.2),    # Adjust as needed
                      save_path=None,
                     colors=color_list)
                     

In [None]:
# Plotting all tropical study regions
if __name__ == '__main__':
    # Define the list of locations (used in file names).
    location_list = keys_eq
    # Define a corresponding list of display names for the subplot headers.
    display_list = vals_eq
    # Optionally define a list of colors for each column.
    color_list = colors_eq
    # Altitudes can be provided in any order; they will be sorted (lowest pressure/highest altitude on top).
    alt_list = [68, 46, 32, 22]
    
    plot_scatter_corr(locations=location_list, display_names=display_list, altitudes=alt_list,
                      has_header=False,
                      x_range=(0, 350),   # Adjust as needed
                      y_range=(0, 8.2),    # Adjust as needed
                      save_path=None,
                     colors=color_list)
                     

In [None]:
#Plotting all southern hemisphere (SH) study regions
if __name__ == '__main__':
    # Define the list of locations (used in file names).
    location_list = keys_sh
    # Define a corresponding list of display names for the subplot headers.
    display_list = vals_sh
    # Optionally define a list of colors for each column.
    color_list = colors_sh
    # Altitudes can be provided in any order; they will be sorted (lowest pressure/highest altitude on top).
    alt_list = [68, 46, 32, 22]
    
    plot_scatter_corr(locations=location_list, display_names=display_list, altitudes=alt_list,
                      has_header=False,
                      x_range=(0, 350),   # Adjust as needed
                      y_range=(0, 8.2),    # Adjust as needed
                      save_path=None,
                     colors=color_list)
                     