<a href="https://colab.research.google.com/github/Jeong-HyunLee/stromatoporoid-reef/blob/main/stromatoporoid_reef_size_Camb_Dev_Suppv16.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# =============================================================================
#@title CELL 1: SETUP AND IMPORTS
# =============================================================================

# Install required packages (uncomment if needed)
# !pip install openpyxl geopandas shapely requests

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from scipy import stats
from scipy.interpolate import interp1d
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from sklearn.utils import resample
import statsmodels.api as sm
import warnings
import os
warnings.filterwarnings('ignore')

# Set up matplotlib for publication-quality vector figures
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['pdf.fonttype'] = 42  # TrueType fonts in PDF
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'  # Text as text in SVG

print("="*70)
print("STROMATOPOROID TURNOVER AND REEF MORPHOLOGY ANALYSIS")
print("WITH PEARSON AND SPEARMAN CORRELATIONS")
print("STAGE-LEVEL AND 5-MYR BIN ANALYSIS")
print("="*70)
print("\nLibraries loaded successfully!")

# Output directory
import os
OUTPUT_DIR = "./output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

STROMATOPOROID TURNOVER AND REEF MORPHOLOGY ANALYSIS
WITH PEARSON AND SPEARMAN CORRELATIONS
STAGE-LEVEL AND 5-MYR BIN ANALYSIS

Libraries loaded successfully!
Output directory: ./output


In [2]:
# =============================================================================
#@title CELL 2: GENERATE MACROSTRAT DATA (paleozoic_stage_data.csv, paleozoic_5myr_data.csv)
# =============================================================================

print("="*70)
print("GENERATING MACROSTRAT DATA")
print("="*70)

# Check if files already exist
macrostrat_stage_file = 'paleozoic_stage_data.csv'
macrostrat_5myr_file = 'paleozoic_5myr_data.csv'

if os.path.exists(macrostrat_stage_file) and os.path.exists(macrostrat_5myr_file):
    print(f"✓ {macrostrat_stage_file} already exists")
    print(f"✓ {macrostrat_5myr_file} already exists")
    print("Skipping Macrostrat data generation...")
else:
    print("Generating Macrostrat data from API...")

    import requests
    try:
        import geopandas as gpd
    except ImportError:
        print("Installing geopandas...")
        import subprocess
        subprocess.run(['pip', 'install', 'geopandas', '-q'])
        import geopandas as gpd

    # Define Paleozoic Period Age Ranges
    periods = {
        "Cambrian": {"start": 538.8, "end": 485.4, "color": "#7FA056"},
        "Ordovician": {"start": 485.4, "end": 443.8, "color": "#00a9ce"},
        "Silurian": {"start": 443.8, "end": 419.2, "color": "#b3e1af"},
        "Devonian": {"start": 419.2, "end": 358.9, "color": "#cb8c37"}
    }
    # Define stages (ICS 2023/04; Cambrian stage boundary ages are approximate in the ICS chart)
    cambrian_stages = {
        "Stage 10": (485.4, 489.5), "Jiangshanian": (489.5, 494.0), "Paibian": (494.0, 497.0),
        "Guzhangian": (497.0, 500.5), "Drumian": (500.5, 504.5), "Wuliuan": (504.5, 509.0),
        "Stage 4": (509.0, 514.0), "Stage 3": (514.0, 521.0), "Stage 2": (521.0, 529.0),
        "Fortunian": (529.0, 538.8)
    }
    ordovician_stages = {
        "Tremadocian": (477.7, 485.4), "Floian": (470.0, 477.7),
        "Dapingian": (467.3, 470.0), "Darriwilian": (458.4, 467.3),
        "Sandbian": (453.0, 458.4), "Katian": (445.2, 453.0),
        "Hirnantian": (443.8, 445.2)
    }
    silurian_stages = {
        "Rhuddanian": (440.8, 443.8), "Aeronian": (438.5, 440.8),
        "Telychian": (433.4, 438.5), "Sheinwoodian": (430.5, 433.4),
        "Homerian": (427.4, 430.5), "Gorstian": (425.6, 427.4),
        "Ludfordian": (423.0, 425.6), "Pridoli": (419.2, 423.0)
    }
    devonian_stages = {
        "Lochkovian": (410.8, 419.2), "Pragian": (407.6, 410.8),
        "Emsian": (393.3, 407.6), "Eifelian": (387.7, 393.3),
        "Givetian": (382.7, 387.7), "Frasnian": (372.2, 382.7),
        "Famennian": (358.9, 372.2)
    }

    # Create stages dataframe
    stages_data = []
    for stage, (end_age, start_age) in ordovician_stages.items():
        stages_data.append({"stage": stage, "start_age": start_age, "end_age": end_age,
                           "mid_age": (start_age + end_age) / 2, "period": "Ordovician"})
    for stage, (end_age, start_age) in silurian_stages.items():
        stages_data.append({"stage": stage, "start_age": start_age, "end_age": end_age,
                           "mid_age": (start_age + end_age) / 2, "period": "Silurian"})
    for stage, (end_age, start_age) in devonian_stages.items():
        stages_data.append({"stage": stage, "start_age": start_age, "end_age": end_age,
                           "mid_age": (start_age + end_age) / 2, "period": "Devonian"})
    stages_df = pd.DataFrame(stages_data)

    # Retrieve Macrostrat Data
    periods_to_fetch = ["Cambrian", "Ordovician", "Silurian", "Devonian"]
    all_units_list = []

    for period in periods_to_fetch:
        url = f"https://macrostrat.org/api/units?interval_name={period}&format=geojson&response=long"
        print(f"  Fetching data for {period}...")
        try:
            response = requests.get(url, timeout=60)
            if response.status_code == 200:
                data = response.json()
                features = data.get("success", {}).get("data", [])
                if features:
                    period_units = gpd.GeoDataFrame.from_features(features)
                    print(f"    Retrieved {len(period_units)} geological units")
                    period_units['source_period'] = period
                    all_units_list.append(period_units)
        except Exception as e:
            print(f"    Error fetching {period}: {e}")

    if all_units_list:
        units = pd.concat(all_units_list, ignore_index=True)
        print(f"  Combined dataset: {len(units)} geological units")

        # Process units
        try:
            if units.crs is None:
                units.set_crs(epsg=4326, inplace=True)
            units = units.to_crs(epsg=3857)
            if 'col_area' in units.columns:
                units['area_km2'] = pd.to_numeric(units['col_area'], errors='coerce')
            else:
                units['area_km2'] = units.geometry.area / 1e6
        except:
            units['area_km2'] = 100  # Default

        units['t_age'] = pd.to_numeric(units['t_age'], errors='coerce')
        units['b_age'] = pd.to_numeric(units['b_age'], errors='coerce')
        units['mid_age'] = (units['t_age'] + units['b_age']) / 2.0
        units.dropna(subset=['mid_age'], inplace=True)

        # Identify carbonates
        def check_if_carbonate(lithologies):
            if isinstance(lithologies, list):
                for lith in lithologies:
                    if isinstance(lith, dict) and 'type' in lith and 'carbonate' in str(lith['type']).lower():
                        return True
            elif isinstance(lithologies, str):
                return 'carbonate' in lithologies.lower()
            return False

        units['is_carbonate'] = units['lith'].apply(check_if_carbonate)
        carbonate_units = units[units['is_carbonate']].copy()

        # Assign stages
        all_stages = {**cambrian_stages, **ordovician_stages, **silurian_stages, **devonian_stages}
        def assign_stage(age):
            for stage, (end, start) in all_stages.items():
                if start >= age >= end:
                    return stage
            return None

        units['stage'] = units['mid_age'].apply(assign_stage)
        carbonate_units['stage'] = carbonate_units['mid_age'].apply(assign_stage)

        # Aggregate by stage
        stage_totals = units.groupby('stage')['area_km2'].sum().reset_index()
        stage_totals.rename(columns={'area_km2': 'total_area_km2'}, inplace=True)
        stage_carbonates = carbonate_units.groupby('stage')['area_km2'].sum().reset_index()
        stage_carbonates.rename(columns={'area_km2': 'carbonate_area_km2'}, inplace=True)

        stage_summary = pd.merge(stage_totals, stage_carbonates, on='stage', how='left')
        stage_summary['carbonate_area_km2'] = stage_summary['carbonate_area_km2'].fillna(0)
        stage_summary['carbonate_percentage'] = (stage_summary['carbonate_area_km2'] / stage_summary['total_area_km2']) * 100

        macrostrat_data = pd.merge(stages_df, stage_summary, on='stage', how='left')
        macrostrat_data = macrostrat_data.sort_values('start_age', ascending=False).reset_index(drop=True)
        macrostrat_data.to_csv(macrostrat_stage_file, index=False)
        print(f"  ✓ Saved {macrostrat_stage_file}")

        # 5 Myr bins
        max_age = 490
        min_age = 355
        manual_bins = np.arange(min_age, max_age + 5, 5)

        units['time_bin'] = pd.cut(units['mid_age'], bins=manual_bins, include_lowest=True, right=False)
        carbonate_units['time_bin'] = pd.cut(carbonate_units['mid_age'], bins=manual_bins, include_lowest=True, right=False)

        macro_all_5myr = units.groupby('time_bin')['area_km2'].sum().reset_index()
        macro_all_5myr.rename(columns={'area_km2': 'total_area_km2'}, inplace=True)
        macro_carb_5myr = carbonate_units.groupby('time_bin')['area_km2'].sum().reset_index()
        macro_carb_5myr.rename(columns={'area_km2': 'carbonate_area_km2'}, inplace=True)

        macrostrat_5myr = pd.merge(macro_all_5myr, macro_carb_5myr, on='time_bin', how='left')
        macrostrat_5myr['carbonate_area_km2'] = macrostrat_5myr['carbonate_area_km2'].fillna(0)
        macrostrat_5myr['carbonate_percentage'] = (macrostrat_5myr['carbonate_area_km2'] / macrostrat_5myr['total_area_km2']) * 100
        macrostrat_5myr['bin_mid'] = macrostrat_5myr['time_bin'].apply(lambda x: (x.left + x.right) / 2 if pd.notna(x) else np.nan)
        macrostrat_5myr.to_csv(macrostrat_5myr_file, index=False)
        print(f"  ✓ Saved {macrostrat_5myr_file}")
    else:
        print("  WARNING: Could not fetch Macrostrat data. Creating placeholder files...")
        # Create placeholder files
        pd.DataFrame(columns=['stage', 'total_area_km2', 'carbonate_area_km2', 'carbonate_percentage']).to_csv(macrostrat_stage_file, index=False)
        pd.DataFrame(columns=['bin_mid', 'total_area_km2', 'carbonate_area_km2', 'carbonate_percentage']).to_csv(macrostrat_5myr_file, index=False)

print("✓ Macrostrat data ready")

GENERATING MACROSTRAT DATA
Generating Macrostrat data from API...
  Fetching data for Cambrian...
    Retrieved 2480 geological units
  Fetching data for Ordovician...
    Retrieved 2943 geological units
  Fetching data for Silurian...
    Retrieved 1715 geological units
  Fetching data for Devonian...
    Retrieved 2793 geological units
  Combined dataset: 9931 geological units
  ✓ Saved paleozoic_stage_data.csv
  ✓ Saved paleozoic_5myr_data.csv
✓ Macrostrat data ready


In [3]:
# =============================================================================
#@title CELL 3: GENERATE PARED REEF DATA (reef stage and 5myr files)
# =============================================================================

print("\n" + "="*70)
print("GENERATING PARED REEF DATA")
print("="*70)

reef_stage_file = 'cambrian_devonian_reef_data_stage_for_analysis.csv'
reef_5myr_file = 'cambrian_devonian_reef_data_5myr_for_analysis.csv'
pared_source_file = 'PARED_reef_All_numerical.csv'

# Force regeneration to include new variable
if os.path.exists(reef_stage_file) and os.path.exists(reef_5myr_file) and False:
    print(f"✓ {reef_stage_file} already exists")
    print(f"✓ {reef_5myr_file} already exists")
    print("Skipping PARED reef data generation...")
else:
    # Check if source file exists
    if not os.path.exists(pared_source_file):
        print(f"Source file '{pared_source_file}' not found.")
        print("Please upload it now:")
        try:
            from google.colab import files
            uploaded_pared = files.upload()
            uploaded_name = list(uploaded_pared.keys())[0]
            if uploaded_name != pared_source_file:
                os.rename(uploaded_name, pared_source_file)
        except ImportError:
            raise FileNotFoundError(f"Please place '{pared_source_file}' in the current directory.")

    print(f"Processing {pared_source_file}...")

    # Load PARED data
    try:
        pared_df = pd.read_csv(pared_source_file, encoding='utf-8')
    except UnicodeDecodeError:
        try:
            pared_df = pd.read_csv(pared_source_file, encoding='latin-1')
        except:
            pared_df = pd.read_csv(pared_source_file, encoding='cp1252')

    # --- NEW CODE: Calculate Paired Ratio (Log Difference) ---
    # Ensure numeric
    pared_df['thickness'] = pd.to_numeric(pared_df['thickness'], errors='coerce')
    pared_df['width'] = pd.to_numeric(pared_df['width'], errors='coerce')

    # Calculate difference (Log Thickness - Log Width)
    # This automatically becomes NaN if either value is missing
    pared_df['t_w_log_ratio'] = pared_df['thickness'] - pared_df['width']
    # ---------------------------------------------------------

    def analyze_pared_data(dataframe, bin_definitions, analysis_type_label):
        """Calculate statistics using OVERLAP method"""
        results = []
        # Added 't_w_log_ratio' to variables
        variables = ['thickness', 'width', 'extension', 't_w_log_ratio']

        for bin_def in bin_definitions:
            s_start = bin_def['start_ma']
            s_end = bin_def['end_ma']
            name = bin_def['time_identifier']

            # OVERLAP LOGIC
            mask = (dataframe['min_ma'] < s_end) & (dataframe['max_ma'] > s_start)
            subset = dataframe[mask]

            row_data = {
                'bin_center': (s_start + s_end) / 2.0,
                'start_age': s_start,
                'end_age': s_end,
                'reef_count': len(subset),
            }

            for var in variables:
                data = subset[var].dropna() if var in subset.columns else pd.Series()
                if len(data) > 0:
                    row_data[f'{var}_mean'] = data.mean()
                    row_data[f'{var}_std'] = data.std()
                    row_data[f'{var}_stderr'] = data.sem()
                    row_data[f'{var}_median'] = data.median()
                    row_data[f'{var}_count'] = len(data)
                else:
                    for suffix in ['mean', 'std', 'stderr', 'median']:
                        row_data[f'{var}_{suffix}'] = np.nan
                    row_data[f'{var}_count'] = 0

            row_data['analysis_type'] = analysis_type_label
            row_data['time_identifier'] = name
            row_data['name'] = name
            row_data['start_ma'] = s_start
            row_data['end_ma'] = s_end
            row_data['midpoint_ma'] = (s_start + s_end) / 2.0
            row_data['duration_myr'] = round(s_end - s_start, 1)

            results.append(row_data)

        return pd.DataFrame(results)

    # Stage definitions (PRESERVED FROM ORIGINAL)
    stages_data = [
        # Devonian
        {'time_identifier': 'Famennian', 'start_ma': 358.9, 'end_ma': 372.2},
        {'time_identifier': 'Frasnian', 'start_ma': 372.2, 'end_ma': 382.7},
        {'time_identifier': 'Givetian', 'start_ma': 382.7, 'end_ma': 387.7},
        {'time_identifier': 'Eifelian', 'start_ma': 387.7, 'end_ma': 393.3},
        {'time_identifier': 'Emsian', 'start_ma': 393.3, 'end_ma': 407.6},
        {'time_identifier': 'Pragian', 'start_ma': 407.6, 'end_ma': 410.8},
        {'time_identifier': 'Lochkovian', 'start_ma': 410.8, 'end_ma': 419.2},

        # Silurian
        {'time_identifier': 'Pridoli', 'start_ma': 419.2, 'end_ma': 423.0},
        {'time_identifier': 'Ludfordian', 'start_ma': 423.0, 'end_ma': 425.6},
        {'time_identifier': 'Gorstian', 'start_ma': 425.6, 'end_ma': 427.4},
        {'time_identifier': 'Homerian', 'start_ma': 427.4, 'end_ma': 430.5},
        {'time_identifier': 'Sheinwoodian', 'start_ma': 430.5, 'end_ma': 433.4},
        {'time_identifier': 'Telychian', 'start_ma': 433.4, 'end_ma': 438.5},
        {'time_identifier': 'Aeronian', 'start_ma': 438.5, 'end_ma': 440.8},
        {'time_identifier': 'Rhuddanian', 'start_ma': 440.8, 'end_ma': 443.8},

        # Ordovician
        {'time_identifier': 'Hirnantian', 'start_ma': 443.8, 'end_ma': 445.2},
        {'time_identifier': 'Katian', 'start_ma': 445.2, 'end_ma': 453.0},
        {'time_identifier': 'Sandbian', 'start_ma': 453.0, 'end_ma': 458.4},
        {'time_identifier': 'Darriwilian', 'start_ma': 458.4, 'end_ma': 467.3},
        {'time_identifier': 'Dapingian', 'start_ma': 467.3, 'end_ma': 470.0},
        {'time_identifier': 'Floian', 'start_ma': 470.0, 'end_ma': 477.7},
        {'time_identifier': 'Tremadocian', 'start_ma': 477.7, 'end_ma': 485.4},

        # Cambrian (ICS 2023/04; approximate boundaries)
        {'time_identifier': 'Stage 10', 'start_ma': 485.4, 'end_ma': 489.5},
        {'time_identifier': 'Jiangshanian', 'start_ma': 489.5, 'end_ma': 494.0},
        {'time_identifier': 'Paibian', 'start_ma': 494.0, 'end_ma': 497.0},
        {'time_identifier': 'Guzhangian', 'start_ma': 497.0, 'end_ma': 500.5},
        {'time_identifier': 'Drumian', 'start_ma': 500.5, 'end_ma': 504.5},
        {'time_identifier': 'Wuliuan', 'start_ma': 504.5, 'end_ma': 509.0},
        {'time_identifier': 'Stage 4', 'start_ma': 509.0, 'end_ma': 514.0},
        {'time_identifier': 'Stage 3', 'start_ma': 514.0, 'end_ma': 521.0},
        {'time_identifier': 'Stage 2', 'start_ma': 521.0, 'end_ma': 529.0},
        {'time_identifier': 'Fortunian', 'start_ma': 529.0, 'end_ma': 538.8},
    ]

    # 5-Myr bins (Cambrian–Devonian)
    # We avoid spilling into pre-Cambrian by using partial bins at both ends:
    #   358.9–360.0 and 535.0–538.8.
    bins_5myr = []

    # bottom partial bin (Devonian top is 358.9)
    bins_5myr.append({'time_identifier': "358.9-360.0 Ma", 'start_ma': 358.9, 'end_ma': 360.0})

    # regular 5 Myr bins
    age = 360.0
    while age < 535.0:
        bins_5myr.append({
            'time_identifier': f"{int(age)}-{int(age+5)} Ma",
            'start_ma': float(age),
            'end_ma': float(age + 5.0)
        })
        age += 5.0

    # top partial bin (Cambrian base is 538.8)
    bins_5myr.append({'time_identifier': "535.0-538.8 Ma", 'start_ma': 535.0, 'end_ma': 538.8})


    # Generate stage data
    df_stages = analyze_pared_data(pared_df, stages_data, 'Geological_Stages')
    df_stages.to_csv(reef_stage_file, index=False)
    print(f"  ✓ Generated {reef_stage_file}")

    # Generate 5myr data
    df_5myr = analyze_pared_data(pared_df, bins_5myr, '5_myr_bins')
    df_5myr.to_csv(reef_5myr_file, index=False)
    print(f"  ✓ Generated {reef_5myr_file}")

print("✓ PARED reef data ready")


GENERATING PARED REEF DATA
Source file 'PARED_reef_All_numerical.csv' not found.
Please upload it now:


Saving PARED_reef_All_numerical.csv to PARED_reef_All_numerical.csv
Processing PARED_reef_All_numerical.csv...
  ✓ Generated cambrian_devonian_reef_data_stage_for_analysis.csv
  ✓ Generated cambrian_devonian_reef_data_5myr_for_analysis.csv
✓ PARED reef data ready


In [4]:
# =============================================================================
# @title CELL 4: GENERATE PBDB DIVERSITY AND OCCURRENCE DATA (GENERIC 540.0 BINS; Cam–Dev)
#   - Self-contained (defines smart_read_pbdb and all helpers)
#   - Works whether PBDB files are in ./output or current directory
# =============================================================================
import pandas as pd
import os
import numpy as np
from pathlib import Path

# Colab-safe upload
try:
    from google.colab import files  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False

# ==========================================
# 1. SETUP: Create Time References (Cambrian–Devonian)
#   NOTE: Cambrian stage boundary ages are approximate in the ICS chart.
# ==========================================
ics_data = """stage,series,period,start_ma,end_ma
Fortunian,Lower Cambrian,Cambrian,538.8,529.0
Stage 2,Lower Cambrian,Cambrian,529.0,521.0
Stage 3,Lower Cambrian,Cambrian,521.0,514.0
Stage 4,Lower Cambrian,Cambrian,514.0,509.0
Wuliuan,Miaolingian,Cambrian,509.0,504.5
Drumian,Miaolingian,Cambrian,504.5,500.5
Guzhangian,Miaolingian,Cambrian,500.5,497.0
Paibian,Furongian,Cambrian,497.0,494.0
Jiangshanian,Furongian,Cambrian,494.0,489.5
Stage 10,Furongian,Cambrian,489.5,485.4
Tremadocian,Lower Ordovician,Ordovician,485.4,477.7
Floian,Lower Ordovician,Ordovician,477.7,470.0
Dapingian,Middle Ordovician,Ordovician,470.0,467.3
Darriwilian,Middle Ordovician,Ordovician,467.3,458.4
Sandbian,Upper Ordovician,Ordovician,458.4,453.0
Katian,Upper Ordovician,Ordovician,453.0,445.2
Hirnantian,Upper Ordovician,Ordovician,445.2,443.8
Rhuddanian,Llandovery,Silurian,443.8,440.8
Aeronian,Llandovery,Silurian,440.8,438.5
Telychian,Llandovery,Silurian,438.5,433.4
Sheinwoodian,Wenlock,Silurian,433.4,430.5
Homerian,Wenlock,Silurian,430.5,427.4
Gorstian,Ludlow,Silurian,427.4,425.6
Ludfordian,Ludlow,Silurian,425.6,423.0
Pridoli,Pridoli,Silurian,423.0,419.2
Lochkovian,Lower Devonian,Devonian,419.2,410.8
Pragian,Lower Devonian,Devonian,410.8,407.6
Emsian,Lower Devonian,Devonian,407.6,393.3
Eifelian,Middle Devonian,Devonian,393.3,387.7
Givetian,Middle Devonian,Devonian,387.7,382.7
Frasnian,Upper Devonian,Devonian,382.7,372.2
Famennian,Upper Devonian,Devonian,372.2,358.9
"""

ICS_CSV = Path("ICS_stage_boundaries.csv")
ICS_CSV.write_text(ics_data, encoding="utf-8")

# Generic 5-Myr bins aligned to 540.0 Ma to cover full Cambrian–Devonian window.
def create_5myr_bins(start_ma=540.0, end_ma=358.9, step=5.0):
    bins = []
    current = float(start_ma)

    # Iterate downward in time (older -> younger).
    # We want to include the partial bin that contains end_ma (358.9).
    while True:
        top = current
        bottom = current - step
        bins.append({
            "bin_label": f"{top:.1f}-{bottom:.1f}",
            "bin_top": float(top),
            "bin_bottom": float(bottom)
        })

        # Stop once we've created the bin that contains end_ma.
        # Example: end_ma=358.9 lies in 360–355, so stop after creating 360–355.
        if bottom < end_ma:
            break

        current = bottom

        # Safety: avoid runaway if end_ma is mis-set
        if len(bins) > 1000:
            raise RuntimeError("Too many bins created; check start_ma/end_ma.")
    return pd.DataFrame(bins)

bins_5myr_df = create_5myr_bins()
stages_df = pd.read_csv(ICS_CSV)

print("Created Generic 5-Myr bins (Aligned to 540.0):")
display(bins_5myr_df.head())
display(bins_5myr_df.tail())

# ==========================================
# 2. FILE CHECK
# ==========================================
print("\n--- CHECKING FILE SYSTEM ---")

# Look in ./output first (your pipeline convention), then current directory
found_files = []
for pattern in ["./output/pbdb_data_*.csv", "./pbdb_data_*.csv"]:
    found_files.extend(sorted([str(p) for p in Path(".").glob(pattern.replace("./", ""))]))

# De-duplicate while preserving order
seen = set()
found_files = [f for f in found_files if not (f in seen or seen.add(f))]

if not found_files:
    print("[WARN] No pbdb_data_*.csv found in ./output or current directory.")
    if IN_COLAB:
        print("Please upload your RAW PBDB files now (pbdb_data_*.csv).")
        uploaded = files.upload()
        # after upload, they land in current directory
        found_files = sorted([f for f in os.listdir(".") if f.lower().startswith("pbdb_data_") and f.lower().endswith(".csv")])
        print(f"  Uploaded {len(found_files)} pbdb_data_*.csv files.")
    else:
        raise FileNotFoundError("No pbdb_data_*.csv found. Place files in ./output or the current directory and re-run Cell 4.")

print(f"Found {len(found_files)} files:")
print(found_files)

# ==========================================
# 3. HELPER FUNCTIONS (Self-contained)
# ==========================================
def smart_read_pbdb(file_path: str):
    """
    PBDB downloads sometimes contain metadata rows before the header.
    Detect the header row by locating 'occurrence_no'.
    """
    header_row = None
    try:
        with open(file_path, "r", encoding="utf-8", errors="replace") as f:
            lines = [f.readline() for _ in range(100)]
        for i, line in enumerate(lines):
            if "occurrence_no" in (line or "").lower():
                header_row = i
                break
        if header_row is None:
            print(f"    CRITICAL: Could not find 'occurrence_no' header in {file_path}")
            return None
        return pd.read_csv(file_path, header=header_row)
    except Exception as e:
        print(f"    Error reading {file_path}: {e}")
        return None

def extract_genus(accepted_name):
    if pd.isna(accepted_name):
        return None
    return str(accepted_name).split(" ")[0]

def get_stage_from_age(age, stages_df):
    # Strict containment: start_ma >= age > end_ma (Ma decreasing through time)
    match = stages_df[(stages_df["start_ma"] >= age) & (stages_df["end_ma"] < age)]
    if match.empty:
        # Snap very youngest boundary if needed
        if np.isfinite(age) and abs(age - stages_df["end_ma"].min()) < 1e-6:
            return stages_df.iloc[-1]["stage"]
        return None
    return match.iloc[0]["stage"]

def get_bin_from_age(age, bins_df):
    # Strict containment
    match = bins_df[(bins_df["bin_top"] >= age) & (bins_df["bin_bottom"] < age)]

    # Tolerance snap for slightly older-than-top ages (rare rounding issues)
    if match.empty:
        max_top = bins_df["bin_top"].max()
        if age > max_top and (age - max_top) < 2.0:
            return bins_df.iloc[0]["bin_label"]

    if not match.empty:
        return match.iloc[0]["bin_label"]
    return None

def process_midpoint(df, group_name, reference_df, ref_type="stage"):
    if "accepted_name" not in df.columns:
        print("    Error: 'accepted_name' column missing.")
        return None

    # Required PBDB age columns
    if ("max_ma" not in df.columns) or ("min_ma" not in df.columns):
        print("    Error: PBDB file missing 'max_ma' and/or 'min_ma' columns.")
        return None

    df["genus_name"] = df["accepted_name"].apply(extract_genus)
    df = df.dropna(subset=["genus_name"])

    # Midpoint Logic
    df["midpoint"] = (pd.to_numeric(df["max_ma"], errors="coerce") + pd.to_numeric(df["min_ma"], errors="coerce")) / 2.0
    df = df.dropna(subset=["midpoint"])

    if ref_type == "stage":
        df["assigned_interval"] = df["midpoint"].apply(lambda x: get_stage_from_age(x, reference_df))
        merge_col = "stage"
    else:
        df["assigned_interval"] = df["midpoint"].apply(lambda x: get_bin_from_age(x, reference_df))
        merge_col = "bin_label"

    df = df.dropna(subset=["assigned_interval"])

    # Aggregation
    genus_counts = df.groupby("assigned_interval")["genus_name"].nunique()
    genus_counts.name = f"{group_name}_genus"

    occ_counts = df.groupby("assigned_interval")["occurrence_no"].nunique()
    occ_counts.name = f"{group_name}_occ"

    # Merging
    final_df = reference_df.copy()
    final_df = final_df.merge(genus_counts, left_on=merge_col, right_index=True, how="left")
    final_df = final_df.merge(occ_counts, left_on=merge_col, right_index=True, how="left")

    # Cleanup
    cols_to_fix = [f"{group_name}_genus", f"{group_name}_occ"]
    for c in cols_to_fix:
        if c in final_df.columns:
            final_df[c] = final_df[c].fillna(0).astype(int)

    return final_df

# ==========================================
# 4. MAIN EXECUTION LOOP
# ==========================================
print("\n--- STARTING ANALYSIS (Generic 540.0 Bins; Cam–Dev) ---")

for file_path in found_files:
    filename = os.path.basename(file_path)
    group_name = filename.replace("pbdb_data_", "").replace(".csv", "")
    print(f"\nAnalyzing {group_name}...")

    df_pbdb = smart_read_pbdb(file_path)
    if df_pbdb is None:
        continue

    try:
        # A. Stages
        stage_df = process_midpoint(df_pbdb.copy(), group_name, stages_df, "stage")
        if stage_df is not None:
            out_stage = f"pbdb_{group_name}_midpoint_stages.csv"
            stage_df.to_csv(out_stage, index=False)
            print(f"  -> Created: {out_stage}")

        # B. 5-Myr Bins
        bin_df = process_midpoint(df_pbdb.copy(), group_name, bins_5myr_df, "bin")
        if bin_df is not None:
            out_bin = f"pbdb_{group_name}_midpoint_5myr_bins.csv"
            bin_df.to_csv(out_bin, index=False)
            print(f"  -> Created: {out_bin}")

    except Exception as e:
        print(f"  Error: {e}")

print("\n" + "="*40)
print("DONE! New 540.0-aligned bin files created (Cambrian–Devonian).")
print("="*40)


Created Generic 5-Myr bins (Aligned to 540.0):


Unnamed: 0,bin_label,bin_top,bin_bottom
0,540.0-535.0,540.0,535.0
1,535.0-530.0,535.0,530.0
2,530.0-525.0,530.0,525.0
3,525.0-520.0,525.0,520.0
4,520.0-515.0,520.0,515.0


Unnamed: 0,bin_label,bin_top,bin_bottom
32,380.0-375.0,380.0,375.0
33,375.0-370.0,375.0,370.0
34,370.0-365.0,370.0,365.0
35,365.0-360.0,365.0,360.0
36,360.0-355.0,360.0,355.0



--- CHECKING FILE SYSTEM ---
[WARN] No pbdb_data_*.csv found in ./output or current directory.
Please upload your RAW PBDB files now (pbdb_data_*.csv).


Saving pbdb_data_Actinostromatida.csv to pbdb_data_Actinostromatida.csv
Saving pbdb_data_Amphiporida.csv to pbdb_data_Amphiporida.csv
Saving pbdb_data_Clathrodictyida.csv to pbdb_data_Clathrodictyida.csv
Saving pbdb_data_Labechiida.csv to pbdb_data_Labechiida.csv
Saving pbdb_data_Rugosa.csv to pbdb_data_Rugosa.csv
Saving pbdb_data_Stromatoporellida.csv to pbdb_data_Stromatoporellida.csv
Saving pbdb_data_Stromatoporida.csv to pbdb_data_Stromatoporida.csv
Saving pbdb_data_Syringostromatida.csv to pbdb_data_Syringostromatida.csv
Saving pbdb_data_Tabulata.csv to pbdb_data_Tabulata.csv
  Uploaded 9 pbdb_data_*.csv files.
Found 9 files:
['pbdb_data_Actinostromatida.csv', 'pbdb_data_Amphiporida.csv', 'pbdb_data_Clathrodictyida.csv', 'pbdb_data_Labechiida.csv', 'pbdb_data_Rugosa.csv', 'pbdb_data_Stromatoporellida.csv', 'pbdb_data_Stromatoporida.csv', 'pbdb_data_Syringostromatida.csv', 'pbdb_data_Tabulata.csv']

--- STARTING ANALYSIS (Generic 540.0 Bins; Cam–Dev) ---

Analyzing Actinostromatida

In [5]:
# =============================================================================
#@title CELL 5: UPLOAD ENVIRONMENT DATA FILES (Google Colab) - CONDITIONAL
# =============================================================================

from google.colab import files
import io

# Define required files
required_files = [
    'temperature.csv',
    'DO.csv',
    'oxygen.csv',
    'sealevel.csv',
    'd13C_5Myr_Cam-Dev.csv',
    'd13C_stage_binned_Cam-Dev.csv'
]
# Check which files are missing
missing_files = [f for f in required_files if not os.path.exists(f)]

if missing_files:
    print("The following files are missing and need to be uploaded:")
    for i, f in enumerate(missing_files, 1):
        print(f"  {i}. {f}")
    print("\nClick 'Choose Files' and select the missing files...")

    uploaded = files.upload()

    print(f"\n✓ Uploaded {len(uploaded)} files:")
    for fn in uploaded.keys():
        print(f"  - {fn}")
else:
    print("✓ All required files already exist in the folder")
    uploaded = {}
    # Load existing files into uploaded dict for compatibility
    for f in required_files:
        with open(f, 'rb') as file:
            uploaded[f] = file.read()

print(f"\n✓ {len(required_files)} files ready for analysis")


The following files are missing and need to be uploaded:
  1. temperature.csv
  2. DO.csv
  3. oxygen.csv
  4. sealevel.csv
  5. d13C_5Myr_Cam-Dev.csv
  6. d13C_stage_binned_Cam-Dev.csv

Click 'Choose Files' and select the missing files...


Saving d13C_5Myr_Cam-Dev.csv to d13C_5Myr_Cam-Dev.csv
Saving d13C_stage_binned_Cam-Dev.csv to d13C_stage_binned_Cam-Dev.csv
Saving DO.csv to DO.csv
Saving oxygen.csv to oxygen.csv
Saving sealevel.csv to sealevel.csv
Saving temperature.csv to temperature.csv

✓ Uploaded 6 files:
  - d13C_5Myr_Cam-Dev.csv
  - d13C_stage_binned_Cam-Dev.csv
  - DO.csv
  - oxygen.csv
  - sealevel.csv
  - temperature.csv

✓ 6 files ready for analysis


In [6]:
# =============================================================================
# @title CELL 6: LOAD AND PROCESS DATA (FROM CONTENT FOLDER)
# =============================================================================
import pandas as pd
import os

# Helper: Find file path by keywords in the current directory
def get_file_path_robust(keywords, search_dir='.'):
    """
    Finds a filename in the search_dir that contains ALL keywords (case-insensitive).
    """
    try:
        files = os.listdir(search_dir)
    except FileNotFoundError:
        print(f"  ! Error: Directory '{search_dir}' not found.")
        return None

    for filename in files:
        if not filename.endswith(('.csv', '.xlsx', '.xls')):
            continue
        # Check if ALL keywords are present in this filename
        if all(str(k).lower() in filename.lower() for k in keywords):
            return os.path.join(search_dir, filename)
    return None

def load_and_merge_from_disk(target_list, merge_cols, dataset_name):
    """
    Iterates through a list of target filenames, finds them on disk, and merges them.
    """
    merged_df = None
    print(f"\nProcessing {dataset_name}...")

    for target in target_list:
        # Extract keywords from the target filename
        clean_name = target.replace('.csv', '').replace('.xlsx', '')
        keywords = [k for k in clean_name.split('_') if k and k.lower() != 'pbdb']

        # Search for the file
        filepath = get_file_path_robust(keywords)

        if filepath:
            print(f"  ✓ Found: {filepath}")
            try:
                df = pd.read_csv(filepath)
                df.columns = df.columns.str.strip()

                if merged_df is None:
                    merged_df = df
                else:
                    merged_df = pd.merge(merged_df, df, on=merge_cols, how='outer')
            except Exception as e:
                print(f"  ! Error reading {filepath}: {e}")
        else:
            # Retry with minimal keywords (Taxon + Resolution)
            # This handles cases where the user filename might differ slightly from the instruction
            taxon = next((k for k in keywords if k.lower() not in ['midpoint', 'stages', '5myr', 'bins']), None)
            resolution = '5myr' if '5myr' in target.lower() else 'stages'

            if taxon:
                filepath_retry = get_file_path_robust([taxon, resolution])
                if filepath_retry:
                    print(f"  ✓ Found (fallback): {filepath_retry}")
                    try:
                        df = pd.read_csv(filepath_retry)
                        df.columns = df.columns.str.strip()
                        if merged_df is None: merged_df = df
                        else: merged_df = pd.merge(merged_df, df, on=merge_cols, how='outer')
                    except Exception as e:
                        print(f"  ! Error reading {filepath_retry}: {e}")
                else:
                     print(f"  x Could not find file for: {taxon} ({resolution})")
            else:
                 print(f"  x Could not find file matching: {keywords}")

    return merged_df

# Common merge columns
stage_merge_cols = ['stage', 'series', 'period', 'start_ma', 'end_ma']
bin_merge_cols = ['bin_label', 'bin_top', 'bin_bottom']

# =============================================================================
# 1. Substitute Stromatoporoid Data (Stage)
# =============================================================================
strom_stage_targets = [
    "pbdb_Stromatoporida_midpoint_stages.csv",
    "pbdb_Labechiida_midpoint_stages.csv",
    "pbdb_Actinostromatida_midpoint_stages.csv",
    "pbdb_clathrodictyida_midpoint_stages.csv",
    "pbdb_Syringostromatida_midpoint_stages.csv",
    "pbdb_Stromatoporellida_midpoint_stages.csv",
    "pbdb_Amphiporida_midpoint_stages.csv"
]

strom_df = load_and_merge_from_disk(strom_stage_targets, stage_merge_cols, "Stromatoporoid (Stage)")

if strom_df is not None:
    strom_df = strom_df.fillna(0)
    # Calculate Totals
    genus_cols = [c for c in strom_df.columns if c.endswith('_genus')]
    occ_cols = [c for c in strom_df.columns if c.endswith('_occ')]
    strom_df['Total_genus'] = strom_df[genus_cols].sum(axis=1)
    strom_df['Total_occ'] = strom_df[occ_cols].sum(axis=1)

    # === FILTER: Remove empty stages ===
    dropped_s = len(strom_df[strom_df['Total_occ'] == 0])
    strom_df = strom_df[strom_df['Total_occ'] > 0].copy()
    if dropped_s > 0:
        print(f"  -> Dropped {dropped_s} Stromatoporoid stage(s) with zero occurrences.")
    # ===================================

    print(f"  -> Merged Stromatoporoid Stages. Rows: {len(strom_df)}")
else:
    print("  ! Error: Stromatoporoid dataframe is empty.")

# =============================================================================
# 2. Substitute Coral Data (Stage)
# =============================================================================
coral_stage_targets = [
    "pbdb_tabulata_midpoint_stages.csv",
    "pbdb_Rugosa_midpoint_stages.csv"
]

coral_df = load_and_merge_from_disk(coral_stage_targets, stage_merge_cols, "Coral (Stage)")

if coral_df is not None:
    coral_df = coral_df.fillna(0)

    # Calculate Totals (Added for filtering)
    c_genus_cols = [c for c in coral_df.columns if c.endswith('_genus')]
    c_occ_cols = [c for c in coral_df.columns if c.endswith('_occ')]
    coral_df['Total_genus'] = coral_df[c_genus_cols].sum(axis=1)
    coral_df['Total_occ'] = coral_df[c_occ_cols].sum(axis=1)

    # === FILTER: Remove empty stages ===
    dropped_c = len(coral_df[coral_df['Total_occ'] == 0])
    coral_df = coral_df[coral_df['Total_occ'] > 0].copy()
    if dropped_c > 0:
        print(f"  -> Dropped {dropped_c} Coral stage(s) with zero occurrences.")
    # ===================================

    print(f"  -> Merged Coral Stages. Rows: {len(coral_df)}")

# =============================================================================
# 3. Create New Dataset for 5-Myr Bins
# =============================================================================
# Stromatoporoids
strom_bin_targets = [
    "pbdb_Stromatoporida_midpoint_5myr_bins.csv",
    "pbdb_Labechiida_midpoint_5myr_bins.csv",
    "pbdb_Actinostromatida_midpoint_5myr_bins.csv",
    "pbdb_clathrodictyida_midpoint_5myr_bins.csv",
    "pbdb_Syringostromatida_midpoint_5myr_bins.csv",
    "pbdb_Stromatoporellida_midpoint_5myr_bins.csv",
    "pbdb_Amphiporida_midpoint_5myr_bins.csv"
]

strom_5myr_df = load_and_merge_from_disk(strom_bin_targets, bin_merge_cols, "Stromatoporoid (5-Myr)")

if strom_5myr_df is not None:
    strom_5myr_df = strom_5myr_df.fillna(0)
    genus_cols = [c for c in strom_5myr_df.columns if c.endswith('_genus')]
    occ_cols = [c for c in strom_5myr_df.columns if c.endswith('_occ')]
    strom_5myr_df['Total_genus'] = strom_5myr_df[genus_cols].sum(axis=1)
    strom_5myr_df['Total_occ'] = strom_5myr_df[occ_cols].sum(axis=1)

    # === FILTER: Remove empty bins ===
    dropped_s5 = len(strom_5myr_df[strom_5myr_df['Total_occ'] == 0])
    strom_5myr_df = strom_5myr_df[strom_5myr_df['Total_occ'] > 0].copy()
    if dropped_s5 > 0:
        print(f"  -> Dropped {dropped_s5} Stromatoporoid 5-Myr bin(s) with zero occurrences.")
    # =================================

    print(f"  -> Merged Stromatoporoid 5-Myr Bins. Rows: {len(strom_5myr_df)}")

# Corals
coral_bin_targets = [
    "pbdb_tabulata_midpoint_5myr_bins.csv",
    "pbdb_Rugosa_midpoint_5myr_bins.csv"
]
coral_5myr_df = load_and_merge_from_disk(coral_bin_targets, bin_merge_cols, "Coral (5-Myr)")

if coral_5myr_df is not None:
    coral_5myr_df = coral_5myr_df.fillna(0)

    # Calculate Totals (Added for filtering)
    c5_genus_cols = [c for c in coral_5myr_df.columns if c.endswith('_genus')]
    c5_occ_cols = [c for c in coral_5myr_df.columns if c.endswith('_occ')]
    coral_5myr_df['Total_genus'] = coral_5myr_df[c5_genus_cols].sum(axis=1)
    coral_5myr_df['Total_occ'] = coral_5myr_df[c5_occ_cols].sum(axis=1)

    # === FILTER: Remove empty bins ===
    dropped_c5 = len(coral_5myr_df[coral_5myr_df['Total_occ'] == 0])
    coral_5myr_df = coral_5myr_df[coral_5myr_df['Total_occ'] > 0].copy()
    if dropped_c5 > 0:
        print(f"  -> Dropped {dropped_c5} Coral 5-Myr bin(s) with zero occurrences.")
    # =================================

    print(f"  -> Merged Coral 5-Myr Bins. Rows: {len(coral_5myr_df)}")

# =============================================================================
# 4. Load Remaining Contextual Data
# =============================================================================
print("\nLoading Contextual Data...")

def quick_load(keywords):
    path = get_file_path_robust(keywords)
    return pd.read_csv(path) if path else pd.DataFrame()

reef_df = quick_load(["reef_data", "stage"])
reef_5myr_df = quick_load(["reef_data", "5myr"])

macro_stage = quick_load(["paleozoic_stage_data"])
if 'stage' in macro_stage.columns: macro_stage['stage'] = macro_stage['stage'].replace('Pridolian', 'Pridoli')

macro_5myr = quick_load(["paleozoic_5myr_data"])

# Environmental
env_files = {
    'temperature': 'temperature',
    'do': 'DO',
    'oxygen': 'oxygen',
    'sealevel': 'sealevel',
    # δ13C files (already binned; DO NOT re-bin/interpolate these)
    'd13c_5myr': 'd13C_5Myr_Cam-Dev',
    'd13c_stage': 'd13C_stage_binned_Cam-Dev'
}
env_dfs = {}
for var, key in env_files.items():
    df = quick_load([key])
    if not df.empty: df.columns = df.columns.str.strip().str.replace('\ufeff', '')
    env_dfs[var] = df

temp_df = env_dfs.get('temperature', pd.DataFrame())
do_df = env_dfs.get('do', pd.DataFrame())
oxygen_df = env_dfs.get('oxygen', pd.DataFrame())
sealevel_df = env_dfs.get('sealevel', pd.DataFrame())
d13c_5myr_df = env_dfs.get('d13c_5myr', pd.DataFrame())
d13c_stage_df = env_dfs.get('d13c_stage', pd.DataFrame())
print("\n✓ Data loading complete.")

print("\nSaving intermediate merged datasets...")

if 'strom_df' in locals() and strom_df is not None:
    strom_df.to_csv(f"{OUTPUT_DIR}/intermediate_strom_stage.csv", index=False)
if 'coral_df' in locals() and coral_df is not None:
    coral_df.to_csv(f"{OUTPUT_DIR}/intermediate_coral_stage.csv", index=False)

if 'strom_5myr_df' in locals() and strom_5myr_df is not None:
    strom_5myr_df.to_csv(f"{OUTPUT_DIR}/intermediate_strom_5myr.csv", index=False)
if 'coral_5myr_df' in locals() and coral_5myr_df is not None:
    coral_5myr_df.to_csv(f"{OUTPUT_DIR}/intermediate_coral_5myr.csv", index=False)

print("✓ Intermediate files saved to ./output")



Processing Stromatoporoid (Stage)...
  ✓ Found: ./pbdb_Stromatoporida_midpoint_stages.csv
  ✓ Found: ./pbdb_Labechiida_midpoint_stages.csv
  ✓ Found: ./pbdb_Actinostromatida_midpoint_stages.csv
  ✓ Found: ./pbdb_Clathrodictyida_midpoint_stages.csv
  ✓ Found: ./pbdb_Syringostromatida_midpoint_stages.csv
  ✓ Found: ./pbdb_Stromatoporellida_midpoint_stages.csv
  ✓ Found: ./pbdb_Amphiporida_midpoint_stages.csv
  -> Dropped 11 Stromatoporoid stage(s) with zero occurrences.
  -> Merged Stromatoporoid Stages. Rows: 21

Processing Coral (Stage)...
  ✓ Found: ./pbdb_Tabulata_midpoint_stages.csv
  ✓ Found: ./pbdb_Rugosa_midpoint_stages.csv
  -> Dropped 11 Coral stage(s) with zero occurrences.
  -> Merged Coral Stages. Rows: 21

Processing Stromatoporoid (5-Myr)...
  ✓ Found: ./pbdb_Stromatoporida_midpoint_5myr_bins.csv
  ✓ Found: ./pbdb_Labechiida_midpoint_5myr_bins.csv
  ✓ Found: ./pbdb_Actinostromatida_midpoint_5myr_bins.csv
  ✓ Found: ./pbdb_Clathrodictyida_midpoint_5myr_bins.csv
  ✓ Found: 

In [7]:
# =============================================================================
# @title CELL 7: DEFINE CONSTANTS AND STAGE INFORMATION
# =============================================================================

# 1. Stage definitions (ICS 2023/04)
# Note: Cambrian stage boundary numerical ages are approximate (~) in the ICS chart.
STAGES = {
    # Cambrian
    'Fortunian':    {'start': 538.8, 'end': 529.0, 'mid': (538.8 + 529.0) / 2, 'period': 'Cambrian'},
    'Stage 2':      {'start': 529.0, 'end': 521.0, 'mid': (529.0 + 521.0) / 2, 'period': 'Cambrian'},
    'Stage 3':      {'start': 521.0, 'end': 514.0, 'mid': (521.0 + 514.0) / 2, 'period': 'Cambrian'},
    'Stage 4':      {'start': 514.0, 'end': 509.0, 'mid': (514.0 + 509.0) / 2, 'period': 'Cambrian'},
    'Wuliuan':      {'start': 509.0, 'end': 504.5, 'mid': (509.0 + 504.5) / 2, 'period': 'Cambrian'},
    'Drumian':      {'start': 504.5, 'end': 500.5, 'mid': (504.5 + 500.5) / 2, 'period': 'Cambrian'},
    'Guzhangian':   {'start': 500.5, 'end': 497.0, 'mid': (500.5 + 497.0) / 2, 'period': 'Cambrian'},
    'Paibian':      {'start': 497.0, 'end': 494.0, 'mid': (497.0 + 494.0) / 2, 'period': 'Cambrian'},
    'Jiangshanian': {'start': 494.0, 'end': 489.5, 'mid': (494.0 + 489.5) / 2, 'period': 'Cambrian'},
    'Stage 10':     {'start': 489.5, 'end': 485.4, 'mid': (489.5 + 485.4) / 2, 'period': 'Cambrian'},

    # Ordovician–Devonian
    'Tremadocian': {'start': 485.4, 'end': 477.7, 'mid': 481.55, 'period': 'Ordovician'},
    'Floian': {'start': 477.7, 'end': 470.0, 'mid': 473.85, 'period': 'Ordovician'},
    'Dapingian': {'start': 470.0, 'end': 467.3, 'mid': 468.65, 'period': 'Ordovician'},
    'Darriwilian': {'start': 467.3, 'end': 458.4, 'mid': 462.85, 'period': 'Ordovician'},
    'Sandbian': {'start': 458.4, 'end': 453.0, 'mid': 455.7, 'period': 'Ordovician'},
    'Katian': {'start': 453.0, 'end': 445.2, 'mid': 449.1, 'period': 'Ordovician'},
    'Hirnantian': {'start': 445.2, 'end': 443.8, 'mid': 444.5, 'period': 'Ordovician'},
    'Rhuddanian': {'start': 443.8, 'end': 440.8, 'mid': 442.3, 'period': 'Silurian'},
    'Aeronian': {'start': 440.8, 'end': 438.5, 'mid': 439.65, 'period': 'Silurian'},
    'Telychian': {'start': 438.5, 'end': 433.4, 'mid': 435.95, 'period': 'Silurian'},
    'Sheinwoodian': {'start': 433.4, 'end': 430.5, 'mid': 431.95, 'period': 'Silurian'},
    'Homerian': {'start': 430.5, 'end': 427.4, 'mid': 428.95, 'period': 'Silurian'},
    'Gorstian': {'start': 427.4, 'end': 425.6, 'mid': 426.5, 'period': 'Silurian'},
    'Ludfordian': {'start': 425.6, 'end': 423.0, 'mid': 424.3, 'period': 'Silurian'},
    'Pridoli': {'start': 423.0, 'end': 419.2, 'mid': 421.1, 'period': 'Silurian'},
    'Lochkovian': {'start': 419.2, 'end': 410.8, 'mid': 415.0, 'period': 'Devonian'},
    'Pragian': {'start': 410.8, 'end': 407.6, 'mid': 409.2, 'period': 'Devonian'},
    'Emsian': {'start': 407.6, 'end': 393.3, 'mid': 400.45, 'period': 'Devonian'},
    'Eifelian': {'start': 393.3, 'end': 387.7, 'mid': 390.5, 'period': 'Devonian'},
    'Givetian': {'start': 387.7, 'end': 382.7, 'mid': 385.2, 'period': 'Devonian'},
    'Frasnian': {'start': 382.7, 'end': 372.2, 'mid': 377.45, 'period': 'Devonian'},
    'Famennian': {'start': 372.2, 'end': 358.9, 'mid': 365.55, 'period': 'Devonian'}
}
STAGE_ORDER = list(STAGES.keys())

# 2. Period colors (ICS standard)
PERIOD_COLORS = {
    'Cambrian': '#7FA056',
    'Ordovician': '#009270',
    'Silurian': '#B3E1B6',
    'Devonian': '#CB8C37'
}

# 3. Stromatoporoid order colors (phylogenetically informed)
STROM_COLORS = {
    'Labechiida': '#8B0000',       # Dark red - basal
    'Clathrodictyida': '#CD5C5C',  # Indian red - early-diverging
    'Actinostromatida': '#FF8C00', # Dark orange - derived
    'Stromatoporida': '#FFD700',   # Gold - derived
    'Stromatoporellida': '#32CD32',# Lime green - derived
    'Syringostromatida': '#4169E1',# Royal blue - derived reef builders
    'Amphiporida': '#9370DB'       # Medium purple - derived
}
STROM_ORDERS = ['Labechiida', 'Clathrodictyida', 'Actinostromatida',
                'Stromatoporida', 'Stromatoporellida', 'Syringostromatida', 'Amphiporida']

# 4. Coral colors (New definitions for Rugosa/Tabulata)
CORAL_COLORS = {
    'Rugosa': '#800080',    # Purple
    'Tabulata': '#D2691E'   # Chocolate/Orange-Brown
}
CORAL_GROUPS = ['Rugosa', 'Tabulata']

# =============================================================================
# DATA NORMALIZATION: Ensure DataFrame columns match capitalized constants
# =============================================================================
# Some files were lowercase (e.g., 'clathrodictyida'), but constants are TitleCase.
# We fix this here to prevent KeyErrors in future plotting cells.

def normalize_columns(df, target_orders):
    if df is None: return df

    # Get current columns
    cols = df.columns.tolist()
    rename_map = {}

    for order in target_orders:
        # Check if TitleCase version exists (e.g., 'Clathrodictyida_genus')
        title_genus = f"{order}_genus"
        title_occ = f"{order}_occ"

        # Check if LowerCase version exists (e.g., 'clathrodictyida_genus')
        lower_genus = f"{order.lower()}_genus"
        lower_occ = f"{order.lower()}_occ"

        # If TitleCase missing but LowerCase present, map Lower -> Title
        if title_genus not in cols and lower_genus in cols:
            rename_map[lower_genus] = title_genus
        if title_occ not in cols and lower_occ in cols:
            rename_map[lower_occ] = title_occ

        # Also handle "tabulata" (lowercase t)
        if order == 'Tabulata' and 'tabulata_genus' in cols:
             rename_map['tabulata_genus'] = 'Tabulata_genus'
             rename_map['tabulata_occ'] = 'Tabulata_occ'

    if rename_map:
        print(f"  Note: Renaming columns to match standard capitalization: {list(rename_map.keys())}")
        df = df.rename(columns=rename_map)
    return df

# Apply normalization to the datasets loaded in Cell 5
if 'strom_df' in locals(): strom_df = normalize_columns(strom_df, STROM_ORDERS)
if 'strom_5myr_df' in locals(): strom_5myr_df = normalize_columns(strom_5myr_df, STROM_ORDERS)
if 'coral_df' in locals(): coral_df = normalize_columns(coral_df, CORAL_GROUPS)
if 'coral_5myr_df' in locals(): coral_5myr_df = normalize_columns(coral_5myr_df, CORAL_GROUPS)

print("✓ Constants defined and DataFrames normalized.")

✓ Constants defined and DataFrames normalized.


In [8]:
# =============================================================================
# @title CELL 8: RELOAD ENV DATA & INTERPOLATE (ROBUST FIX)
# =============================================================================
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
import os

# -----------------------------------------------------------------------------
# 1. RELOAD ENVIRONMENTAL DATA (To ensure it is not empty)
# -----------------------------------------------------------------------------
print("Reloading environmental datasets...")

def load_env_file(filenames, target_cols):
    """Try to load a file from a list of possible names"""
    for fname in filenames:
        if os.path.exists(fname):
            try:
                df = pd.read_csv(fname)
                df.columns = df.columns.str.strip() # clean whitespace
                print(f"  ✓ Loaded {fname} ({len(df)} rows)")
                return df
            except Exception as e:
                print(f"  ! Error loading {fname}: {e}")
    print(f"  ! WARNING: Could not find any of {filenames}")
    return pd.DataFrame() # Return empty if not found

# Load with specific fallbacks
temp_df     = load_env_file(['temperature.csv', 'Temperature.csv'], ['Age', 'SST'])
do_df       = load_env_file(['DO.csv', 'do.csv'], ['Age', 'DO'])
oxygen_df   = load_env_file(['oxygen.csv', 'Oxygen.csv'], ['Age', 'O2'])
sealevel_df = load_env_file(['sealevel.csv', 'Sealevel.csv'], ['Age', 'Eustatic'])

# -----------------------------------------------------------------------------
# 2. STANDARDIZE COLUMNS
# -----------------------------------------------------------------------------
def standardize_env_columns(df, name, target_age='Age', target_val=None):
    if df.empty: return df

    # Fix Age column
    if target_age not in df.columns:
        for candidate in ['age', 'AGE', 'Time', 'Ma', 'time']:
            if candidate in df.columns:
                df = df.rename(columns={candidate: target_age})
                break

    # Fix Value column
    if target_val and target_val not in df.columns:
        # Check case-insensitive match
        for col in df.columns:
            if col.lower() == target_val.lower():
                df = df.rename(columns={col: target_val})
                break

    return df

do_df = standardize_env_columns(do_df, 'Dissolved Oxygen', target_val='DO')
temp_df = standardize_env_columns(temp_df, 'Temperature', target_val='SST')
oxygen_df = standardize_env_columns(oxygen_df, 'Atmosphere', target_age='Age')
sealevel_df = standardize_env_columns(sealevel_df, 'Sea Level', target_val='Eustatic Sea Level')

# -----------------------------------------------------------------------------
# 3. INTERPOLATION
# -----------------------------------------------------------------------------
def interpolate_to_ages(source_df, age_col, value_col, target_ages):
    """Interpolate values to target ages, skipping if columns missing."""
    # Check if empty or missing columns
    if source_df.empty or age_col not in source_df.columns or value_col not in source_df.columns:
        return np.full_like(target_ages, np.nan)

    # Drop NaNs in source
    source_df = source_df.dropna(subset=[age_col, value_col])
    # Sort by Age (Crucial for interp1d)
    source_df = source_df.sort_values(age_col)

    if len(source_df) < 2:
        return np.full_like(target_ages, np.nan)

    # Interpolate
    f = interp1d(source_df[age_col], source_df[value_col],
                 kind='linear', fill_value='extrapolate', bounds_error=False)
    return f(target_ages)

# Get Stage Midpoints from Cell 7 constants
stage_midpoints = np.array([STAGES[s]['mid'] for s in STAGE_ORDER])

print("\nInterpolating environmental proxies to stage midpoints...")
env_data = pd.DataFrame({'stage': STAGE_ORDER, 'midpoint_ma': stage_midpoints})

env_data['temperature']     = interpolate_to_ages(temp_df, 'Age', 'SST', stage_midpoints)
env_data['dissolved_O2']    = interpolate_to_ages(do_df, 'Age', 'DO', stage_midpoints)
env_data['atm_O2']          = interpolate_to_ages(oxygen_df, 'Age', 'O2', stage_midpoints)
env_data['atm_CO2']         = interpolate_to_ages(oxygen_df, 'Age', 'CO2', stage_midpoints)
env_data['sea_level']       = interpolate_to_ages(sealevel_df, 'Age', 'Eustatic Sea Level', stage_midpoints)
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# MERGE δ13C (STAGE-BINNED) — prefer Stage-name join; fallback to age-nearest
# Uses attached d13C_stage_binned_Cam-Dev.csv WITHOUT re-binning/conversion.
# -----------------------------------------------------------------------------
import re
from pathlib import Path

def _norm_stage(s):
    if pd.isna(s):
        return np.nan
    s = str(s).strip().lower()
    # remove spaces/punct to survive minor naming differences
    return re.sub(r'[^a-z0-9]+', '', s)

try:
    # Load attached stage-binned δ13C table
    cand = [
        Path("d13C_stage_binned_Cam-Dev.csv"),
        Path("/mnt/data/d13C_stage_binned_Cam-Dev.csv")
    ]
    f = next((p for p in cand if p.exists()), None)
    if f is None:
        raise FileNotFoundError("d13C_stage_binned_Cam-Dev.csv not found in working dir or /mnt/data")

    d13_stage = pd.read_csv(f)
    d13_stage.columns = d13_stage.columns.str.strip().str.replace('\ufeff', '')

    # Required columns in attached file
    if "Stage" not in d13_stage.columns or "Mid_Ma" not in d13_stage.columns or "d13C_mean" not in d13_stage.columns:
        raise ValueError("δ13C stage file must have columns: Stage, Mid_Ma, d13C_mean")

    # Clean numeric + stage keys
    d13_stage["Mid_Ma"] = pd.to_numeric(d13_stage["Mid_Ma"], errors="coerce")
    d13_stage["d13C_mean"] = pd.to_numeric(d13_stage["d13C_mean"], errors="coerce")
    d13_stage["stage_key"] = d13_stage["Stage"].apply(_norm_stage)

    env_data["midpoint_ma"] = pd.to_numeric(env_data["midpoint_ma"], errors="coerce")

    # --- 1) Stage-name merge (best) ---
    if "stage" in env_data.columns:
        env_data["stage_key"] = env_data["stage"].apply(_norm_stage)

        _m1 = env_data.merge(
            d13_stage[["stage_key", "d13C_mean", "Mid_Ma"]],
            on="stage_key",
            how="left",
            suffixes=("", "_d13")
        )

        # Keep δ13C as a single downstream name
        _m1["d13C"] = _m1["d13C_mean"]

        # --- 2) Fallback: for unmatched stages, fill by age-nearest ---
        missing = _m1["d13C"].isna() & _m1["midpoint_ma"].notna()
        if missing.any():
            _left = _m1.loc[missing, ["midpoint_ma"]].copy().sort_values("midpoint_ma")
            _right = d13_stage[["Mid_Ma", "d13C_mean"]].dropna().sort_values("Mid_Ma")

            _fill = pd.merge_asof(
                _left,
                _right,
                left_on="midpoint_ma",
                right_on="Mid_Ma",
                direction="nearest",
                tolerance=2.0  # stage midpoints can differ by >0.25; allow reasonable slack
            )
            _m1.loc[missing, "d13C"] = _fill["d13C_mean"].values

        # Cleanup
        env_data = _m1.drop(columns=[c for c in ["d13C_mean", "Mid_Ma", "stage_key"] if c in _m1.columns])
        env_data = env_data.sort_values("midpoint_ma", ascending=False).reset_index(drop=True)

    else:
        # If env_data has no stage column, do age-nearest only (more tolerant)
        _left = env_data.dropna(subset=["midpoint_ma"]).sort_values("midpoint_ma")
        _right = d13_stage[["Mid_Ma", "d13C_mean"]].dropna().sort_values("Mid_Ma")
        _m = pd.merge_asof(_left, _right, left_on="midpoint_ma", right_on="Mid_Ma", direction="nearest", tolerance=2.0)
        env_data = _m.drop(columns=["Mid_Ma"]).rename(columns={"d13C_mean": "d13C"})
        env_data = env_data.sort_values("midpoint_ma", ascending=False).reset_index(drop=True)

    n_valid = int(env_data["d13C"].notna().sum()) if "d13C" in env_data.columns else 0
    print(f"  -> δ13C (stage-binned) merged: {n_valid}/{len(env_data)} values")

except Exception as e:
    env_data["d13C"] = np.nan
    print("  ! Warning: δ13C stage merge failed:", e)


# -----------------------------------------------------------------------------
# 4. INTERPOLATE TO 5-MYR BINS
# -----------------------------------------------------------------------------
print("\nInterpolating environmental proxies to 5-Myr bin midpoints...")

# Define bins (Logic from previous step)
if 'reef_5myr_df' in locals() and not reef_5myr_df.empty:
    bin_midpoints_5myr = reef_5myr_df['midpoint_ma'].values
    bin_ids = reef_5myr_df['time_identifier']
elif 'strom_5myr_df' in locals() and not strom_5myr_df.empty:
    # Recalculate if column missing
    if 'midpoint_ma' not in strom_5myr_df.columns:
        strom_5myr_df['midpoint_ma'] = (strom_5myr_df['bin_top'] + strom_5myr_df['bin_bottom']) / 2
    bin_midpoints_5myr = strom_5myr_df['midpoint_ma'].values
    bin_ids = strom_5myr_df['bin_label']
else:
    # Fallback
    bin_midpoints_5myr = np.arange(482.5, 359, -5)
    bin_ids = [f"{x}" for x in bin_midpoints_5myr]

env_data_5myr = pd.DataFrame({'bin_id': bin_ids, 'midpoint_ma': bin_midpoints_5myr})

env_data_5myr['temperature']     = interpolate_to_ages(temp_df, 'Age', 'SST', bin_midpoints_5myr)
env_data_5myr['dissolved_O2']    = interpolate_to_ages(do_df, 'Age', 'DO', bin_midpoints_5myr)
env_data_5myr['atm_O2']          = interpolate_to_ages(oxygen_df, 'Age', 'O2', bin_midpoints_5myr)
env_data_5myr['atm_CO2']         = interpolate_to_ages(oxygen_df, 'Age', 'CO2', bin_midpoints_5myr)
env_data_5myr['sea_level']       = interpolate_to_ages(sealevel_df, 'Age', 'Eustatic Sea Level', bin_midpoints_5myr)
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# 4b. MERGE δ13C (5-MYR BINNED) WITHOUT INTERPOLATION / RE-BINNING
#   Float bin midpoints often differ by tiny rounding; if exact merge yields
#   few/zero matches, fall back to merge_asof (requires ascending sort).
# -----------------------------------------------------------------------------
try:
    if 'd13c_5myr_df' in globals() and isinstance(d13c_5myr_df, pd.DataFrame) and (not d13c_5myr_df.empty):
        _d13_5 = d13c_5myr_df.copy()
    else:
        cand = [
            'd13C_5Myr_Cam-Dev.csv',
            './output/d13C_5Myr_Cam-Dev.csv',
            'd13C_5Myr.csv',
            './output/d13C_5Myr.csv'
        ]
        _d13_5 = None
        for p in cand:
            if os.path.exists(p):
                _d13_5 = pd.read_csv(p)
                break
        if _d13_5 is None:
            raise FileNotFoundError("No 5-Myr δ13C file found (tried: " + ", ".join(cand) + ").")

    _d13_5.columns = _d13_5.columns.str.strip().str.replace('\ufeff', '')

    # Standardize columns
    if 'age_Ma' in _d13_5.columns and 'midpoint_ma' not in _d13_5.columns:
        _d13_5 = _d13_5.rename(columns={'age_Ma': 'midpoint_ma'})
    if 'Mid_Ma' in _d13_5.columns and 'midpoint_ma' not in _d13_5.columns:
        _d13_5 = _d13_5.rename(columns={'Mid_Ma': 'midpoint_ma'})
    if 'd13Ccarb_permille' in _d13_5.columns and 'd13C' not in _d13_5.columns:
        _d13_5 = _d13_5.rename(columns={'d13Ccarb_permille': 'd13C'})
    if 'd13C_mean' in _d13_5.columns and 'd13C' not in _d13_5.columns:
        _d13_5 = _d13_5.rename(columns={'d13C_mean': 'd13C'})

    _d13_5['midpoint_ma'] = pd.to_numeric(_d13_5['midpoint_ma'], errors='coerce')
    _d13_5['d13C'] = pd.to_numeric(_d13_5['d13C'], errors='coerce')
    _d13_5 = _d13_5.dropna(subset=['midpoint_ma', 'd13C']).copy()

    # --- Attempt exact merge after rounding to reduce floating mismatch
    env_data_5myr['midpoint_ma'] = pd.to_numeric(env_data_5myr['midpoint_ma'], errors='coerce')
    _left = env_data_5myr.copy()
    _left['midpoint_ma_round'] = _left['midpoint_ma'].round(3)
    _right = _d13_5[['midpoint_ma', 'd13C']].copy()
    _right['midpoint_ma_round'] = _right['midpoint_ma'].round(3)

    env_data_5myr = _left.merge(_right[['midpoint_ma_round', 'd13C']], on='midpoint_ma_round', how='left')
    env_data_5myr = env_data_5myr.drop(columns=['midpoint_ma_round'])

    n_valid = env_data_5myr['d13C'].notna().sum()

    # --- Fallback: nearest-age join if exact merge fails
    if n_valid == 0 and len(_d13_5) > 0:
        _left_sorted = env_data_5myr.drop(columns=['d13C'], errors='ignore').copy()
        _left_sorted = _left_sorted.dropna(subset=['midpoint_ma']).sort_values('midpoint_ma', ascending=True).reset_index(drop=True)

        _right_sorted = _d13_5[['midpoint_ma', 'd13C']].sort_values('midpoint_ma', ascending=True).reset_index(drop=True)

        _m = pd.merge_asof(
            _left_sorted,
            _right_sorted,
            on='midpoint_ma',
            direction='nearest',
            tolerance=2.6  # ~half of 5-Myr bin width
        )
        env_data_5myr = _m.sort_values('midpoint_ma', ascending=False).reset_index(drop=True)
        n_valid = env_data_5myr['d13C'].notna().sum()

    print(f"  -> δ13C (5-Myr binned) merged: {n_valid}/{len(env_data_5myr)} values")

except Exception as e:
    env_data_5myr['d13C'] = np.nan
    print("  ! Warning: δ13C 5-Myr merge failed:", e)

print("✓ Environmental proxies interpolated (Stage & 5-Myr).")

# =============================================================================
# [ADDED] SAVE INTERPOLATED ENV DATA
# =============================================================================
if 'env_data' in locals() and not env_data.empty:
    env_data.to_csv(f"{OUTPUT_DIR}/intermediate_env_data_stage.csv", index=False)

if 'env_data_5myr' in locals() and not env_data_5myr.empty:
    env_data_5myr.to_csv(f"{OUTPUT_DIR}/intermediate_env_data_5myr.csv", index=False)

print(f"✓ Interpolated environmental data saved to {OUTPUT_DIR}")


Reloading environmental datasets...
  ✓ Loaded temperature.csv (387 rows)
  ✓ Loaded DO.csv (499 rows)
  ✓ Loaded oxygen.csv (58 rows)
  ✓ Loaded sealevel.csv (101 rows)

Interpolating environmental proxies to stage midpoints...
  -> δ13C (stage-binned) merged: 32/32 values

Interpolating environmental proxies to 5-Myr bin midpoints...
  -> δ13C (5-Myr binned) merged: 37/37 values
✓ Environmental proxies interpolated (Stage & 5-Myr).
✓ Interpolated environmental data saved to ./output


In [9]:
# =============================================================================
# @title CELL 9: CREATE MASTER DATASET (STAGES AND 5-MYR BINS)
# =============================================================================
import numpy as np
import pandas as pd

# -----------------------------------------------------------------------------
# 1. BUILD MASTER STAGE DATASET
# -----------------------------------------------------------------------------
print("Building Master Stage Dataset...")
data = []

for stage, info in STAGES.items():
    row = {
        'stage': stage,
        'midpoint_ma': info['mid'],
        'start_ma': info['start'],
        'end_ma': info['end'],
        'period': info['period']
    }

# A. Reef data (if available)
    if 'reef_df' in locals() and not reef_df.empty:
        reef_row = reef_df[reef_df['name'] == stage]
        if len(reef_row) > 0:
            for col in [
                'thickness_mean', 'thickness_std', 'thickness_stderr', 'thickness_median',
                'thickness_min', 'thickness_max', 'thickness_q25', 'thickness_q75', 'thickness_count',
                'width_mean', 'width_std', 'width_median', 'width_min', 'width_max',
                'reef_count'
            ]:
                if col in reef_row.columns:
                    row[col] = reef_row[col].values[0]

    # B. Stromatoporoid data (from strom_df)
    if 'strom_df' in locals() and not strom_df.empty:
        # Check column name (case insensitive)
        strom_row = strom_df[strom_df['stage'].str.lower() == stage.lower()]
        if len(strom_row) > 0:
            for order in STROM_ORDERS:
                col_occ = f'{order}_occ'
                col_gen = f'{order}_genus'
                # Check actual columns (constants are TitleCase, data is normalized in Cell 7)
                if col_occ in strom_row.columns:
                    row[col_occ] = strom_row[col_occ].values[0]
                if col_gen in strom_row.columns:
                    row[col_gen] = strom_row[col_gen].values[0]

            if 'Total_occ' in strom_row.columns:
                row['strom_total_occ'] = strom_row['Total_occ'].values[0]
            if 'Total_genus' in strom_row.columns:
                row['strom_total_gen'] = strom_row['Total_genus'].values[0]

    # C. Coral data (from coral_df)
    if 'coral_df' in locals() and not coral_df.empty:
        coral_row = coral_df[coral_df['stage'].str.lower() == stage.lower()]
        if len(coral_row) > 0:
            # Rugosa
            if 'Rugosa_genus' in coral_row.columns: row['rugose_div'] = coral_row['Rugosa_genus'].values[0]
            if 'Rugosa_occ' in coral_row.columns:   row['rugose_occ'] = coral_row['Rugosa_occ'].values[0]

            # Tabulata
            if 'Tabulata_genus' in coral_row.columns: row['tabulate_div'] = coral_row['Tabulata_genus'].values[0]
            if 'Tabulata_occ' in coral_row.columns:   row['tabulate_occ'] = coral_row['Tabulata_occ'].values[0]

    # D. Macrostrat data
    if 'macro_stage' in locals() and not macro_stage.empty:
        macro_row = macro_stage[macro_stage['stage'] == stage]
        if len(macro_row) > 0:
            row['total_area_km2'] = macro_row['total_area_km2'].values[0]
            row['carbonate_area_km2'] = macro_row['carbonate_area_km2'].values[0]
            row['carbonate_percentage'] = macro_row['carbonate_percentage'].values[0]

    # E. Environmental proxies
    if 'env_data' in locals() and not env_data.empty:
        env_row = env_data[env_data['stage'] == stage]
        if len(env_row) > 0:
            row['temperature'] = env_row['temperature'].values[0]
            row['dissolved_O2'] = env_row['dissolved_O2'].values[0]
            row['atm_O2'] = env_row['atm_O2'].values[0]
            row['atm_CO2'] = env_row['atm_CO2'].values[0]
            row['sea_level'] = env_row['sea_level'].values[0]
            if 'd13C' in env_row.columns: row['d13C'] = env_row['d13C'].values[0]

    data.append(row)

df = pd.DataFrame(data)

# Calculate Stromatoporoid Proportions
if 'strom_total_occ' in df.columns:
    for order in STROM_ORDERS:
        col_occ = f'{order}_occ'
        if col_occ in df.columns:
            df[f'{order}_prop'] = np.where(
                df['strom_total_occ'] > 0,
                df[col_occ].fillna(0) / df['strom_total_occ'],
                0.0
            )

    # Derived vs Basal
    derived_orders = ['Actinostromatida', 'Stromatoporida', 'Stromatoporellida',
                      'Syringostromatida', 'Amphiporida']
    basal_orders = ['Labechiida', 'Clathrodictyida']

    # Occurrences
    df['derived_strom_occ'] = sum(df[f'{o}_occ'].fillna(0) for o in derived_orders if f'{o}_occ' in df.columns)
    df['basal_strom_occ'] = sum(df[f'{o}_occ'].fillna(0) for o in basal_orders if f'{o}_occ' in df.columns)

    # Diversity
    df['derived_strom_div'] = sum(df[f'{o}_genus'].fillna(0) for o in derived_orders if f'{o}_genus' in df.columns)
    df['basal_strom_div'] = sum(df[f'{o}_genus'].fillna(0) for o in basal_orders if f'{o}_genus' in df.columns)

    # Proportions
    df['derived_strom_prop'] = np.where(df['strom_total_occ'] > 0, df['derived_strom_occ'] / df['strom_total_occ'], 0.0)
    df['basal_strom_prop'] = np.where(df['strom_total_occ'] > 0, df['basal_strom_occ'] / df['strom_total_occ'], 0.0)

# Sort
df = df.sort_values('midpoint_ma', ascending=False).reset_index(drop=True)
print(f"✓ Master dataset (STAGES) created: {len(df)} stages, {len(df.columns)} variables")


# -----------------------------------------------------------------------------
# 2. BUILD MASTER 5-MYR DATASET
# -----------------------------------------------------------------------------
print("\n" + "="*70)
print("CREATING 5-MYR BIN MASTER DATASET")
print("="*70)

# Determine the primary source for 5-Myr bins
# We prefer reef data if available, otherwise we use the biological data
if 'reef_5myr_df' in locals() and not reef_5myr_df.empty:
    primary_bins = reef_5myr_df
    print("Using Reef Data as primary 5-Myr bin source.")
elif 'strom_5myr_df' in locals() and not strom_5myr_df.empty:
    primary_bins = strom_5myr_df
    # Add midpoint if missing
    if 'midpoint_ma' not in primary_bins.columns:
        primary_bins['midpoint_ma'] = (primary_bins['bin_top'] + primary_bins['bin_bottom']) / 2
    print("Using Stromatoporoid Data as primary 5-Myr bin source.")
else:
    # Fallback to env_data_5myr if bio data missing
    primary_bins = env_data_5myr
    print("Using Environmental Data as primary 5-Myr bin source.")

data_5myr = []

# Iterate through the chosen primary bins
for idx, row_ref in primary_bins.iterrows():
    # Identify bin
    if 'time_identifier' in row_ref: bin_id = row_ref['time_identifier']
    elif 'bin_label' in row_ref: bin_id = row_ref['bin_label']
    elif 'bin_id' in row_ref: bin_id = row_ref['bin_id']
    else: bin_id = idx # fallback

    midpoint = row_ref['midpoint_ma']

    row = {
        'bin_id': bin_id,
        'midpoint_ma': midpoint
    }

    # A. Reef Data (if available)
    if 'reef_5myr_df' in locals() and not reef_5myr_df.empty:
        # Find matching reef row (if not already iterating it)
        if primary_bins is not reef_5myr_df:
            # Match by midpoint proximity (float comparison)
            reef_match = reef_5myr_df[abs(reef_5myr_df['midpoint_ma'] - midpoint) < 0.1]
            if len(reef_match) > 0:
                row_ref_for_reef = reef_match.iloc[0]
            else:
                row_ref_for_reef = pd.Series()
        else:
            row_ref_for_reef = row_ref

        for col in ['thickness_mean', 'thickness_std', 'thickness_stderr', 'thickness_median',
                    'thickness_count', 'width_mean', 'width_std', 'reef_count']:
            if col in row_ref_for_reef.index:
                row[col] = row_ref_for_reef[col]

    # B. Stromatoporoid Data (5-Myr) - NOW INCLUDED
    if 'strom_5myr_df' in locals() and not strom_5myr_df.empty:
        # Match by midpoint
        strom_match = strom_5myr_df[abs(((strom_5myr_df['bin_top'] + strom_5myr_df['bin_bottom'])/2) - midpoint) < 0.1]
        if len(strom_match) > 0:
            s_row = strom_match.iloc[0]
            for order in STROM_ORDERS:
                if f'{order}_occ' in s_row: row[f'{order}_occ'] = s_row[f'{order}_occ']
                if f'{order}_genus' in s_row: row[f'{order}_genus'] = s_row[f'{order}_genus']
            if 'Total_occ' in s_row: row['strom_total_occ'] = s_row['Total_occ']
            if 'Total_genus' in s_row: row['strom_total_gen'] = s_row['Total_genus']

    # C. Coral Data (5-Myr) - NOW INCLUDED
    if 'coral_5myr_df' in locals() and not coral_5myr_df.empty:
        # Match by midpoint
        coral_match = coral_5myr_df[abs(((coral_5myr_df['bin_top'] + coral_5myr_df['bin_bottom'])/2) - midpoint) < 0.1]
        if len(coral_match) > 0:
            c_row = coral_match.iloc[0]
            if 'Rugosa_occ' in c_row: row['rugose_occ'] = c_row['Rugosa_occ']
            if 'Rugosa_genus' in c_row: row['rugose_div'] = c_row['Rugosa_genus']
            if 'Tabulata_occ' in c_row: row['tabulate_occ'] = c_row['Tabulata_occ']
            if 'Tabulata_genus' in c_row: row['tabulate_div'] = c_row['Tabulata_genus']

    # D. Macrostrat Data
    if 'macro_5myr' in locals() and not macro_5myr.empty and 'bin_mid' in macro_5myr.columns:
        macro_match = macro_5myr[abs(macro_5myr['bin_mid'] - midpoint) < 2.5]
        if len(macro_match) > 0:
            row['total_area_km2'] = macro_match['total_area_km2'].values[0]
            row['carbonate_area_km2'] = macro_match['carbonate_area_km2'].values[0]
            row['carbonate_percentage'] = macro_match['carbonate_percentage'].values[0]

    # E. Environmental Proxies
    if 'env_data_5myr' in locals() and not env_data_5myr.empty:
        # Match by midpoint
        env_match = env_data_5myr[abs(env_data_5myr['midpoint_ma'] - midpoint) < 0.1]
        if len(env_match) > 0:
            env_val = env_match.iloc[0]
            row['temperature'] = env_val['temperature']
            row['dissolved_O2'] = env_val['dissolved_O2']
            row['atm_O2'] = env_val['atm_O2']
            row['atm_CO2'] = env_val['atm_CO2']
            row['sea_level'] = env_val['sea_level']
            if 'd13C' in env_val.index: row['d13C'] = env_val['d13C']

    data_5myr.append(row)

df_5myr = pd.DataFrame(data_5myr)
df_5myr = df_5myr.sort_values('midpoint_ma', ascending=False).reset_index(drop=True)
print(f"✓ Master dataset (5-MYR BINS) created: {len(df_5myr)} bins, {len(df_5myr.columns)} variables")

# =============================================================================
# [ADDED] SAVE MASTER DATASETS
# =============================================================================
# Save the master datasets immediately after creation
if 'df' in locals() and not df.empty:
    df.to_csv(f"{OUTPUT_DIR}/MASTER_dataset_stage.csv", index=False)
    print(f"✓ Saved: {OUTPUT_DIR}/MASTER_dataset_stage.csv")

if 'df_5myr' in locals() and not df_5myr.empty:
    df_5myr.to_csv(f"{OUTPUT_DIR}/MASTER_dataset_5myr.csv", index=False)
    print(f"✓ Saved: {OUTPUT_DIR}/MASTER_dataset_5myr.csv")


Building Master Stage Dataset...
✓ Master dataset (STAGES) created: 32 stages, 56 variables

CREATING 5-MYR BIN MASTER DATASET
Using Reef Data as primary 5-Myr bin source.
✓ Master dataset (5-MYR BINS) created: 37 bins, 39 variables
✓ Saved: ./output/MASTER_dataset_stage.csv
✓ Saved: ./output/MASTER_dataset_5myr.csv


In [10]:
# =============================================================================
# @title CELL 10: CLR COMPOSITIONAL TRANSFORMATION (WITH BASAL/DERIVED GROUPS)
# =============================================================================
import numpy as np
import pandas as pd
from scipy.stats import gmean
import scipy.stats as stats
print("\n" + "="*90)
print("PRE-PROCESSING: CLR TRANSFORMATION")
print("Transforming closed compositional data (proportions) to open log-ratios")
print("="*90)
# Global list for CLR results
global_clr_results = []
def calculate_missing_props(input_df):
    """
    Helper: If _prop columns are missing, calculate them from _occ columns.
    """
    df_copy = input_df.copy()
    orders = ['Labechiida', 'Clathrodictyida', 'Actinostromatida',
              'Stromatoporida', 'Stromatoporellida',
              'Syringostromatida', 'Amphiporida']
    # Check if we have occurrence columns
    occ_cols = [f"{o}_occ" for o in orders if f"{o}_occ" in df_copy.columns]
    if not occ_cols:
        return df_copy
    # Recalculate Total
    if 'strom_total_occ' not in df_copy.columns:
        df_copy['strom_total_occ'] = df_copy[occ_cols].sum(axis=1)
    # Calculate Proportions
    for o in orders:
        occ_col = f"{o}_occ"
        prop_col = f"{o}_prop"
        if occ_col in df_copy.columns:
            total = df_copy['strom_total_occ'].replace(0, np.nan)
            df_copy[prop_col] = df_copy[occ_col] / total
            df_copy[prop_col] = df_copy[prop_col].fillna(0)
    return df_copy
def apply_clr(input_df, label):
    """Apply CLR transformation and calculate group-level metrics."""

    dataset = input_df.copy()

    # 1. AUTO-REPAIR: Ensure Proportion Columns Exist
    dataset = calculate_missing_props(dataset)
    # 2. Identify Proportion Columns
    prop_cols = ['Labechiida_prop', 'Clathrodictyida_prop', 'Actinostromatida_prop',
                 'Stromatoporida_prop', 'Stromatoporellida_prop',
                 'Syringostromatida_prop', 'Amphiporida_prop']
    available_cols = [c for c in prop_cols if c in dataset.columns]
    # Define group membership
    basal_orders = ['Labechiida_prop', 'Clathrodictyida_prop']
    derived_orders = ['Actinostromatida_prop', 'Stromatoporida_prop',
                      'Stromatoporellida_prop', 'Syringostromatida_prop',
                      'Amphiporida_prop']
    if len(available_cols) < 2:
        print(f"  {label}: ! Skipped. Found only {len(available_cols)} proportion columns.")
        return dataset
    # 3. Extract and Handle Zeros
    comp_data = dataset[available_cols].replace(0, 1e-5)
    # 4. Geometric Mean per Row
    gmeans = gmean(comp_data, axis=1)
    # 5. Transform: ln(x / gmean)
    clr_data = np.log(comp_data.div(gmeans, axis=0))

    # Use CLR_ prefix (avoiding duplicates)
    new_col_names = []
    for c in available_cols:
        new_name = f"CLR_{c}"
        # Drop existing column if present to avoid duplicates
        if new_name in dataset.columns:
            dataset = dataset.drop(columns=[new_name])
        new_col_names.append(new_name)
    clr_data.columns = new_col_names
    # 6. Create 'Derived vs Basal' Log-Ratio
    basal_in = [c for c in basal_orders if c in dataset.columns]
    derived_in = [c for c in derived_orders if c in dataset.columns]
    if basal_in and derived_in:
        b_sum = dataset[basal_in].sum(axis=1).replace(0, 1e-5)
        d_sum = dataset[derived_in].sum(axis=1).replace(0, 1e-5)
        dataset['log_derived_basal_ratio'] = np.log(d_sum / b_sum)
        print(f"  {label}: Created 'log_derived_basal_ratio'")
    # 7. Merge CLR columns back to dataset
    dataset = pd.concat([dataset.reset_index(drop=True), clr_data.reset_index(drop=True)], axis=1)
    print(f"  {label}: Generated {len(clr_data.columns)} CLR variables.")

    # 8. Calculate GROUP-LEVEL CLR means
    clr_basal_cols = [f"CLR_{c}" for c in basal_in]
    clr_derived_cols = [f"CLR_{c}" for c in derived_in]

    if clr_basal_cols:
        dataset['CLR_basal_mean'] = dataset[clr_basal_cols].mean(axis=1)
    if clr_derived_cols:
        dataset['CLR_derived_mean'] = dataset[clr_derived_cols].mean(axis=1)

    return dataset
# Apply to both master datasets
df = apply_clr(df, "Stage-Level")
df_5myr = apply_clr(df_5myr, "5-Myr Bins")
# =============================================================================
# CLR CORRELATION ANALYSIS (Individual Taxa + Basal/Derived Groups)
# =============================================================================
print("\n" + "-"*90)
print("CLR CORRELATION ANALYSIS")
print("-"*90)
prop_cols = ['Labechiida_prop', 'Clathrodictyida_prop', 'Actinostromatida_prop',
             'Stromatoporida_prop', 'Stromatoporellida_prop',
             'Syringostromatida_prop', 'Amphiporida_prop']
targets = ['thickness_mean', 'width_mean']
def safe_spearman(x, y):
    """Calculate Spearman correlation safely, returning scalars."""
    try:
        # Ensure 1D numpy arrays
        x_arr = np.array(x).flatten()
        y_arr = np.array(y).flatten()

        # Remove NaN pairs
        mask = ~(np.isnan(x_arr) | np.isnan(y_arr))
        x_clean = x_arr[mask]
        y_clean = y_arr[mask]

        if len(x_clean) < 3:
            return np.nan, np.nan

        result = stats.spearmanr(x_clean, y_clean)
        # Handle both old and new scipy return types
        if hasattr(result, 'correlation'):
            return float(result.correlation), float(result.pvalue)
        else:
            return float(result[0]), float(result[1])
    except:
        return np.nan, np.nan
for label, data in [('Stage-Level', df), ('5-Myr Bins', df_5myr)]:
    print(f"\n--- {label} ---")
    print(f"  {'Taxon/Group':<28s} | {'Orig ρ':>8s} | {'CLR ρ':>8s} | {'Diff':>7s} | {'CLR p':>10s}")
    print("  " + "-"*75)

    # A. INDIVIDUAL TAXA
    for prop in prop_cols:
        clr_col = f"CLR_{prop}"
        if prop not in data.columns or clr_col not in data.columns:
            continue

        for target in targets:
            if target not in data.columns:
                continue

            # Get data as 1D arrays
            x_orig = data[prop].values
            x_clr = data[clr_col].values
            y = data[target].values

            # Calculate correlations
            r_orig, p_orig = safe_spearman(x_orig, y)
            r_clr, p_clr = safe_spearman(x_clr, y)

            if np.isnan(r_orig) or np.isnan(r_clr):
                continue

            diff = r_clr - r_orig

            # Interpretation
            if abs(diff) < 0.1:
                interp = "Stable"
            elif diff < -0.2:
                interp = "SUPPRESSED"
            elif diff > 0.2:
                interp = "INFLATED"
            else:
                interp = "Moderate"

            global_clr_results.append({
                'Dataset': label, 'Level': 'Individual',
                'Predictor': prop.replace('_prop', ''),
                'Target': target, 'Original_Rho': r_orig, 'CLR_Rho': r_clr,
                'Difference': diff, 'CLR_P': p_clr, 'Interpretation': interp
            })

            if target == 'thickness_mean':
                sig = "*" if p_clr < 0.05 else ""
                print(f"  {prop.replace('_prop',''):<28s} | {r_orig:>+8.3f} | {r_clr:>+8.3f} | {diff:>+7.3f} | {p_clr:>10.4f} {sig}")

    # B. BASAL vs DERIVED GROUPS
    print("  " + "-"*75)
    print("  GROUP COMPARISONS:")

    group_vars = [
        ('CLR_basal_mean', 'Basal (Labech+Clathr) CLR'),
        ('CLR_derived_mean', 'Derived (5 taxa) CLR'),
        ('log_derived_basal_ratio', 'Log(Derived/Basal) Ratio')
    ]

    for col, name in group_vars:
        if col not in data.columns:
            continue

        for target in targets:
            if target not in data.columns:
                continue

            x = data[col].values
            y = data[target].values

            r, p = safe_spearman(x, y)

            if np.isnan(r):
                continue

            global_clr_results.append({
                'Dataset': label, 'Level': 'Group',
                'Predictor': name, 'Target': target,
                'Original_Rho': np.nan, 'CLR_Rho': r,
                'Difference': np.nan, 'CLR_P': p, 'Interpretation': 'Group-level'
            })

            if target == 'thickness_mean':
                sig = "*" if p < 0.05 else ""
                print(f"  {name:<28s} |      N/A | {r:>+8.3f} |     N/A | {p:>10.4f} {sig}")
# =============================================================================
# SAVE CLR RESULTS
# =============================================================================
pd.DataFrame(global_clr_results).to_csv(f"{OUTPUT_DIR}/results_clr.csv", index=False)
print(f"\nSaved: {OUTPUT_DIR}/results_clr.csv ({len(global_clr_results)} rows)")
# Update predictor lists
if 'all_test_vars' in locals():
    new_vars = [
        ('log_derived_basal_ratio', 'Log(Derived/Basal)'),
        ('CLR_basal_mean', 'CLR Basal Mean'),
        ('CLR_derived_mean', 'CLR Derived Mean')
    ]
    for var in new_vars:
        if not any(v[0] == var[0] for v in all_test_vars):
            all_test_vars.append(var)



PRE-PROCESSING: CLR TRANSFORMATION
Transforming closed compositional data (proportions) to open log-ratios
  Stage-Level: Created 'log_derived_basal_ratio'
  Stage-Level: Generated 7 CLR variables.
  5-Myr Bins: Created 'log_derived_basal_ratio'
  5-Myr Bins: Generated 7 CLR variables.

------------------------------------------------------------------------------------------
CLR CORRELATION ANALYSIS
------------------------------------------------------------------------------------------

--- Stage-Level ---
  Taxon/Group                  |   Orig ρ |    CLR ρ |    Diff |      CLR p
  ---------------------------------------------------------------------------
  Labechiida                   |   +0.075 |   -0.364 |  -0.439 |     0.0406 *
  Clathrodictyida              |   +0.520 |   +0.416 |  -0.104 |     0.0178 *
  Actinostromatida             |   +0.712 |   +0.405 |  -0.307 |     0.0213 *
  Stromatoporida               |   +0.679 |   +0.359 |  -0.320 |     0.0433 *
  Stromatoporelli

In [11]:
# =============================================================================
# @title CELL 11: CORRELATION FUNCTIONS - SPEARMAN AND PEARSON
# =============================================================================
import numpy as np
import scipy.stats as stats

"""
CORRELATION ANALYSIS METHODS
============================
This cell defines functions for both Spearman and Pearson correlations.

SPEARMAN'S RHO (ρ):
- Non-parametric rank correlation
- Measures monotonic relationships (not just linear)
- Robust to outliers and non-normal distributions
- Appropriate for ordinal data or data with outliers
- Reference: Spearman, C. (1904). American Journal of Psychology, 15(1), 72-101.

PEARSON'S r:
- Parametric correlation coefficient
- Measures linear relationships specifically
- Assumes normally distributed variables
- More powerful when assumptions are met
- Reference: Pearson, K. (1895). Philosophical Transactions of the Royal Society A, 186, 343-414.

For geological time series:
- Spearman is often preferred due to non-normal distributions
- Pearson provides information on linear relationships
- Presenting both allows comparison and assessment of relationship type
"""

def calc_spearman(data, v1, v2):
    """
    Calculate Spearman rank correlation with significance

    Parameters:
    -----------
    data : DataFrame
        Data containing the variables
    v1, v2 : str
        Column names to correlate

    Returns:
    --------
    rho : float
        Spearman correlation coefficient
    p : float
        Two-tailed p-value
    n : int
        Number of valid pairs
    """
    # Check if columns exist
    if v1 not in data.columns or v2 not in data.columns:
        return np.nan, np.nan, 0

    valid = data[[v1, v2]].dropna()

    if len(valid) >= 5:
        # Spearmanr returns a Result object or tuple depending on version, generic unpacking is safer
        result = stats.spearmanr(valid[v1], valid[v2])
        # Handle cases where result might be a struct or tuple
        try:
            r, p = result.correlation, result.pvalue
        except AttributeError:
            r, p = result[0], result[1]

        return r, p, len(valid)

    return np.nan, np.nan, 0

def calc_pearson(data, v1, v2):
    """
    Calculate Pearson correlation with significance

    Parameters:
    -----------
    data : DataFrame
        Data containing the variables
    v1, v2 : str
        Column names to correlate

    Returns:
    --------
    r : float
        Pearson correlation coefficient
    p : float
        Two-tailed p-value
    n : int
        Number of valid pairs
    """
    if v1 not in data.columns or v2 not in data.columns:
        return np.nan, np.nan, 0

    valid = data[[v1, v2]].dropna()

    if len(valid) >= 5:
        r, p = stats.pearsonr(valid[v1], valid[v2])
        return r, p, len(valid)

    return np.nan, np.nan, 0

def calc_both_correlations(data, v1, v2):
    """
    Calculate both Spearman and Pearson correlations

    Returns:
    --------
    dict with both correlation results
    """
    spearman_r, spearman_p, n = calc_spearman(data, v1, v2)
    pearson_r, pearson_p, _ = calc_pearson(data, v1, v2)

    return {
        'spearman_rho': spearman_r,
        'spearman_p': spearman_p,
        'pearson_r': pearson_r,
        'pearson_p': pearson_p,
        'n': n
    }

def get_significance_stars(p):
    """Convert p-value to significance stars"""
    if p is None or np.isnan(p):
        return ''
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return 'ns'

print("✓ Correlation functions defined (Spearman and Pearson)")

✓ Correlation functions defined (Spearman and Pearson)


In [12]:
# =============================================================================
# @title CELL 12: COMPREHENSIVE CORRELATION ANALYSIS (MULTI-METRIC & 5-MYR BIOLOGY)
#   - NO T/W ratio computed or compared
#   - Excludes rows with NO stromatoporoids (strom_total_occ<=0) from statistics
#   - Includes δ13C (expects column name 'd13C')
#   - Standardizes atmospheric O2/CO2 column names for correlation output:
#       atmospheric_O2, atmospheric_CO2
# =============================================================================
import numpy as np
import pandas as pd
import scipy.stats as stats
import re

print("="*90)
print("COMPREHENSIVE CORRELATION ANALYSIS")
print("Metrics: Thickness, Width")
print("Scopes:  Stage-Level AND 5-Myr Bins")
print("Notes:   Excluding strom_total_occ<=0 rows; NO T/W ratio; includes δ13C + atm O2/CO2 if present")
print("="*90)

# -----------------------------------------------------------------------------
# 0. SAFETY: define STROM_ORDERS if missing
# -----------------------------------------------------------------------------
if 'STROM_ORDERS' not in globals():
    STROM_ORDERS = [
        'Labechiida', 'Clathrodictyida', 'Actinostromatida',
        'Stromatoporida', 'Stromatoporellida', 'Syringostromatida', 'Amphiporida'
    ]
    print("[WARN] STROM_ORDERS not found in globals(); using default list.")

# -----------------------------------------------------------------------------
# 1. PRE-PROCESSING: CALCULATE DERIVED VARIABLES (NO T/W)
# -----------------------------------------------------------------------------
print("Pre-processing data...")

def calculate_derived_metrics(dataset, label):
    if dataset is None or dataset.empty:
        print(f"  - [{label}] Empty dataset; skip derived metrics.")
        return dataset

    dataset = dataset.copy()

    # Remove T/W ratio if it exists from older runs (do NOT compute it)
    if 'thickness_width_ratio' in dataset.columns:
        dataset = dataset.drop(columns=['thickness_width_ratio'])

    # Stromatoporoid Proportions & Groups
    strom_cols = [c for c in dataset.columns if c.endswith('_occ') and c != 'strom_total_occ']
    if strom_cols:
        if 'strom_total_occ' not in dataset.columns:
            dataset['strom_total_occ'] = dataset[strom_cols].sum(axis=1)

        # Ensure numeric
        dataset['strom_total_occ'] = pd.to_numeric(dataset['strom_total_occ'], errors='coerce')

        for order in STROM_ORDERS:
            col_occ = f'{order}_occ'
            if col_occ in dataset.columns:
                dataset[col_occ] = pd.to_numeric(dataset[col_occ], errors='coerce').fillna(0)
                dataset[f'{order}_prop'] = np.where(
                    dataset['strom_total_occ'] > 0,
                    dataset[col_occ] / dataset['strom_total_occ'],
                    np.nan
                )

        derived_orders = ['Actinostromatida', 'Stromatoporida', 'Stromatoporellida',
                          'Syringostromatida', 'Amphiporida']
        basal_orders   = ['Labechiida', 'Clathrodictyida']

        # occurrences
        dataset['derived_strom_occ'] = 0.0
        dataset['basal_strom_occ']   = 0.0
        for o in derived_orders:
            c = f'{o}_occ'
            if c in dataset.columns:
                dataset['derived_strom_occ'] += pd.to_numeric(dataset[c], errors='coerce').fillna(0)
        for o in basal_orders:
            c = f'{o}_occ'
            if c in dataset.columns:
                dataset['basal_strom_occ'] += pd.to_numeric(dataset[c], errors='coerce').fillna(0)

        # diversity (genus)
        dataset['derived_strom_div'] = 0.0
        dataset['basal_strom_div']   = 0.0
        for o in derived_orders:
            c = f'{o}_genus'
            if c in dataset.columns:
                dataset['derived_strom_div'] += pd.to_numeric(dataset[c], errors='coerce').fillna(0)
        for o in basal_orders:
            c = f'{o}_genus'
            if c in dataset.columns:
                dataset['basal_strom_div'] += pd.to_numeric(dataset[c], errors='coerce').fillna(0)

        dataset['derived_strom_prop'] = np.where(
            dataset['strom_total_occ'] > 0,
            dataset['derived_strom_occ'] / dataset['strom_total_occ'],
            np.nan
        )
        dataset['basal_strom_prop'] = np.where(
            dataset['strom_total_occ'] > 0,
            dataset['basal_strom_occ'] / dataset['strom_total_occ'],
            np.nan
        )

        print(f"  [OK] [{label}] Calculated Stromatoporoid proportions and groupings")
    else:
        print(f"  - [{label}] No Stromatoporoid occurrence data found.")

    return dataset

df = calculate_derived_metrics(df, "Stage")
df_5myr = calculate_derived_metrics(df_5myr, "5-Myr")

# -----------------------------------------------------------------------------
# 1A. STANDARDIZE ATMOSPHERIC O2/CO2 COLUMN NAMES (Stage + 5-Myr)
#   Outputs standardized columns:
#     - atmospheric_O2
#     - atmospheric_CO2
# -----------------------------------------------------------------------------
def standardize_atm_cols(dataset, label):
    if dataset is None or dataset.empty:
        return dataset
    dataset = dataset.copy()

    def _find_col(candidates, regex_pat=None):
        for c in candidates:
            if c in dataset.columns:
                return c
        if regex_pat is not None:
            hits = [c for c in dataset.columns if re.search(regex_pat, c, flags=re.IGNORECASE)]
            return hits[0] if hits else None
        return None

    # Atmospheric O2
    if 'atmospheric_O2' not in dataset.columns:
        o2_src = _find_col(
            candidates=['atm_O2','atmospheric_o2','oxygen','pO2','PO2','O2_atm','oxygen_atm'],
            regex_pat=r'(atm|atmos).*o2|po2'
        )
        if o2_src is not None:
            dataset['atmospheric_O2'] = pd.to_numeric(dataset[o2_src], errors='coerce')
            print(f"  [OK] [{label}] atmospheric_O2 <- {o2_src}")
        else:
            print(f"  [WARN] [{label}] No atmospheric O2 column found (atmospheric_O2 will be missing).")

    # Atmospheric CO2
    if 'atmospheric_CO2' not in dataset.columns:
        co2_src = _find_col(
            candidates=['atm_CO2','atmospheric_co2','co2','pCO2','PCO2','CO2_atm','co2_atm'],
            regex_pat=r'(atm|atmos).*co2|pco2'
        )
        if co2_src is not None:
            dataset['atmospheric_CO2'] = pd.to_numeric(dataset[co2_src], errors='coerce')
            print(f"  [OK] [{label}] atmospheric_CO2 <- {co2_src}")
        else:
            print(f"  [WARN] [{label}] No atmospheric CO2 column found (atmospheric_CO2 will be missing).")

    return dataset

df = standardize_atm_cols(df, "Stage")
df_5myr = standardize_atm_cols(df_5myr, "5-Myr")

# -----------------------------------------------------------------------------
# 1B. FILTER: EXCLUDE ROWS WITH NO STROMATOPOROIDS FROM STATISTICS
# -----------------------------------------------------------------------------
def filter_has_strom(df_in, label):
    if df_in is None or df_in.empty:
        print(f"  - [{label}] Empty dataset; nothing to filter.")
        return df_in
    if 'strom_total_occ' not in df_in.columns:
        print(f"  [WARN] [{label}] 'strom_total_occ' not found; keeping all rows.")
        return df_in
    df_out = df_in[df_in['strom_total_occ'].notna() & (pd.to_numeric(df_in['strom_total_occ'], errors='coerce') > 0)].copy()
    removed = len(df_in) - len(df_out)
    print(f"  - [{label}] Excluding strom_total_occ<=0 rows: removed {removed}, kept {len(df_out)}.")
    return df_out

df_corr = filter_has_strom(df, "Stage")
df_5myr_corr = filter_has_strom(df_5myr, "5-Myr")

# -----------------------------------------------------------------------------
# 2. DEFINE VARIABLE GROUPS
# -----------------------------------------------------------------------------
reef_targets = [
    ('thickness_mean', 'Thickness (Log)'),
    ('width_mean', 'Width (Log)')
]

strom_prop_vars = [('derived_strom_prop', 'Derived Proportion'), ('basal_strom_prop', 'Basal Proportion')] + \
                  [(f'{order}_prop', f'{order} Prop') for order in STROM_ORDERS]

strom_occ_vars = [('strom_total_occ', 'Total Occurrence'), ('derived_strom_occ', 'Derived Occurrence'), ('basal_strom_occ', 'Basal Occurrence')] + \
                 [(f'{order}_occ', f'{order} Occ') for order in STROM_ORDERS]

strom_div_vars = [('strom_total_gen', 'Total Diversity'), ('derived_strom_div', 'Derived Diversity'), ('basal_strom_div', 'Basal Diversity')] + \
                 [(f'{order}_genus', f'{order} Div') for order in STROM_ORDERS]

coral_div_vars = [('rugose_div', 'Rugose Diversity'), ('tabulate_div', 'Tabulate Diversity')]
coral_occ_vars = [('rugose_occ', 'Rugose Occurrence'), ('tabulate_occ', 'Tabulate Occurrence')]

env_vars_macrostrat = [
    ('total_area_km2', 'Total Area'),
    ('carbonate_area_km2', 'Carb Area'),
    ('carbonate_percentage', 'Carb %')
]

# IMPORTANT: δ13C expected as 'd13C'
env_vars_proxies = [
    ('temperature', 'SST'),
    ('sea_level', 'Sea Level'),
    ('atmospheric_O2', 'Atm O2'),
    ('atmospheric_CO2', 'Atm CO2'),
    ('dissolved_O2', 'Dissolved O2'),
    ('d13C', 'δ¹³C')
]

var_groups = [
    ('Strom Props', strom_prop_vars),
    ('Strom Occ', strom_occ_vars),
    ('Strom Div', strom_div_vars),
    ('Coral Div', coral_div_vars),
    ('Coral Occ', coral_occ_vars),
    ('Macrostrat', env_vars_macrostrat),
    ('Proxies', env_vars_proxies),
]

# -----------------------------------------------------------------------------
# 3. ANALYSIS FUNCTIONS
# -----------------------------------------------------------------------------
def get_significance_stars(p):
    if p is None or np.isnan(p):
        return ""
    if p < 0.001:
        return "***"
    if p < 0.01:
        return "**"
    if p < 0.05:
        return "*"
    return ""

def calc_stats(x, y):
    mask = np.isfinite(x) & np.isfinite(y)
    x = x[mask]
    y = y[mask]
    n = len(x)
    if n < 5:
        return dict(n=n, spearman_rho=np.nan, spearman_p=np.nan, pearson_r=np.nan, pearson_p=np.nan)
    sr = stats.spearmanr(x, y)
    pr = stats.pearsonr(x, y)
    return dict(
        n=n,
        spearman_rho=float(sr.correlation),
        spearman_p=float(sr.pvalue),
        pearson_r=float(pr.statistic),
        pearson_p=float(pr.pvalue)
    )

def run_correlation_suite(dataset, target_col, target_name):
    if dataset is None or dataset.empty or target_col not in dataset.columns:
        print(f"[WARN] Missing target {target_col} or dataset empty; skipping {target_name}.")
        return []

    print("\n" + "-"*90)
    print(f"TARGET: {target_name.upper()}  |  rows used (after no-strom filter): {len(dataset)}")
    print("-"*90)
    print("{:<35} {:>8} {:>10} {:>8} {:>10} {:>5}".format("Variable", "rho", "p(rho)", "r", "p(r)", "n"))

    results = []
    y = pd.to_numeric(dataset[target_col], errors='coerce').values.astype(float)

    for group_name, group_vars in var_groups:
        for var, label in group_vars:
            if var not in dataset.columns:
                continue

            x = pd.to_numeric(dataset[var], errors='coerce').values.astype(float)
            s = calc_stats(x, y)

            # Only print/store if enough data
            if s['n'] >= 5 and np.isfinite(s['spearman_rho']):
                s_sig = get_significance_stars(s['spearman_p'])
                p_sig = get_significance_stars(s['pearson_p'])

                print(f"{label:35s} {s['spearman_rho']:8.2f}{s_sig:3s} {s['spearman_p']:10.3g} "
                      f"{s['pearson_r']:8.2f}{p_sig:3s} {s['pearson_p']:10.3g} {s['n']:5d}")

                results.append({
                    'Scope': 'Stage' if dataset is df_corr else '5-Myr',
                    'Target': target_col,
                    'Predictor': var,
                    'Label': label,
                    'Group': group_name,
                    'Spearman_Rho': s['spearman_rho'],
                    'Spearman_P': s['spearman_p'],
                    'Pearson_R': s['pearson_r'],
                    'Pearson_P': s['pearson_p'],
                    'N': int(s['n'])
                })

    return results

# -----------------------------------------------------------------------------
# 4. RUN: Stage + 5-Myr
# -----------------------------------------------------------------------------
all_results_stage = []
for target_col, target_name in reef_targets:
    all_results_stage.extend(run_correlation_suite(df_corr, target_col, target_name))

all_results_5myr = []
for target_col, target_name in reef_targets:
    all_results_5myr.extend(run_correlation_suite(df_5myr_corr, target_col, target_name))

stage_results_df = pd.DataFrame(all_results_stage)
myr_results_df = pd.DataFrame(all_results_5myr)

if 'OUTPUT_DIR' in globals():
    if not stage_results_df.empty:
        stage_results_df.to_csv(f"{OUTPUT_DIR}/results_correlations_stage.csv", index=False, encoding="utf-8-sig")
        print(f"\nSaved: {OUTPUT_DIR}/results_correlations_stage.csv")
    else:
        print("\n[WARN] Stage correlations empty (nothing met n>=5 criteria).")

    if not myr_results_df.empty:
        myr_results_df.to_csv(f"{OUTPUT_DIR}/results_correlations_5myr.csv", index=False, encoding="utf-8-sig")
        print(f"Saved: {OUTPUT_DIR}/results_correlations_5myr.csv")
    else:
        print("[WARN] 5-Myr correlations empty (nothing met n>=5 criteria).")

# Show quick views
display(stage_results_df if not stage_results_df.empty else pd.DataFrame())
display(myr_results_df if not myr_results_df.empty else pd.DataFrame())


COMPREHENSIVE CORRELATION ANALYSIS
Metrics: Thickness, Width
Scopes:  Stage-Level AND 5-Myr Bins
Notes:   Excluding strom_total_occ<=0 rows; NO T/W ratio; includes δ13C + atm O2/CO2 if present
Pre-processing data...
  [OK] [Stage] Calculated Stromatoporoid proportions and groupings
  [OK] [5-Myr] Calculated Stromatoporoid proportions and groupings
  [OK] [Stage] atmospheric_O2 <- atm_O2
  [OK] [Stage] atmospheric_CO2 <- atm_CO2
  [OK] [5-Myr] atmospheric_O2 <- atm_O2
  [OK] [5-Myr] atmospheric_CO2 <- atm_CO2
  - [Stage] Excluding strom_total_occ<=0 rows: removed 11, kept 21.
  - [5-Myr] Excluding strom_total_occ<=0 rows: removed 13, kept 24.

------------------------------------------------------------------------------------------
TARGET: THICKNESS (LOG)  |  rows used (after no-strom filter): 21
------------------------------------------------------------------------------------------
Variable                                 rho     p(rho)        r       p(r)     n
Derived Proportion 

Unnamed: 0,Scope,Target,Predictor,Label,Group,Spearman_Rho,Spearman_P,Pearson_R,Pearson_P,N
0,Stage,thickness_mean,derived_strom_prop,Derived Proportion,Strom Props,0.872563,2.501763e-07,0.853369,8.771189e-07,21
1,Stage,thickness_mean,basal_strom_prop,Basal Proportion,Strom Props,-0.872563,2.501763e-07,-0.853369,8.771189e-07,21
2,Stage,thickness_mean,Labechiida_prop,Labechiida Prop,Strom Props,-0.778648,3.214033e-05,-0.733873,1.527280e-04,21
3,Stage,thickness_mean,Clathrodictyida_prop,Clathrodictyida Prop,Strom Props,0.047557,8.378025e-01,-0.151025,5.134546e-01,21
4,Stage,thickness_mean,Actinostromatida_prop,Actinostromatida Prop,Strom Props,0.613505,3.098310e-03,0.595730,4.377395e-03,21
...,...,...,...,...,...,...,...,...,...,...
79,Stage,width_mean,sea_level,Sea Level,Proxies,-0.037675,8.712021e-01,-0.007430,9.745005e-01,21
80,Stage,width_mean,atmospheric_O2,Atm O2,Proxies,0.105229,6.498606e-01,0.107905,6.415237e-01,21
81,Stage,width_mean,atmospheric_CO2,Atm CO2,Proxies,-0.048068,8.360847e-01,-0.058657,8.006085e-01,21
82,Stage,width_mean,dissolved_O2,Dissolved O2,Proxies,0.049367,8.317147e-01,0.074091,7.495915e-01,21


Unnamed: 0,Scope,Target,Predictor,Label,Group,Spearman_Rho,Spearman_P,Pearson_R,Pearson_P,N
0,5-Myr,thickness_mean,derived_strom_prop,Derived Proportion,Strom Props,0.878380,1.673329e-08,0.927584,7.061008e-11,24
1,5-Myr,thickness_mean,basal_strom_prop,Basal Proportion,Strom Props,-0.878380,1.673329e-08,-0.927584,7.061008e-11,24
2,5-Myr,thickness_mean,Labechiida_prop,Labechiida Prop,Strom Props,-0.828941,5.599693e-07,-0.825183,6.980943e-07,24
3,5-Myr,thickness_mean,Clathrodictyida_prop,Clathrodictyida Prop,Strom Props,0.144726,4.998510e-01,-0.068724,7.496635e-01,24
4,5-Myr,thickness_mean,Actinostromatida_prop,Actinostromatida Prop,Strom Props,0.708982,1.051787e-04,0.666808,3.730702e-04,24
...,...,...,...,...,...,...,...,...,...,...
79,5-Myr,width_mean,sea_level,Sea Level,Proxies,-0.188113,3.787188e-01,-0.114646,5.937347e-01,24
80,5-Myr,width_mean,atmospheric_O2,Atm O2,Proxies,0.065752,7.601729e-01,0.095466,6.572337e-01,24
81,5-Myr,width_mean,atmospheric_CO2,Atm CO2,Proxies,-0.058350,7.865258e-01,-0.088770,6.799854e-01,24
82,5-Myr,width_mean,dissolved_O2,Dissolved O2,Proxies,0.094927,6.590544e-01,0.095234,6.580181e-01,24
