<a href="https://colab.research.google.com/github/Jeong-HyunLee/stromatoporoid-reef/blob/main/stromatoporoid_reef_size_v17.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# =============================================================================
#@title CELL 1: SETUP AND IMPORTS
# =============================================================================

# Install required packages (uncomment if needed)
# !pip install openpyxl geopandas shapely requests

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from scipy import stats
from scipy.interpolate import interp1d
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from sklearn.utils import resample
import statsmodels.api as sm
import warnings
import os
warnings.filterwarnings('ignore')

# Set up matplotlib for publication-quality vector figures
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['pdf.fonttype'] = 42  # TrueType fonts in PDF
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'  # Text as text in SVG

print("="*70)
print("STROMATOPOROID TURNOVER AND REEF MORPHOLOGY ANALYSIS")
print("WITH PEARSON AND SPEARMAN CORRELATIONS")
print("STAGE-LEVEL AND 5-MYR BIN ANALYSIS")
print("="*70)
print("\nLibraries loaded successfully!")

# Output directory
import os
OUTPUT_DIR = "./output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

STROMATOPOROID TURNOVER AND REEF MORPHOLOGY ANALYSIS
WITH PEARSON AND SPEARMAN CORRELATIONS
STAGE-LEVEL AND 5-MYR BIN ANALYSIS

Libraries loaded successfully!
Output directory: ./output


In [2]:
# =============================================================================
#@title CELL 2: GENERATE MACROSTRAT DATA (paleozoic_stage_data.csv, paleozoic_5myr_data.csv)
# =============================================================================

print("="*70)
print("GENERATING MACROSTRAT DATA")
print("="*70)

# Check if files already exist
macrostrat_stage_file = 'paleozoic_stage_data.csv'
macrostrat_5myr_file = 'paleozoic_5myr_data.csv'

if os.path.exists(macrostrat_stage_file) and os.path.exists(macrostrat_5myr_file):
    print(f"✓ {macrostrat_stage_file} already exists")
    print(f"✓ {macrostrat_5myr_file} already exists")
    print("Skipping Macrostrat data generation...")
else:
    print("Generating Macrostrat data from API...")

    import requests
    try:
        import geopandas as gpd
    except ImportError:
        print("Installing geopandas...")
        import subprocess
        subprocess.run(['pip', 'install', 'geopandas', '-q'])
        import geopandas as gpd

    # Define Paleozoic Period Age Ranges
    periods = {
        "Ordovician": {"start": 485.4, "end": 443.8, "color": "#00a9ce"},
        "Silurian": {"start": 443.8, "end": 419.2, "color": "#b3e1af"},
        "Devonian": {"start": 419.2, "end": 358.9, "color": "#cb8c37"}
    }

    # Define stages
    ordovician_stages = {
        "Tremadocian": (478.6, 485.4), "Floian": (470.0, 478.6),
        "Dapingian": (467.3, 470.0), "Darriwilian": (458.4, 467.3),
        "Sandbian": (453.0, 458.4), "Katian": (445.2, 453.0),
        "Hirnantian": (443.8, 445.2)
    }
    silurian_stages = {
        "Rhuddanian": (440.8, 443.8), "Aeronian": (438.5, 440.8),
        "Telychian": (433.4, 438.5), "Sheinwoodian": (430.5, 433.4),
        "Homerian": (427.4, 430.5), "Gorstian": (425.6, 427.4),
        "Ludfordian": (423.0, 425.6), "Pridolian": (419.2, 423.0)
    }
    devonian_stages = {
        "Lochkovian": (410.8, 419.2), "Pragian": (407.6, 410.8),
        "Emsian": (393.3, 407.6), "Eifelian": (387.7, 393.3),
        "Givetian": (382.7, 387.7), "Frasnian": (372.2, 382.7),
        "Famennian": (358.9, 372.2)
    }

    # Create stages dataframe
    stages_data = []
    for stage, (end_age, start_age) in ordovician_stages.items():
        stages_data.append({"stage": stage, "start_age": start_age, "end_age": end_age,
                           "mid_age": (start_age + end_age) / 2, "period": "Ordovician"})
    for stage, (end_age, start_age) in silurian_stages.items():
        stages_data.append({"stage": stage, "start_age": start_age, "end_age": end_age,
                           "mid_age": (start_age + end_age) / 2, "period": "Silurian"})
    for stage, (end_age, start_age) in devonian_stages.items():
        stages_data.append({"stage": stage, "start_age": start_age, "end_age": end_age,
                           "mid_age": (start_age + end_age) / 2, "period": "Devonian"})
    stages_df = pd.DataFrame(stages_data)

    # Retrieve Macrostrat Data
    periods_to_fetch = ["Ordovician", "Silurian", "Devonian"]
    all_units_list = []

    for period in periods_to_fetch:
        url = f"https://macrostrat.org/api/units?interval_name={period}&format=geojson&response=long"
        print(f"  Fetching data for {period}...")
        try:
            response = requests.get(url, timeout=60)
            if response.status_code == 200:
                data = response.json()
                features = data.get("success", {}).get("data", [])
                if features:
                    period_units = gpd.GeoDataFrame.from_features(features)
                    print(f"    Retrieved {len(period_units)} geological units")
                    period_units['source_period'] = period
                    all_units_list.append(period_units)
        except Exception as e:
            print(f"    Error fetching {period}: {e}")

    if all_units_list:
        units = pd.concat(all_units_list, ignore_index=True)
        print(f"  Combined dataset: {len(units)} geological units")

        # Process units
        try:
            if units.crs is None:
                units.set_crs(epsg=4326, inplace=True)
            units = units.to_crs(epsg=3857)
            if 'col_area' in units.columns:
                units['area_km2'] = pd.to_numeric(units['col_area'], errors='coerce')
            else:
                units['area_km2'] = units.geometry.area / 1e6
        except:
            units['area_km2'] = 100  # Default

        units['t_age'] = pd.to_numeric(units['t_age'], errors='coerce')
        units['b_age'] = pd.to_numeric(units['b_age'], errors='coerce')
        units['mid_age'] = (units['t_age'] + units['b_age']) / 2.0
        units.dropna(subset=['mid_age'], inplace=True)

        # Identify carbonates
        def check_if_carbonate(lithologies):
            if isinstance(lithologies, list):
                for lith in lithologies:
                    if isinstance(lith, dict) and 'type' in lith and 'carbonate' in str(lith['type']).lower():
                        return True
            elif isinstance(lithologies, str):
                return 'carbonate' in lithologies.lower()
            return False

        units['is_carbonate'] = units['lith'].apply(check_if_carbonate)
        carbonate_units = units[units['is_carbonate']].copy()

        # Assign stages
        all_stages = {**ordovician_stages, **silurian_stages, **devonian_stages}
        def assign_stage(age):
            for stage, (end, start) in all_stages.items():
                if start >= age >= end:
                    return stage
            return None

        units['stage'] = units['mid_age'].apply(assign_stage)
        carbonate_units['stage'] = carbonate_units['mid_age'].apply(assign_stage)

        # Aggregate by stage
        stage_totals = units.groupby('stage')['area_km2'].sum().reset_index()
        stage_totals.rename(columns={'area_km2': 'total_area_km2'}, inplace=True)
        stage_carbonates = carbonate_units.groupby('stage')['area_km2'].sum().reset_index()
        stage_carbonates.rename(columns={'area_km2': 'carbonate_area_km2'}, inplace=True)

        stage_summary = pd.merge(stage_totals, stage_carbonates, on='stage', how='left')
        stage_summary['carbonate_area_km2'] = stage_summary['carbonate_area_km2'].fillna(0)
        stage_summary['carbonate_percentage'] = (stage_summary['carbonate_area_km2'] / stage_summary['total_area_km2']) * 100

        macrostrat_data = pd.merge(stages_df, stage_summary, on='stage', how='left')
        macrostrat_data = macrostrat_data.sort_values('start_age', ascending=False).reset_index(drop=True)
        macrostrat_data.to_csv(macrostrat_stage_file, index=False)
        print(f"  ✓ Saved {macrostrat_stage_file}")

        # 5 Myr bins
        max_age = 490
        min_age = 355
        manual_bins = np.arange(min_age, max_age + 5, 5)

        units['time_bin'] = pd.cut(units['mid_age'], bins=manual_bins, include_lowest=True, right=False)
        carbonate_units['time_bin'] = pd.cut(carbonate_units['mid_age'], bins=manual_bins, include_lowest=True, right=False)

        macro_all_5myr = units.groupby('time_bin')['area_km2'].sum().reset_index()
        macro_all_5myr.rename(columns={'area_km2': 'total_area_km2'}, inplace=True)
        macro_carb_5myr = carbonate_units.groupby('time_bin')['area_km2'].sum().reset_index()
        macro_carb_5myr.rename(columns={'area_km2': 'carbonate_area_km2'}, inplace=True)

        macrostrat_5myr = pd.merge(macro_all_5myr, macro_carb_5myr, on='time_bin', how='left')
        macrostrat_5myr['carbonate_area_km2'] = macrostrat_5myr['carbonate_area_km2'].fillna(0)
        macrostrat_5myr['carbonate_percentage'] = (macrostrat_5myr['carbonate_area_km2'] / macrostrat_5myr['total_area_km2']) * 100
        macrostrat_5myr['bin_mid'] = macrostrat_5myr['time_bin'].apply(lambda x: (x.left + x.right) / 2 if pd.notna(x) else np.nan)
        macrostrat_5myr.to_csv(macrostrat_5myr_file, index=False)
        print(f"  ✓ Saved {macrostrat_5myr_file}")
    else:
        print("  WARNING: Could not fetch Macrostrat data. Creating placeholder files...")
        # Create placeholder files
        pd.DataFrame(columns=['stage', 'total_area_km2', 'carbonate_area_km2', 'carbonate_percentage']).to_csv(macrostrat_stage_file, index=False)
        pd.DataFrame(columns=['bin_mid', 'total_area_km2', 'carbonate_area_km2', 'carbonate_percentage']).to_csv(macrostrat_5myr_file, index=False)

print("✓ Macrostrat data ready")

GENERATING MACROSTRAT DATA
Generating Macrostrat data from API...
  Fetching data for Ordovician...
    Retrieved 2943 geological units
  Fetching data for Silurian...
    Retrieved 1715 geological units
  Fetching data for Devonian...
    Retrieved 2793 geological units
  Combined dataset: 7451 geological units
  ✓ Saved paleozoic_stage_data.csv
  ✓ Saved paleozoic_5myr_data.csv
✓ Macrostrat data ready


In [3]:
# =============================================================================
#@title CELL 3: GENERATE PARED REEF DATA (reef stage and 5myr files)
# =============================================================================

print("\n" + "="*70)
print("GENERATING PARED REEF DATA")
print("="*70)

reef_stage_file = 'ordovician_devonian_reef_data_stage_for_analysis.csv'
reef_5myr_file = 'ordovician_devonian_reef_data_5myr_for_analysis.csv'
pared_source_file = 'PARED_reef_All_numerical.csv'

# Force regeneration to include new variable
if os.path.exists(reef_stage_file) and os.path.exists(reef_5myr_file) and False:
    print(f"✓ {reef_stage_file} already exists")
    print(f"✓ {reef_5myr_file} already exists")
    print("Skipping PARED reef data generation...")
else:
    # Check if source file exists
    if not os.path.exists(pared_source_file):
        print(f"Source file '{pared_source_file}' not found.")
        print("Please upload it now:")
        try:
            from google.colab import files
            uploaded_pared = files.upload()
            uploaded_name = list(uploaded_pared.keys())[0]
            if uploaded_name != pared_source_file:
                os.rename(uploaded_name, pared_source_file)
        except ImportError:
            raise FileNotFoundError(f"Please place '{pared_source_file}' in the current directory.")

    print(f"Processing {pared_source_file}...")

    # Load PARED data
    try:
        pared_df = pd.read_csv(pared_source_file, encoding='utf-8')
    except UnicodeDecodeError:
        try:
            pared_df = pd.read_csv(pared_source_file, encoding='latin-1')
        except:
            pared_df = pd.read_csv(pared_source_file, encoding='cp1252')

    # --- NEW CODE: Calculate Paired Ratio (Log Difference) ---
    # Ensure numeric
    pared_df['thickness'] = pd.to_numeric(pared_df['thickness'], errors='coerce')
    pared_df['width'] = pd.to_numeric(pared_df['width'], errors='coerce')

    # Calculate difference (Log Thickness - Log Width)
    # This automatically becomes NaN if either value is missing
    pared_df['t_w_log_ratio'] = pared_df['thickness'] - pared_df['width']
    # ---------------------------------------------------------

    def analyze_pared_data(dataframe, bin_definitions, analysis_type_label):
        """Calculate statistics using OVERLAP method"""
        results = []
        # Added 't_w_log_ratio' to variables
        variables = ['thickness', 'width', 'extension', 't_w_log_ratio']

        for bin_def in bin_definitions:
            s_start = bin_def['start_ma']
            s_end = bin_def['end_ma']
            name = bin_def['time_identifier']

            # OVERLAP LOGIC
            mask = (dataframe['min_ma'] < s_end) & (dataframe['max_ma'] > s_start)
            subset = dataframe[mask]

            row_data = {
                'bin_center': (s_start + s_end) / 2.0,
                'start_age': s_start,
                'end_age': s_end,
                'reef_count': len(subset),
            }

            for var in variables:
                data = subset[var].dropna() if var in subset.columns else pd.Series()
                if len(data) > 0:
                    row_data[f'{var}_mean'] = data.mean()
                    row_data[f'{var}_std'] = data.std()
                    row_data[f'{var}_stderr'] = data.sem()
                    row_data[f'{var}_median'] = data.median()
                    row_data[f'{var}_count'] = len(data)
                else:
                    for suffix in ['mean', 'std', 'stderr', 'median']:
                        row_data[f'{var}_{suffix}'] = np.nan
                    row_data[f'{var}_count'] = 0

            row_data['analysis_type'] = analysis_type_label
            row_data['time_identifier'] = name
            row_data['name'] = name
            row_data['start_ma'] = s_start
            row_data['end_ma'] = s_end
            row_data['midpoint_ma'] = (s_start + s_end) / 2.0
            row_data['duration_myr'] = round(s_end - s_start, 1)

            results.append(row_data)

        return pd.DataFrame(results)

    # Stage definitions (PRESERVED FROM ORIGINAL)
    stages_data = [
        {'time_identifier': 'Famennian', 'start_ma': 358.9, 'end_ma': 372.2},
        {'time_identifier': 'Frasnian', 'start_ma': 372.2, 'end_ma': 382.7},
        {'time_identifier': 'Givetian', 'start_ma': 382.7, 'end_ma': 387.7},
        {'time_identifier': 'Eifelian', 'start_ma': 387.7, 'end_ma': 393.3},
        {'time_identifier': 'Emsian', 'start_ma': 393.3, 'end_ma': 407.6},
        {'time_identifier': 'Pragian', 'start_ma': 407.6, 'end_ma': 410.8},
        {'time_identifier': 'Lochkovian', 'start_ma': 410.8, 'end_ma': 419.2},
        {'time_identifier': 'Pridoli', 'start_ma': 419.2, 'end_ma': 423.0},
        {'time_identifier': 'Ludfordian', 'start_ma': 423.0, 'end_ma': 425.6},
        {'time_identifier': 'Gorstian', 'start_ma': 425.6, 'end_ma': 427.4},
        {'time_identifier': 'Homerian', 'start_ma': 427.4, 'end_ma': 430.5},
        {'time_identifier': 'Sheinwoodian', 'start_ma': 430.5, 'end_ma': 433.4},
        {'time_identifier': 'Telychian', 'start_ma': 433.4, 'end_ma': 438.5},
        {'time_identifier': 'Aeronian', 'start_ma': 438.5, 'end_ma': 440.8},
        {'time_identifier': 'Rhuddanian', 'start_ma': 440.8, 'end_ma': 443.8},
        {'time_identifier': 'Hirnantian', 'start_ma': 443.8, 'end_ma': 445.2},
        {'time_identifier': 'Katian', 'start_ma': 445.2, 'end_ma': 453.0},
        {'time_identifier': 'Sandbian', 'start_ma': 453.0, 'end_ma': 458.4},
        {'time_identifier': 'Darriwilian', 'start_ma': 458.4, 'end_ma': 467.3},
        {'time_identifier': 'Dapingian', 'start_ma': 467.3, 'end_ma': 470.0},
        {'time_identifier': 'Floian', 'start_ma': 470.0, 'end_ma': 477.7},
        {'time_identifier': 'Tremadocian', 'start_ma': 477.7, 'end_ma': 485.4},
    ]

    # 5-Myr bins
    bins_5myr = []
    for age in range(355, 490, 5):
        bins_5myr.append({
            'time_identifier': f"{age}-{age+5} Ma",
            'start_ma': float(age),
            'end_ma': float(age + 5)
        })

    # Generate stage data
    df_stages = analyze_pared_data(pared_df, stages_data, 'Geological_Stages')
    df_stages.to_csv(reef_stage_file, index=False)
    print(f"  ✓ Generated {reef_stage_file}")

    # Generate 5myr data
    df_5myr = analyze_pared_data(pared_df, bins_5myr, '5_myr_bins')
    df_5myr.to_csv(reef_5myr_file, index=False)
    print(f"  ✓ Generated {reef_5myr_file}")

print("✓ PARED reef data ready")


GENERATING PARED REEF DATA
Source file 'PARED_reef_All_numerical.csv' not found.
Please upload it now:


Saving PARED_reef_All_numerical.csv to PARED_reef_All_numerical.csv
Processing PARED_reef_All_numerical.csv...
  ✓ Generated ordovician_devonian_reef_data_stage_for_analysis.csv
  ✓ Generated ordovician_devonian_reef_data_5myr_for_analysis.csv
✓ PARED reef data ready


In [4]:
# =============================================================================
# @title CELL 4: GENERATE PBDB DIVERSITY AND OCCURRENCE DATA (GENERIC 485.0 BINS)
# =============================================================================
import pandas as pd
import glob
import os
import numpy as np
from google.colab import files

# ==========================================
# 1. SETUP: Create Time References
# ==========================================
ics_data = """stage,series,period,start_ma,end_ma
Tremadocian,Lower Ordovician,Ordovician,485.4,477.7
Floian,Lower Ordovician,Ordovician,477.7,470
Dapingian,Middle Ordovician,Ordovician,470,467.3
Darriwilian,Middle Ordovician,Ordovician,467.3,458.4
Sandbian,Upper Ordovician,Ordovician,458.4,453
Katian,Upper Ordovician,Ordovician,453,445.2
Hirnantian,Upper Ordovician,Ordovician,445.2,443.8
Rhuddanian,Llandovery,Silurian,443.8,440.8
Aeronian,Llandovery,Silurian,440.8,438.5
Telychian,Llandovery,Silurian,438.5,433.4
Sheinwoodian,Wenlock,Silurian,433.4,430.5
Homerian,Wenlock,Silurian,430.5,427.4
Gorstian,Ludlow,Silurian,427.4,425.6
Ludfordian,Ludlow,Silurian,425.6,423
Pridoli,Pridoli,Silurian,423,419.2
Lochkovian,Lower Devonian,Devonian,419.2,410.8
Pragian,Lower Devonian,Devonian,410.8,407.6
Emsian,Lower Devonian,Devonian,407.6,393.3
Eifelian,Middle Devonian,Devonian,393.3,387.7
Givetian,Middle Devonian,Devonian,387.7,382.7
Frasnian,Upper Devonian,Devonian,382.7,372.2
Famennian,Upper Devonian,Devonian,372.2,358.9"""

with open("ICS_stage_boundaries.csv", "w") as f:
    f.write(ics_data)

# MODIFIED: Start at 485.0 to align with generic grid
def create_5myr_bins(start_ma=485.0, end_ma=358.9, step=5.0):
    bins = []
    current = start_ma
    # Ensure we cover down to the end_ma
    while current > end_ma - step:
        # Logic: Stop if the bottom of the bin is way below end_ma?
        # Typically we want the bin containing end_ma (358.9).
        # Bin 360-355 covers 358.9. 355 is < 358.9? No.
        # Let's keep standard logic:
        if current < 360 and current < end_ma: break # Safety break

        top = current
        bottom = current - step

        # Check if we are going too far (e.g. into Carboniferous)
        # We want to stop after covering Famennian (ends 358.9).
        # Bin 365-360: Covers 360+.
        # Bin 360-355: Covers 358.9.
        # Bin 355-350: Unnecessary.

        label = f"{top:.1f}-{bottom:.1f}"
        bins.append({'bin_label': label, 'bin_top': top, 'bin_bottom': bottom})

        # Break if this bin covered the end of the period
        if bottom < end_ma:
             current = bottom
             break

        current = bottom
    return pd.DataFrame(bins)

bins_5myr_df = create_5myr_bins()
stages_df = pd.read_csv("ICS_stage_boundaries.csv")

print("Created Generic 5-Myr bins (Aligned to 485.0):")
print(bins_5myr_df.head())
print(bins_5myr_df.tail())

# ==========================================
# 2. FILE CHECK
# ==========================================
print("\n--- CHECKING FILE SYSTEM ---")
found_files = []
for f in os.listdir('.'):
    if f.lower().startswith("pbdb_data_") and f.lower().endswith(".csv"):
        found_files.append(f)

if not found_files:
    print("No 'pbdb_data_*.csv' files found. Please upload your RAW PBDB files now.")
    files.upload()
    found_files = [f for f in os.listdir('.') if f.lower().startswith("pbdb_data_") and f.lower().endswith(".csv")]

print(f"Found {len(found_files)} files: {found_files}")

# ==========================================
# 3. HELPER FUNCTIONS
# ==========================================

def smart_read_pbdb(file_path):
    header_row = None
    try:
        with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
            lines = [f.readline() for _ in range(50)]
        for i, line in enumerate(lines):
            if "occurrence_no" in line:
                header_row = i
                break
        if header_row is None:
            print(f"    CRITICAL: Could not find 'occurrence_no' header in {file_path}")
            return None
        return pd.read_csv(file_path, header=header_row)
    except Exception as e:
        print(f"    Error reading {file_path}: {e}")
        return None

def extract_genus(accepted_name):
    if pd.isna(accepted_name): return None
    return str(accepted_name).split(' ')[0]

def get_stage_from_age(age, stages_df):
    match = stages_df[(stages_df['start_ma'] >= age) & (stages_df['end_ma'] < age)]
    if match.empty:
        if abs(age - stages_df['end_ma'].min()) < 0.001: return stages_df.iloc[-1]['stage']
    if not match.empty: return match.iloc[0]['stage']
    return None

def get_bin_from_age(age, bins_df):
    # Strict containment
    match = bins_df[(bins_df['bin_top'] >= age) & (bins_df['bin_bottom'] < age)]

    # MODIFIED: Tolerance/Snap for oldest points
    # If age is slightly older than the top bin (e.g. 485.4 vs 485.0), snap it to the first bin
    if match.empty:
        max_top = bins_df['bin_top'].max()
        if age > max_top and (age - max_top) < 1.5: # 1.5 Ma tolerance for Tremadocian start
             return bins_df.iloc[0]['bin_label']

    if not match.empty:
        return match.iloc[0]['bin_label']
    return None

def process_midpoint(df, group_name, reference_df, ref_type="stage"):
    if 'accepted_name' not in df.columns:
        print(f"    Error: 'accepted_name' column missing.")
        return None

    df['genus_name'] = df['accepted_name'].apply(extract_genus)
    df = df.dropna(subset=['genus_name'])

    # Midpoint Logic
    df['midpoint'] = (df['max_ma'] + df['min_ma']) / 2

    if ref_type == "stage":
        df['assigned_interval'] = df['midpoint'].apply(lambda x: get_stage_from_age(x, reference_df))
        merge_col = 'stage'
    else:
        df['assigned_interval'] = df['midpoint'].apply(lambda x: get_bin_from_age(x, reference_df))
        merge_col = 'bin_label'

    df = df.dropna(subset=['assigned_interval'])

    # Aggregation
    genus_counts = df.groupby('assigned_interval')['genus_name'].nunique()
    genus_counts.name = f'{group_name}_genus'
    occ_counts = df.groupby('assigned_interval')['occurrence_no'].nunique()
    occ_counts.name = f'{group_name}_occ'

    # Merging
    final_df = reference_df.copy()
    final_df = final_df.merge(genus_counts, left_on=merge_col, right_index=True, how='left')
    final_df = final_df.merge(occ_counts, left_on=merge_col, right_index=True, how='left')

    # Cleanup
    cols_to_fix = [f'{group_name}_genus', f'{group_name}_occ']
    final_df[cols_to_fix] = final_df[cols_to_fix].fillna(0).astype(int)

    return final_df

# ==========================================
# 4. MAIN EXECUTION LOOP
# ==========================================
print("\n--- STARTING ANALYSIS (Generic 485.0 Bins) ---")

for file_path in found_files:
    filename = os.path.basename(file_path)
    group_name = filename.replace("pbdb_data_", "").replace(".csv", "")
    print(f"\nAnalyzing {group_name}...")

    df = smart_read_pbdb(file_path)
    if df is not None:
        try:
            # A. Stages
            stage_df = process_midpoint(df.copy(), group_name, stages_df, "stage")
            if stage_df is not None:
                out_stage = f"pbdb_{group_name}_midpoint_stages.csv"
                stage_df.to_csv(out_stage, index=False)
                print(f"  -> Created: {out_stage}")

            # B. 5-Myr Bins
            bin_df = process_midpoint(df.copy(), group_name, bins_5myr_df, "bin")
            if bin_df is not None:
                out_bin = f"pbdb_{group_name}_midpoint_5myr_bins.csv"
                bin_df.to_csv(out_bin, index=False)
                print(f"  -> Created: {out_bin}")
        except Exception as e:
            print(f"  Error: {e}")

print("\n" + "="*40)
print("DONE! New 485.0-aligned bin files created.")
print("="*40)

Created Generic 5-Myr bins (Aligned to 485.0):
     bin_label  bin_top  bin_bottom
0  485.0-480.0    485.0       480.0
1  480.0-475.0    480.0       475.0
2  475.0-470.0    475.0       470.0
3  470.0-465.0    470.0       465.0
4  465.0-460.0    465.0       460.0
      bin_label  bin_top  bin_bottom
21  380.0-375.0    380.0       375.0
22  375.0-370.0    375.0       370.0
23  370.0-365.0    370.0       365.0
24  365.0-360.0    365.0       360.0
25  360.0-355.0    360.0       355.0

--- CHECKING FILE SYSTEM ---
No 'pbdb_data_*.csv' files found. Please upload your RAW PBDB files now.


Saving pbdb_data_Actinostromatida.csv to pbdb_data_Actinostromatida.csv
Saving pbdb_data_Amphiporida.csv to pbdb_data_Amphiporida.csv
Saving pbdb_data_Clathrodictyida.csv to pbdb_data_Clathrodictyida.csv
Saving pbdb_data_Labechiida.csv to pbdb_data_Labechiida.csv
Saving pbdb_data_Rugosa.csv to pbdb_data_Rugosa.csv
Saving pbdb_data_Stromatoporellida.csv to pbdb_data_Stromatoporellida.csv
Saving pbdb_data_Stromatoporida.csv to pbdb_data_Stromatoporida.csv
Saving pbdb_data_Syringostromatida.csv to pbdb_data_Syringostromatida.csv
Saving pbdb_data_Tabulata.csv to pbdb_data_Tabulata.csv
Found 9 files: ['pbdb_data_Clathrodictyida.csv', 'pbdb_data_Labechiida.csv', 'pbdb_data_Rugosa.csv', 'pbdb_data_Amphiporida.csv', 'pbdb_data_Syringostromatida.csv', 'pbdb_data_Tabulata.csv', 'pbdb_data_Actinostromatida.csv', 'pbdb_data_Stromatoporellida.csv', 'pbdb_data_Stromatoporida.csv']

--- STARTING ANALYSIS (Generic 485.0 Bins) ---

Analyzing Clathrodictyida...
  -> Created: pbdb_Clathrodictyida_midpoin

In [5]:
# =============================================================================
#@title CELL 5: UPLOAD ENVIRONMENT DATA FILES (Google Colab) - CONDITIONAL
# =============================================================================

from google.colab import files
import io

# Define required files
required_files = [
    'temperature.csv',
    'DO.csv',
    'oxygen.csv',
    'sealevel.csv',
    'd13C_5Myr_Cam-Dev.csv',
    'd13C_stage_binned_Cam-Dev.csv'
]
# Check which files are missing
missing_files = [f for f in required_files if not os.path.exists(f)]

if missing_files:
    print("The following files are missing and need to be uploaded:")
    for i, f in enumerate(missing_files, 1):
        print(f"  {i}. {f}")
    print("\nClick 'Choose Files' and select the missing files...")

    uploaded = files.upload()

    print(f"\n✓ Uploaded {len(uploaded)} files:")
    for fn in uploaded.keys():
        print(f"  - {fn}")
else:
    print("✓ All required files already exist in the folder")
    uploaded = {}
    # Load existing files into uploaded dict for compatibility
    for f in required_files:
        with open(f, 'rb') as file:
            uploaded[f] = file.read()

print(f"\n✓ {len(required_files)} files ready for analysis")


The following files are missing and need to be uploaded:
  1. temperature.csv
  2. DO.csv
  3. oxygen.csv
  4. sealevel.csv
  5. d13C_5Myr_Cam-Dev.csv
  6. d13C_stage_binned_Cam-Dev.csv

Click 'Choose Files' and select the missing files...


Saving d13C_5Myr_Cam-Dev.csv to d13C_5Myr_Cam-Dev.csv
Saving d13C_stage_binned_Cam-Dev.csv to d13C_stage_binned_Cam-Dev.csv
Saving DO.csv to DO.csv
Saving oxygen.csv to oxygen.csv
Saving sealevel.csv to sealevel.csv
Saving temperature.csv to temperature.csv

✓ Uploaded 6 files:
  - d13C_5Myr_Cam-Dev.csv
  - d13C_stage_binned_Cam-Dev.csv
  - DO.csv
  - oxygen.csv
  - sealevel.csv
  - temperature.csv

✓ 6 files ready for analysis


In [6]:
# =============================================================================
# @title CELL 6: LOAD AND PROCESS DATA (FROM CONTENT FOLDER)
# =============================================================================
import pandas as pd
import os

# Helper: Find file path by keywords in the current directory
def get_file_path_robust(keywords, search_dir='.'):
    """
    Finds a filename in the search_dir that contains ALL keywords (case-insensitive).
    """
    try:
        files = os.listdir(search_dir)
    except FileNotFoundError:
        print(f"  ! Error: Directory '{search_dir}' not found.")
        return None

    for filename in files:
        if not filename.endswith(('.csv', '.xlsx', '.xls')):
            continue
        # Check if ALL keywords are present in this filename
        if all(str(k).lower() in filename.lower() for k in keywords):
            return os.path.join(search_dir, filename)
    return None

def load_and_merge_from_disk(target_list, merge_cols, dataset_name):
    """
    Iterates through a list of target filenames, finds them on disk, and merges them.
    """
    merged_df = None
    print(f"\nProcessing {dataset_name}...")

    for target in target_list:
        # Extract keywords from the target filename
        clean_name = target.replace('.csv', '').replace('.xlsx', '')
        keywords = [k for k in clean_name.split('_') if k and k.lower() != 'pbdb']

        # Search for the file
        filepath = get_file_path_robust(keywords)

        if filepath:
            print(f"  ✓ Found: {filepath}")
            try:
                df = pd.read_csv(filepath)
                df.columns = df.columns.str.strip()

                if merged_df is None:
                    merged_df = df
                else:
                    merged_df = pd.merge(merged_df, df, on=merge_cols, how='outer')
            except Exception as e:
                print(f"  ! Error reading {filepath}: {e}")
        else:
            # Retry with minimal keywords (Taxon + Resolution)
            # This handles cases where the user filename might differ slightly from the instruction
            taxon = next((k for k in keywords if k.lower() not in ['midpoint', 'stages', '5myr', 'bins']), None)
            resolution = '5myr' if '5myr' in target.lower() else 'stages'

            if taxon:
                filepath_retry = get_file_path_robust([taxon, resolution])
                if filepath_retry:
                    print(f"  ✓ Found (fallback): {filepath_retry}")
                    try:
                        df = pd.read_csv(filepath_retry)
                        df.columns = df.columns.str.strip()
                        if merged_df is None: merged_df = df
                        else: merged_df = pd.merge(merged_df, df, on=merge_cols, how='outer')
                    except Exception as e:
                        print(f"  ! Error reading {filepath_retry}: {e}")
                else:
                     print(f"  x Could not find file for: {taxon} ({resolution})")
            else:
                 print(f"  x Could not find file matching: {keywords}")

    return merged_df

# Common merge columns
stage_merge_cols = ['stage', 'series', 'period', 'start_ma', 'end_ma']
bin_merge_cols = ['bin_label', 'bin_top', 'bin_bottom']

# =============================================================================
# 1. Substitute Stromatoporoid Data (Stage)
# =============================================================================
strom_stage_targets = [
    "pbdb_Stromatoporida_midpoint_stages.csv",
    "pbdb_Labechiida_midpoint_stages.csv",
    "pbdb_Actinostromatida_midpoint_stages.csv",
    "pbdb_clathrodictyida_midpoint_stages.csv",
    "pbdb_Syringostromatida_midpoint_stages.csv",
    "pbdb_Stromatoporellida_midpoint_stages.csv",
    "pbdb_Amphiporida_midpoint_stages.csv"
]

strom_df = load_and_merge_from_disk(strom_stage_targets, stage_merge_cols, "Stromatoporoid (Stage)")

if strom_df is not None:
    strom_df = strom_df.fillna(0)
    # Calculate Totals
    genus_cols = [c for c in strom_df.columns if c.endswith('_genus')]
    occ_cols = [c for c in strom_df.columns if c.endswith('_occ')]
    strom_df['Total_genus'] = strom_df[genus_cols].sum(axis=1)
    strom_df['Total_occ'] = strom_df[occ_cols].sum(axis=1)

    # === FILTER: Remove empty stages ===
    dropped_s = len(strom_df[strom_df['Total_occ'] == 0])
    strom_df = strom_df[strom_df['Total_occ'] > 0].copy()
    if dropped_s > 0:
        print(f"  -> Dropped {dropped_s} Stromatoporoid stage(s) with zero occurrences.")
    # ===================================

    print(f"  -> Merged Stromatoporoid Stages. Rows: {len(strom_df)}")
else:
    print("  ! Error: Stromatoporoid dataframe is empty.")

# =============================================================================
# 2. Substitute Coral Data (Stage)
# =============================================================================
coral_stage_targets = [
    "pbdb_tabulata_midpoint_stages.csv",
    "pbdb_Rugosa_midpoint_stages.csv"
]

coral_df = load_and_merge_from_disk(coral_stage_targets, stage_merge_cols, "Coral (Stage)")

if coral_df is not None:
    coral_df = coral_df.fillna(0)

    # Calculate Totals (Added for filtering)
    c_genus_cols = [c for c in coral_df.columns if c.endswith('_genus')]
    c_occ_cols = [c for c in coral_df.columns if c.endswith('_occ')]
    coral_df['Total_genus'] = coral_df[c_genus_cols].sum(axis=1)
    coral_df['Total_occ'] = coral_df[c_occ_cols].sum(axis=1)

    # === FILTER: Remove empty stages ===
    dropped_c = len(coral_df[coral_df['Total_occ'] == 0])
    coral_df = coral_df[coral_df['Total_occ'] > 0].copy()
    if dropped_c > 0:
        print(f"  -> Dropped {dropped_c} Coral stage(s) with zero occurrences.")
    # ===================================

    print(f"  -> Merged Coral Stages. Rows: {len(coral_df)}")

# =============================================================================
# 3. Create New Dataset for 5-Myr Bins
# =============================================================================
# Stromatoporoids
strom_bin_targets = [
    "pbdb_Stromatoporida_midpoint_5myr_bins.csv",
    "pbdb_Labechiida_midpoint_5myr_bins.csv",
    "pbdb_Actinostromatida_midpoint_5myr_bins.csv",
    "pbdb_clathrodictyida_midpoint_5myr_bins.csv",
    "pbdb_Syringostromatida_midpoint_5myr_bins.csv",
    "pbdb_Stromatoporellida_midpoint_5myr_bins.csv",
    "pbdb_Amphiporida_midpoint_5myr_bins.csv"
]

strom_5myr_df = load_and_merge_from_disk(strom_bin_targets, bin_merge_cols, "Stromatoporoid (5-Myr)")

if strom_5myr_df is not None:
    strom_5myr_df = strom_5myr_df.fillna(0)
    genus_cols = [c for c in strom_5myr_df.columns if c.endswith('_genus')]
    occ_cols = [c for c in strom_5myr_df.columns if c.endswith('_occ')]
    strom_5myr_df['Total_genus'] = strom_5myr_df[genus_cols].sum(axis=1)
    strom_5myr_df['Total_occ'] = strom_5myr_df[occ_cols].sum(axis=1)

    # === FILTER: Remove empty bins ===
    dropped_s5 = len(strom_5myr_df[strom_5myr_df['Total_occ'] == 0])
    strom_5myr_df = strom_5myr_df[strom_5myr_df['Total_occ'] > 0].copy()
    if dropped_s5 > 0:
        print(f"  -> Dropped {dropped_s5} Stromatoporoid 5-Myr bin(s) with zero occurrences.")
    # =================================

    print(f"  -> Merged Stromatoporoid 5-Myr Bins. Rows: {len(strom_5myr_df)}")

# Corals
coral_bin_targets = [
    "pbdb_tabulata_midpoint_5myr_bins.csv",
    "pbdb_Rugosa_midpoint_5myr_bins.csv"
]
coral_5myr_df = load_and_merge_from_disk(coral_bin_targets, bin_merge_cols, "Coral (5-Myr)")

if coral_5myr_df is not None:
    coral_5myr_df = coral_5myr_df.fillna(0)

    # Calculate Totals (Added for filtering)
    c5_genus_cols = [c for c in coral_5myr_df.columns if c.endswith('_genus')]
    c5_occ_cols = [c for c in coral_5myr_df.columns if c.endswith('_occ')]
    coral_5myr_df['Total_genus'] = coral_5myr_df[c5_genus_cols].sum(axis=1)
    coral_5myr_df['Total_occ'] = coral_5myr_df[c5_occ_cols].sum(axis=1)

    # === FILTER: Remove empty bins ===
    dropped_c5 = len(coral_5myr_df[coral_5myr_df['Total_occ'] == 0])
    coral_5myr_df = coral_5myr_df[coral_5myr_df['Total_occ'] > 0].copy()
    if dropped_c5 > 0:
        print(f"  -> Dropped {dropped_c5} Coral 5-Myr bin(s) with zero occurrences.")
    # =================================

    print(f"  -> Merged Coral 5-Myr Bins. Rows: {len(coral_5myr_df)}")

# =============================================================================
# 4. Load Remaining Contextual Data
# =============================================================================
print("\nLoading Contextual Data...")

def quick_load(keywords):
    path = get_file_path_robust(keywords)
    return pd.read_csv(path) if path else pd.DataFrame()

reef_df = quick_load(["reef_data", "stage"])
reef_5myr_df = quick_load(["reef_data", "5myr"])

macro_stage = quick_load(["paleozoic_stage_data"])
if 'stage' in macro_stage.columns: macro_stage['stage'] = macro_stage['stage'].replace('Pridolian', 'Pridoli')

macro_5myr = quick_load(["paleozoic_5myr_data"])

# Environmental
env_files = {
    'temperature': 'temperature',
    'do': 'DO',
    'oxygen': 'oxygen',
    'sealevel': 'sealevel',
    # δ13C files (already binned; DO NOT re-bin/interpolate these)
    'd13c_5myr': 'd13C_5Myr_Cam-Dev',
    'd13c_stage': 'd13C_stage_binned_Cam-Dev'
}
env_dfs = {}
for var, key in env_files.items():
    df = quick_load([key])
    if not df.empty: df.columns = df.columns.str.strip().str.replace('\ufeff', '')
    env_dfs[var] = df

temp_df = env_dfs.get('temperature', pd.DataFrame())
do_df = env_dfs.get('do', pd.DataFrame())
oxygen_df = env_dfs.get('oxygen', pd.DataFrame())
sealevel_df = env_dfs.get('sealevel', pd.DataFrame())
d13c_5myr_df = env_dfs.get('d13c_5myr', pd.DataFrame())
d13c_stage_df = env_dfs.get('d13c_stage', pd.DataFrame())
print("\n✓ Data loading complete.")

print("\nSaving intermediate merged datasets...")

if 'strom_df' in locals() and strom_df is not None:
    strom_df.to_csv(f"{OUTPUT_DIR}/intermediate_strom_stage.csv", index=False)
if 'coral_df' in locals() and coral_df is not None:
    coral_df.to_csv(f"{OUTPUT_DIR}/intermediate_coral_stage.csv", index=False)

if 'strom_5myr_df' in locals() and strom_5myr_df is not None:
    strom_5myr_df.to_csv(f"{OUTPUT_DIR}/intermediate_strom_5myr.csv", index=False)
if 'coral_5myr_df' in locals() and coral_5myr_df is not None:
    coral_5myr_df.to_csv(f"{OUTPUT_DIR}/intermediate_coral_5myr.csv", index=False)

print("✓ Intermediate files saved to ./output")



Processing Stromatoporoid (Stage)...
  ✓ Found: ./pbdb_Stromatoporida_midpoint_stages.csv
  ✓ Found: ./pbdb_Labechiida_midpoint_stages.csv
  ✓ Found: ./pbdb_Actinostromatida_midpoint_stages.csv
  ✓ Found: ./pbdb_Clathrodictyida_midpoint_stages.csv
  ✓ Found: ./pbdb_Syringostromatida_midpoint_stages.csv
  ✓ Found: ./pbdb_Stromatoporellida_midpoint_stages.csv
  ✓ Found: ./pbdb_Amphiporida_midpoint_stages.csv
  -> Dropped 1 Stromatoporoid stage(s) with zero occurrences.
  -> Merged Stromatoporoid Stages. Rows: 21

Processing Coral (Stage)...
  ✓ Found: ./pbdb_Tabulata_midpoint_stages.csv
  ✓ Found: ./pbdb_Rugosa_midpoint_stages.csv
  -> Dropped 1 Coral stage(s) with zero occurrences.
  -> Merged Coral Stages. Rows: 21

Processing Stromatoporoid (5-Myr)...
  ✓ Found: ./pbdb_Stromatoporida_midpoint_5myr_bins.csv
  ✓ Found: ./pbdb_Labechiida_midpoint_5myr_bins.csv
  ✓ Found: ./pbdb_Actinostromatida_midpoint_5myr_bins.csv
  ✓ Found: ./pbdb_Clathrodictyida_midpoint_5myr_bins.csv
  ✓ Found: ./

In [7]:
# =============================================================================
# @title CELL 7: DEFINE CONSTANTS AND STAGE INFORMATION
# =============================================================================

# 1. Stage definitions (ICS 2023)
STAGES = {
    'Tremadocian': {'start': 485.4, 'end': 477.7, 'mid': 481.55, 'period': 'Ordovician'},
    'Floian': {'start': 477.7, 'end': 470.0, 'mid': 473.85, 'period': 'Ordovician'},
    'Dapingian': {'start': 470.0, 'end': 467.3, 'mid': 468.65, 'period': 'Ordovician'},
    'Darriwilian': {'start': 467.3, 'end': 458.4, 'mid': 462.85, 'period': 'Ordovician'},
    'Sandbian': {'start': 458.4, 'end': 453.0, 'mid': 455.7, 'period': 'Ordovician'},
    'Katian': {'start': 453.0, 'end': 445.2, 'mid': 449.1, 'period': 'Ordovician'},
    'Hirnantian': {'start': 445.2, 'end': 443.8, 'mid': 444.5, 'period': 'Ordovician'},
    'Rhuddanian': {'start': 443.8, 'end': 440.8, 'mid': 442.3, 'period': 'Silurian'},
    'Aeronian': {'start': 440.8, 'end': 438.5, 'mid': 439.65, 'period': 'Silurian'},
    'Telychian': {'start': 438.5, 'end': 433.4, 'mid': 435.95, 'period': 'Silurian'},
    'Sheinwoodian': {'start': 433.4, 'end': 430.5, 'mid': 431.95, 'period': 'Silurian'},
    'Homerian': {'start': 430.5, 'end': 427.4, 'mid': 428.95, 'period': 'Silurian'},
    'Gorstian': {'start': 427.4, 'end': 425.6, 'mid': 426.5, 'period': 'Silurian'},
    'Ludfordian': {'start': 425.6, 'end': 423.0, 'mid': 424.3, 'period': 'Silurian'},
    'Pridoli': {'start': 423.0, 'end': 419.2, 'mid': 421.1, 'period': 'Silurian'},
    'Lochkovian': {'start': 419.2, 'end': 410.8, 'mid': 415.0, 'period': 'Devonian'},
    'Pragian': {'start': 410.8, 'end': 407.6, 'mid': 409.2, 'period': 'Devonian'},
    'Emsian': {'start': 407.6, 'end': 393.3, 'mid': 400.45, 'period': 'Devonian'},
    'Eifelian': {'start': 393.3, 'end': 387.7, 'mid': 390.5, 'period': 'Devonian'},
    'Givetian': {'start': 387.7, 'end': 382.7, 'mid': 385.2, 'period': 'Devonian'},
    'Frasnian': {'start': 382.7, 'end': 372.2, 'mid': 377.45, 'period': 'Devonian'},
    'Famennian': {'start': 372.2, 'end': 358.9, 'mid': 365.55, 'period': 'Devonian'}
}
STAGE_ORDER = list(STAGES.keys())

# 2. Period colors (ICS standard)
PERIOD_COLORS = {
    'Ordovician': '#009270',
    'Silurian': '#B3E1B6',
    'Devonian': '#CB8C37'
}

# 3. Stromatoporoid order colors (phylogenetically informed)
STROM_COLORS = {
    'Labechiida': '#8B0000',       # Dark red - basal
    'Clathrodictyida': '#CD5C5C',  # Indian red - early-diverging
    'Actinostromatida': '#FF8C00', # Dark orange - derived
    'Stromatoporida': '#FFD700',   # Gold - derived
    'Stromatoporellida': '#32CD32',# Lime green - derived
    'Syringostromatida': '#4169E1',# Royal blue - derived reef builders
    'Amphiporida': '#9370DB'       # Medium purple - derived
}
STROM_ORDERS = ['Labechiida', 'Clathrodictyida', 'Actinostromatida',
                'Stromatoporida', 'Stromatoporellida', 'Syringostromatida', 'Amphiporida']

# 4. Coral colors (New definitions for Rugosa/Tabulata)
CORAL_COLORS = {
    'Rugosa': '#800080',    # Purple
    'Tabulata': '#D2691E'   # Chocolate/Orange-Brown
}
CORAL_GROUPS = ['Rugosa', 'Tabulata']

# =============================================================================
# DATA NORMALIZATION: Ensure DataFrame columns match capitalized constants
# =============================================================================
# Some files were lowercase (e.g., 'clathrodictyida'), but constants are TitleCase.
# We fix this here to prevent KeyErrors in future plotting cells.

def normalize_columns(df, target_orders):
    if df is None: return df

    # Get current columns
    cols = df.columns.tolist()
    rename_map = {}

    for order in target_orders:
        # Check if TitleCase version exists (e.g., 'Clathrodictyida_genus')
        title_genus = f"{order}_genus"
        title_occ = f"{order}_occ"

        # Check if LowerCase version exists (e.g., 'clathrodictyida_genus')
        lower_genus = f"{order.lower()}_genus"
        lower_occ = f"{order.lower()}_occ"

        # If TitleCase missing but LowerCase present, map Lower -> Title
        if title_genus not in cols and lower_genus in cols:
            rename_map[lower_genus] = title_genus
        if title_occ not in cols and lower_occ in cols:
            rename_map[lower_occ] = title_occ

        # Also handle "tabulata" (lowercase t)
        if order == 'Tabulata' and 'tabulata_genus' in cols:
             rename_map['tabulata_genus'] = 'Tabulata_genus'
             rename_map['tabulata_occ'] = 'Tabulata_occ'

    if rename_map:
        print(f"  Note: Renaming columns to match standard capitalization: {list(rename_map.keys())}")
        df = df.rename(columns=rename_map)
    return df

# Apply normalization to the datasets loaded in Cell 5
if 'strom_df' in locals(): strom_df = normalize_columns(strom_df, STROM_ORDERS)
if 'strom_5myr_df' in locals(): strom_5myr_df = normalize_columns(strom_5myr_df, STROM_ORDERS)
if 'coral_df' in locals(): coral_df = normalize_columns(coral_df, CORAL_GROUPS)
if 'coral_5myr_df' in locals(): coral_5myr_df = normalize_columns(coral_5myr_df, CORAL_GROUPS)

print("✓ Constants defined and DataFrames normalized.")

✓ Constants defined and DataFrames normalized.


In [8]:
# =============================================================================
# @title CELL 8: RELOAD ENV DATA & INTERPOLATE (ROBUST FIX)
# =============================================================================
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
import os

# -----------------------------------------------------------------------------
# 1. RELOAD ENVIRONMENTAL DATA (To ensure it is not empty)
# -----------------------------------------------------------------------------
print("Reloading environmental datasets...")

def load_env_file(filenames, target_cols):
    """Try to load a file from a list of possible names"""
    for fname in filenames:
        if os.path.exists(fname):
            try:
                df = pd.read_csv(fname)
                df.columns = df.columns.str.strip() # clean whitespace
                print(f"  ✓ Loaded {fname} ({len(df)} rows)")
                return df
            except Exception as e:
                print(f"  ! Error loading {fname}: {e}")
    print(f"  ! WARNING: Could not find any of {filenames}")
    return pd.DataFrame() # Return empty if not found

# Load with specific fallbacks
temp_df     = load_env_file(['temperature.csv', 'Temperature.csv'], ['Age', 'SST'])
do_df       = load_env_file(['DO.csv', 'do.csv'], ['Age', 'DO'])
oxygen_df   = load_env_file(['oxygen.csv', 'Oxygen.csv'], ['Age', 'O2'])
sealevel_df = load_env_file(['sealevel.csv', 'Sealevel.csv'], ['Age', 'Eustatic'])

# -----------------------------------------------------------------------------
# 2. STANDARDIZE COLUMNS
# -----------------------------------------------------------------------------
def standardize_env_columns(df, name, target_age='Age', target_val=None):
    if df.empty: return df

    # Fix Age column
    if target_age not in df.columns:
        for candidate in ['age', 'AGE', 'Time', 'Ma', 'time']:
            if candidate in df.columns:
                df = df.rename(columns={candidate: target_age})
                break

    # Fix Value column
    if target_val and target_val not in df.columns:
        # Check case-insensitive match
        for col in df.columns:
            if col.lower() == target_val.lower():
                df = df.rename(columns={col: target_val})
                break

    return df

do_df = standardize_env_columns(do_df, 'Dissolved Oxygen', target_val='DO')
temp_df = standardize_env_columns(temp_df, 'Temperature', target_val='SST')
oxygen_df = standardize_env_columns(oxygen_df, 'Atmosphere', target_age='Age')
sealevel_df = standardize_env_columns(sealevel_df, 'Sea Level', target_val='Eustatic Sea Level')

# -----------------------------------------------------------------------------
# 3. INTERPOLATION
# -----------------------------------------------------------------------------
def interpolate_to_ages(source_df, age_col, value_col, target_ages):
    """Interpolate values to target ages, skipping if columns missing."""
    # Check if empty or missing columns
    if source_df.empty or age_col not in source_df.columns or value_col not in source_df.columns:
        return np.full_like(target_ages, np.nan)

    # Drop NaNs in source
    source_df = source_df.dropna(subset=[age_col, value_col])
    # Sort by Age (Crucial for interp1d)
    source_df = source_df.sort_values(age_col)

    if len(source_df) < 2:
        return np.full_like(target_ages, np.nan)

    # Interpolate
    f = interp1d(source_df[age_col], source_df[value_col],
                 kind='linear', fill_value='extrapolate', bounds_error=False)
    return f(target_ages)

# Get Stage Midpoints from Cell 7 constants
stage_midpoints = np.array([STAGES[s]['mid'] for s in STAGE_ORDER])

print("\nInterpolating environmental proxies to stage midpoints...")
env_data = pd.DataFrame({'stage': STAGE_ORDER, 'midpoint_ma': stage_midpoints})

env_data['temperature']     = interpolate_to_ages(temp_df, 'Age', 'SST', stage_midpoints)
env_data['dissolved_O2']    = interpolate_to_ages(do_df, 'Age', 'DO', stage_midpoints)
env_data['atm_O2']          = interpolate_to_ages(oxygen_df, 'Age', 'O2', stage_midpoints)
env_data['atm_CO2']         = interpolate_to_ages(oxygen_df, 'Age', 'CO2', stage_midpoints)
env_data['sea_level']       = interpolate_to_ages(sealevel_df, 'Age', 'Eustatic Sea Level', stage_midpoints)
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# MERGE δ13C (STAGE-BINNED) — prefer Stage-name join; fallback to age-nearest
# Uses attached d13C_stage_binned_Cam-Dev.csv WITHOUT re-binning/conversion.
# -----------------------------------------------------------------------------
import re
from pathlib import Path

def _norm_stage(s):
    if pd.isna(s):
        return np.nan
    s = str(s).strip().lower()
    # remove spaces/punct to survive minor naming differences
    return re.sub(r'[^a-z0-9]+', '', s)

try:
    # Load attached stage-binned δ13C table
    cand = [
        Path("d13C_stage_binned_Cam-Dev.csv"),
        Path("/mnt/data/d13C_stage_binned_Cam-Dev.csv")
    ]
    f = next((p for p in cand if p.exists()), None)
    if f is None:
        raise FileNotFoundError("d13C_stage_binned_Cam-Dev.csv not found in working dir or /mnt/data")

    d13_stage = pd.read_csv(f)
    d13_stage.columns = d13_stage.columns.str.strip().str.replace('\ufeff', '')

    # Required columns in attached file
    if "Stage" not in d13_stage.columns or "Mid_Ma" not in d13_stage.columns or "d13C_mean" not in d13_stage.columns:
        raise ValueError("δ13C stage file must have columns: Stage, Mid_Ma, d13C_mean")

    # Clean numeric + stage keys
    d13_stage["Mid_Ma"] = pd.to_numeric(d13_stage["Mid_Ma"], errors="coerce")
    d13_stage["d13C_mean"] = pd.to_numeric(d13_stage["d13C_mean"], errors="coerce")
    d13_stage["stage_key"] = d13_stage["Stage"].apply(_norm_stage)

    env_data["midpoint_ma"] = pd.to_numeric(env_data["midpoint_ma"], errors="coerce")

    # --- 1) Stage-name merge (best) ---
    if "stage" in env_data.columns:
        env_data["stage_key"] = env_data["stage"].apply(_norm_stage)

        _m1 = env_data.merge(
            d13_stage[["stage_key", "d13C_mean", "Mid_Ma"]],
            on="stage_key",
            how="left",
            suffixes=("", "_d13")
        )

        # Keep δ13C as a single downstream name
        _m1["d13C"] = _m1["d13C_mean"]

        # --- 2) Fallback: for unmatched stages, fill by age-nearest ---
        missing = _m1["d13C"].isna() & _m1["midpoint_ma"].notna()
        if missing.any():
            _left = _m1.loc[missing, ["midpoint_ma"]].copy().sort_values("midpoint_ma")
            _right = d13_stage[["Mid_Ma", "d13C_mean"]].dropna().sort_values("Mid_Ma")

            _fill = pd.merge_asof(
                _left,
                _right,
                left_on="midpoint_ma",
                right_on="Mid_Ma",
                direction="nearest",
                tolerance=2.0  # stage midpoints can differ by >0.25; allow reasonable slack
            )
            _m1.loc[missing, "d13C"] = _fill["d13C_mean"].values

        # Cleanup
        env_data = _m1.drop(columns=[c for c in ["d13C_mean", "Mid_Ma", "stage_key"] if c in _m1.columns])
        env_data = env_data.sort_values("midpoint_ma", ascending=False).reset_index(drop=True)

    else:
        # If env_data has no stage column, do age-nearest only (more tolerant)
        _left = env_data.dropna(subset=["midpoint_ma"]).sort_values("midpoint_ma")
        _right = d13_stage[["Mid_Ma", "d13C_mean"]].dropna().sort_values("Mid_Ma")
        _m = pd.merge_asof(_left, _right, left_on="midpoint_ma", right_on="Mid_Ma", direction="nearest", tolerance=2.0)
        env_data = _m.drop(columns=["Mid_Ma"]).rename(columns={"d13C_mean": "d13C"})
        env_data = env_data.sort_values("midpoint_ma", ascending=False).reset_index(drop=True)

    n_valid = int(env_data["d13C"].notna().sum()) if "d13C" in env_data.columns else 0
    print(f"  -> δ13C (stage-binned) merged: {n_valid}/{len(env_data)} values")

except Exception as e:
    env_data["d13C"] = np.nan
    print("  ! Warning: δ13C stage merge failed:", e)


# -----------------------------------------------------------------------------
# 4. INTERPOLATE TO 5-MYR BINS
# -----------------------------------------------------------------------------
print("\nInterpolating environmental proxies to 5-Myr bin midpoints...")

# Define bins (Logic from previous step)
if 'reef_5myr_df' in locals() and not reef_5myr_df.empty:
    bin_midpoints_5myr = reef_5myr_df['midpoint_ma'].values
    bin_ids = reef_5myr_df['time_identifier']
elif 'strom_5myr_df' in locals() and not strom_5myr_df.empty:
    # Recalculate if column missing
    if 'midpoint_ma' not in strom_5myr_df.columns:
        strom_5myr_df['midpoint_ma'] = (strom_5myr_df['bin_top'] + strom_5myr_df['bin_bottom']) / 2
    bin_midpoints_5myr = strom_5myr_df['midpoint_ma'].values
    bin_ids = strom_5myr_df['bin_label']
else:
    # Fallback
    bin_midpoints_5myr = np.arange(482.5, 359, -5)
    bin_ids = [f"{x}" for x in bin_midpoints_5myr]

env_data_5myr = pd.DataFrame({'bin_id': bin_ids, 'midpoint_ma': bin_midpoints_5myr})

env_data_5myr['temperature']     = interpolate_to_ages(temp_df, 'Age', 'SST', bin_midpoints_5myr)
env_data_5myr['dissolved_O2']    = interpolate_to_ages(do_df, 'Age', 'DO', bin_midpoints_5myr)
env_data_5myr['atm_O2']          = interpolate_to_ages(oxygen_df, 'Age', 'O2', bin_midpoints_5myr)
env_data_5myr['atm_CO2']         = interpolate_to_ages(oxygen_df, 'Age', 'CO2', bin_midpoints_5myr)
env_data_5myr['sea_level']       = interpolate_to_ages(sealevel_df, 'Age', 'Eustatic Sea Level', bin_midpoints_5myr)
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# 4b. MERGE δ13C (5-MYR BINNED) WITHOUT INTERPOLATION / RE-BINNING
#   Float bin midpoints often differ by tiny rounding; if exact merge yields
#   few/zero matches, fall back to merge_asof (requires ascending sort).
# -----------------------------------------------------------------------------
try:
    if 'd13c_5myr_df' in globals() and isinstance(d13c_5myr_df, pd.DataFrame) and (not d13c_5myr_df.empty):
        _d13_5 = d13c_5myr_df.copy()
    else:
        cand = [
            'd13C_5Myr_Cam-Dev.csv',
            './output/d13C_5Myr_Cam-Dev.csv',
            'd13C_5Myr.csv',
            './output/d13C_5Myr.csv'
        ]
        _d13_5 = None
        for p in cand:
            if os.path.exists(p):
                _d13_5 = pd.read_csv(p)
                break
        if _d13_5 is None:
            raise FileNotFoundError("No 5-Myr δ13C file found (tried: " + ", ".join(cand) + ").")

    _d13_5.columns = _d13_5.columns.str.strip().str.replace('\ufeff', '')

    # Standardize columns
    if 'age_Ma' in _d13_5.columns and 'midpoint_ma' not in _d13_5.columns:
        _d13_5 = _d13_5.rename(columns={'age_Ma': 'midpoint_ma'})
    if 'Mid_Ma' in _d13_5.columns and 'midpoint_ma' not in _d13_5.columns:
        _d13_5 = _d13_5.rename(columns={'Mid_Ma': 'midpoint_ma'})
    if 'd13Ccarb_permille' in _d13_5.columns and 'd13C' not in _d13_5.columns:
        _d13_5 = _d13_5.rename(columns={'d13Ccarb_permille': 'd13C'})
    if 'd13C_mean' in _d13_5.columns and 'd13C' not in _d13_5.columns:
        _d13_5 = _d13_5.rename(columns={'d13C_mean': 'd13C'})

    _d13_5['midpoint_ma'] = pd.to_numeric(_d13_5['midpoint_ma'], errors='coerce')
    _d13_5['d13C'] = pd.to_numeric(_d13_5['d13C'], errors='coerce')
    _d13_5 = _d13_5.dropna(subset=['midpoint_ma', 'd13C']).copy()

    # --- Attempt exact merge after rounding to reduce floating mismatch
    env_data_5myr['midpoint_ma'] = pd.to_numeric(env_data_5myr['midpoint_ma'], errors='coerce')
    _left = env_data_5myr.copy()
    _left['midpoint_ma_round'] = _left['midpoint_ma'].round(3)
    _right = _d13_5[['midpoint_ma', 'd13C']].copy()
    _right['midpoint_ma_round'] = _right['midpoint_ma'].round(3)

    env_data_5myr = _left.merge(_right[['midpoint_ma_round', 'd13C']], on='midpoint_ma_round', how='left')
    env_data_5myr = env_data_5myr.drop(columns=['midpoint_ma_round'])

    n_valid = env_data_5myr['d13C'].notna().sum()

    # --- Fallback: nearest-age join if exact merge fails
    if n_valid == 0 and len(_d13_5) > 0:
        _left_sorted = env_data_5myr.drop(columns=['d13C'], errors='ignore').copy()
        _left_sorted = _left_sorted.dropna(subset=['midpoint_ma']).sort_values('midpoint_ma', ascending=True).reset_index(drop=True)

        _right_sorted = _d13_5[['midpoint_ma', 'd13C']].sort_values('midpoint_ma', ascending=True).reset_index(drop=True)

        _m = pd.merge_asof(
            _left_sorted,
            _right_sorted,
            on='midpoint_ma',
            direction='nearest',
            tolerance=2.6  # ~half of 5-Myr bin width
        )
        env_data_5myr = _m.sort_values('midpoint_ma', ascending=False).reset_index(drop=True)
        n_valid = env_data_5myr['d13C'].notna().sum()

    print(f"  -> δ13C (5-Myr binned) merged: {n_valid}/{len(env_data_5myr)} values")

except Exception as e:
    env_data_5myr['d13C'] = np.nan
    print("  ! Warning: δ13C 5-Myr merge failed:", e)

print("✓ Environmental proxies interpolated (Stage & 5-Myr).")

# =============================================================================
# [ADDED] SAVE INTERPOLATED ENV DATA
# =============================================================================
if 'env_data' in locals() and not env_data.empty:
    env_data.to_csv(f"{OUTPUT_DIR}/intermediate_env_data_stage.csv", index=False)

if 'env_data_5myr' in locals() and not env_data_5myr.empty:
    env_data_5myr.to_csv(f"{OUTPUT_DIR}/intermediate_env_data_5myr.csv", index=False)

print(f"✓ Interpolated environmental data saved to {OUTPUT_DIR}")


Reloading environmental datasets...
  ✓ Loaded temperature.csv (387 rows)
  ✓ Loaded DO.csv (499 rows)
  ✓ Loaded oxygen.csv (58 rows)
  ✓ Loaded sealevel.csv (101 rows)

Interpolating environmental proxies to stage midpoints...
  -> δ13C (stage-binned) merged: 22/22 values

Interpolating environmental proxies to 5-Myr bin midpoints...
  -> δ13C (5-Myr binned) merged: 27/27 values
✓ Environmental proxies interpolated (Stage & 5-Myr).
✓ Interpolated environmental data saved to ./output


In [9]:
# =============================================================================
# @title CELL 9: CREATE MASTER DATASET (STAGES AND 5-MYR BINS)
# =============================================================================
import numpy as np
import pandas as pd

# -----------------------------------------------------------------------------
# 1. BUILD MASTER STAGE DATASET
# -----------------------------------------------------------------------------
print("Building Master Stage Dataset...")
data = []

for stage, info in STAGES.items():
    row = {
        'stage': stage,
        'midpoint_ma': info['mid'],
        'start_ma': info['start'],
        'end_ma': info['end'],
        'period': info['period']
    }

    # A. Reef data (if available)
    if 'reef_df' in locals() and reef_df is not None and not reef_df.empty:
        reef_row = reef_df[reef_df['name'] == stage]
        if len(reef_row) > 0:
            for col in [
                'thickness_mean', 'thickness_std', 'thickness_stderr', 'thickness_median',
                'thickness_min', 'thickness_max', 'thickness_q25', 'thickness_q75', 'thickness_count',
                'width_mean', 'width_std', 'width_median', 'width_min', 'width_max',
                'reef_count'
            ]:
                if col in reef_row.columns:
                    row[col] = reef_row[col].values[0]

    # B. Stromatoporoid data (from strom_df)
    if 'strom_df' in locals() and strom_df is not None and not strom_df.empty:
        # robust stage matching
        s_stage = strom_df['stage'].astype(str).str.lower()
        strom_row = strom_df[s_stage == str(stage).lower()]
        if len(strom_row) > 0:
            for order in STROM_ORDERS:
                col_occ = f'{order}_occ'
                col_gen = f'{order}_genus'
                if col_occ in strom_row.columns:
                    row[col_occ] = strom_row[col_occ].values[0]
                if col_gen in strom_row.columns:
                    row[col_gen] = strom_row[col_gen].values[0]

            if 'Total_occ' in strom_row.columns:
                row['strom_total_occ'] = strom_row['Total_occ'].values[0]
            if 'Total_genus' in strom_row.columns:
                row['strom_total_gen'] = strom_row['Total_genus'].values[0]

    # C. Coral data (from coral_df)
    if 'coral_df' in locals() and coral_df is not None and not coral_df.empty:
        c_stage = coral_df['stage'].astype(str).str.lower()
        coral_row = coral_df[c_stage == str(stage).lower()]
        if len(coral_row) > 0:
            # Rugosa
            if 'Rugosa_genus' in coral_row.columns: row['rugose_div'] = coral_row['Rugosa_genus'].values[0]
            if 'Rugosa_occ' in coral_row.columns:   row['rugose_occ'] = coral_row['Rugosa_occ'].values[0]

            # Tabulata
            if 'Tabulata_genus' in coral_row.columns: row['tabulate_div'] = coral_row['Tabulata_genus'].values[0]
            if 'Tabulata_occ' in coral_row.columns:   row['tabulate_occ'] = coral_row['Tabulata_occ'].values[0]

    # D. Macrostrat data
    if 'macro_stage' in locals() and macro_stage is not None and not macro_stage.empty:
        macro_row = macro_stage[macro_stage['stage'] == stage]
        if len(macro_row) > 0:
            row['total_area_km2'] = macro_row['total_area_km2'].values[0]
            row['carbonate_area_km2'] = macro_row['carbonate_area_km2'].values[0]
            row['carbonate_percentage'] = macro_row['carbonate_percentage'].values[0]

    # E. Environmental proxies
    if 'env_data' in locals() and env_data is not None and not env_data.empty:
        env_row = env_data[env_data['stage'] == stage]
        if len(env_row) > 0:
            row['temperature'] = env_row['temperature'].values[0]
            row['dissolved_O2'] = env_row['dissolved_O2'].values[0]
            row['atm_O2'] = env_row['atm_O2'].values[0]
            row['atm_CO2'] = env_row['atm_CO2'].values[0]
            row['sea_level'] = env_row['sea_level'].values[0]
            if 'd13C' in env_row.columns:
                row['d13C'] = env_row['d13C'].values[0]

    data.append(row)

df = pd.DataFrame(data)

# Calculate Stromatoporoid Proportions
if 'strom_total_occ' in df.columns:
    for order in STROM_ORDERS:
        col_occ = f'{order}_occ'
        if col_occ in df.columns:
            df[f'{order}_prop'] = np.where(
                df['strom_total_occ'].fillna(0) > 0,
                df[col_occ].fillna(0) / df['strom_total_occ'].fillna(0),
                0.0
            )

    # Derived vs Basal
    derived_orders = ['Actinostromatida', 'Stromatoporida', 'Stromatoporellida',
                      'Syringostromatida', 'Amphiporida']
    basal_orders = ['Labechiida', 'Clathrodictyida']

    # Occurrences
    df['derived_strom_occ'] = sum(df[f'{o}_occ'].fillna(0) for o in derived_orders if f'{o}_occ' in df.columns)
    df['basal_strom_occ'] = sum(df[f'{o}_occ'].fillna(0) for o in basal_orders if f'{o}_occ' in df.columns)

    # Diversity
    df['derived_strom_div'] = sum(df[f'{o}_genus'].fillna(0) for o in derived_orders if f'{o}_genus' in df.columns)
    df['basal_strom_div'] = sum(df[f'{o}_genus'].fillna(0) for o in basal_orders if f'{o}_genus' in df.columns)

    # Proportions
    df['derived_strom_prop'] = np.where(df['strom_total_occ'].fillna(0) > 0, df['derived_strom_occ'] / df['strom_total_occ'].fillna(0), 0.0)
    df['basal_strom_prop'] = np.where(df['strom_total_occ'].fillna(0) > 0, df['basal_strom_occ'] / df['strom_total_occ'].fillna(0), 0.0)

# [ADDED] Fill all missing NUMERIC values with 0 (keep text columns untouched)
num_cols = df.select_dtypes(include=[np.number]).columns
df[num_cols] = df[num_cols].fillna(0)

# Sort
df = df.sort_values('midpoint_ma', ascending=False).reset_index(drop=True)
print(f"✓ Master dataset (STAGES) created: {len(df)} stages, {len(df.columns)} variables")


# -----------------------------------------------------------------------------
# 2. BUILD MASTER 5-MYR DATASET
# -----------------------------------------------------------------------------
print("\n" + "="*70)
print("CREATING 5-MYR BIN MASTER DATASET")
print("="*70)

# Determine the primary source for 5-Myr bins
# We prefer reef data if available, otherwise we use the biological data
if 'reef_5myr_df' in locals() and reef_5myr_df is not None and not reef_5myr_df.empty:
    primary_bins = reef_5myr_df
    print("Using Reef Data as primary 5-Myr bin source.")
elif 'strom_5myr_df' in locals() and strom_5myr_df is not None and not strom_5myr_df.empty:
    primary_bins = strom_5myr_df
    # Add midpoint if missing
    if 'midpoint_ma' not in primary_bins.columns and {'bin_top', 'bin_bottom'}.issubset(primary_bins.columns):
        primary_bins = primary_bins.copy()
        primary_bins['midpoint_ma'] = (primary_bins['bin_top'] + primary_bins['bin_bottom']) / 2
    print("Using Stromatoporoid Data as primary 5-Myr bin source.")
else:
    # Fallback to env_data_5myr if bio data missing
    primary_bins = env_data_5myr
    print("Using Environmental Data as primary 5-Myr bin source.")

data_5myr = []

# Iterate through the chosen primary bins
for idx, row_ref in primary_bins.iterrows():
    # Identify bin
    if 'time_identifier' in row_ref:
        bin_id = row_ref['time_identifier']
    elif 'bin_label' in row_ref:
        bin_id = row_ref['bin_label']
    elif 'bin_id' in row_ref:
        bin_id = row_ref['bin_id']
    else:
        bin_id = idx  # fallback

    midpoint = row_ref['midpoint_ma']

    row = {
        'bin_id': bin_id,
        'midpoint_ma': midpoint
    }

    # A. Reef Data (if available)
    if 'reef_5myr_df' in locals() and reef_5myr_df is not None and not reef_5myr_df.empty:
        # Find matching reef row (if not already iterating it)
        if primary_bins is not reef_5myr_df:
            reef_match = reef_5myr_df[abs(reef_5myr_df['midpoint_ma'] - midpoint) < 0.1]
            row_ref_for_reef = reef_match.iloc[0] if len(reef_match) > 0 else pd.Series(dtype='float64')
        else:
            row_ref_for_reef = row_ref

        for col in [
            'thickness_mean', 'thickness_std', 'thickness_stderr', 'thickness_median',
            'thickness_count', 'width_mean', 'width_std', 'reef_count'
        ]:
            if col in row_ref_for_reef.index:
                row[col] = row_ref_for_reef[col]

    # B. Stromatoporoid Data (5-Myr)
    if 'strom_5myr_df' in locals() and strom_5myr_df is not None and not strom_5myr_df.empty:
        if {'bin_top', 'bin_bottom'}.issubset(strom_5myr_df.columns):
            strom_mid = (strom_5myr_df['bin_top'] + strom_5myr_df['bin_bottom']) / 2
            strom_match = strom_5myr_df[abs(strom_mid - midpoint) < 0.1]
            if len(strom_match) > 0:
                s_row = strom_match.iloc[0]
                for order in STROM_ORDERS:
                    if f'{order}_occ' in s_row.index: row[f'{order}_occ'] = s_row[f'{order}_occ']
                    if f'{order}_genus' in s_row.index: row[f'{order}_genus'] = s_row[f'{order}_genus']
                if 'Total_occ' in s_row.index: row['strom_total_occ'] = s_row['Total_occ']
                if 'Total_genus' in s_row.index: row['strom_total_gen'] = s_row['Total_genus']

    # C. Coral Data (5-Myr)
    if 'coral_5myr_df' in locals() and coral_5myr_df is not None and not coral_5myr_df.empty:
        if {'bin_top', 'bin_bottom'}.issubset(coral_5myr_df.columns):
            coral_mid = (coral_5myr_df['bin_top'] + coral_5myr_df['bin_bottom']) / 2
            coral_match = coral_5myr_df[abs(coral_mid - midpoint) < 0.1]
            if len(coral_match) > 0:
                c_row = coral_match.iloc[0]
                if 'Rugosa_occ' in c_row.index: row['rugose_occ'] = c_row['Rugosa_occ']
                if 'Rugosa_genus' in c_row.index: row['rugose_div'] = c_row['Rugosa_genus']
                if 'Tabulata_occ' in c_row.index: row['tabulate_occ'] = c_row['Tabulata_occ']
                if 'Tabulata_genus' in c_row.index: row['tabulate_div'] = c_row['Tabulata_genus']

    # D. Macrostrat Data
    if 'macro_5myr' in locals() and macro_5myr is not None and not macro_5myr.empty and 'bin_mid' in macro_5myr.columns:
        macro_match = macro_5myr[abs(macro_5myr['bin_mid'] - midpoint) < 2.5]
        if len(macro_match) > 0:
            row['total_area_km2'] = macro_match['total_area_km2'].values[0]
            row['carbonate_area_km2'] = macro_match['carbonate_area_km2'].values[0]
            row['carbonate_percentage'] = macro_match['carbonate_percentage'].values[0]

    # E. Environmental Proxies
    if 'env_data_5myr' in locals() and env_data_5myr is not None and not env_data_5myr.empty:
        env_match = env_data_5myr[abs(env_data_5myr['midpoint_ma'] - midpoint) < 0.1]
        if len(env_match) > 0:
            env_val = env_match.iloc[0]
            row['temperature'] = env_val['temperature']
            row['dissolved_O2'] = env_val['dissolved_O2']
            row['atm_O2'] = env_val['atm_O2']
            row['atm_CO2'] = env_val['atm_CO2']
            row['sea_level'] = env_val['sea_level']
            if 'd13C' in env_val.index:
                row['d13C'] = env_val['d13C']

    data_5myr.append(row)

df_5myr = pd.DataFrame(data_5myr)

# [ADDED] Fill all missing NUMERIC values with 0 (keep text columns untouched)
num_cols_5 = df_5myr.select_dtypes(include=[np.number]).columns
df_5myr[num_cols_5] = df_5myr[num_cols_5].fillna(0)

df_5myr = df_5myr.sort_values('midpoint_ma', ascending=False).reset_index(drop=True)
print(f"✓ Master dataset (5-MYR BINS) created: {len(df_5myr)} bins, {len(df_5myr.columns)} variables")

# =============================================================================
# [ADDED] SAVE MASTER DATASETS
# =============================================================================
# Save the master datasets immediately after creation
if 'df' in locals() and df is not None and not df.empty:
    df.to_csv(f"{OUTPUT_DIR}/MASTER_dataset_stage.csv", index=False)
    print(f"✓ Saved: {OUTPUT_DIR}/MASTER_dataset_stage.csv")

if 'df_5myr' in locals() and df_5myr is not None and not df_5myr.empty:
    df_5myr.to_csv(f"{OUTPUT_DIR}/MASTER_dataset_5myr.csv", index=False)
    print(f"✓ Saved: {OUTPUT_DIR}/MASTER_dataset_5myr.csv")


Building Master Stage Dataset...
✓ Master dataset (STAGES) created: 22 stages, 56 variables

CREATING 5-MYR BIN MASTER DATASET
Using Reef Data as primary 5-Myr bin source.
✓ Master dataset (5-MYR BINS) created: 27 bins, 39 variables
✓ Saved: ./output/MASTER_dataset_stage.csv
✓ Saved: ./output/MASTER_dataset_5myr.csv


In [10]:
# =============================================================================
# @title CELL 10: CLR COMPOSITIONAL TRANSFORMATION (WITH BASAL/DERIVED GROUPS)
# =============================================================================
import numpy as np
import pandas as pd
from scipy.stats import gmean
import scipy.stats as stats
print("\n" + "="*90)
print("PRE-PROCESSING: CLR TRANSFORMATION")
print("Transforming closed compositional data (proportions) to open log-ratios")
print("="*90)
# Global list for CLR results
global_clr_results = []
def calculate_missing_props(input_df):
    """
    Helper: If _prop columns are missing, calculate them from _occ columns.
    """
    df_copy = input_df.copy()
    orders = ['Labechiida', 'Clathrodictyida', 'Actinostromatida',
              'Stromatoporida', 'Stromatoporellida',
              'Syringostromatida', 'Amphiporida']
    # Check if we have occurrence columns
    occ_cols = [f"{o}_occ" for o in orders if f"{o}_occ" in df_copy.columns]
    if not occ_cols:
        return df_copy
    # Recalculate Total
    if 'strom_total_occ' not in df_copy.columns:
        df_copy['strom_total_occ'] = df_copy[occ_cols].sum(axis=1)
    # Calculate Proportions
    for o in orders:
        occ_col = f"{o}_occ"
        prop_col = f"{o}_prop"
        if occ_col in df_copy.columns:
            total = df_copy['strom_total_occ'].replace(0, np.nan)
            df_copy[prop_col] = df_copy[occ_col] / total
            df_copy[prop_col] = df_copy[prop_col].fillna(0)
    return df_copy
def apply_clr(input_df, label):
    """Apply CLR transformation and calculate group-level metrics."""

    dataset = input_df.copy()

    # 1. AUTO-REPAIR: Ensure Proportion Columns Exist
    dataset = calculate_missing_props(dataset)
    # 2. Identify Proportion Columns
    prop_cols = ['Labechiida_prop', 'Clathrodictyida_prop', 'Actinostromatida_prop',
                 'Stromatoporida_prop', 'Stromatoporellida_prop',
                 'Syringostromatida_prop', 'Amphiporida_prop']
    available_cols = [c for c in prop_cols if c in dataset.columns]
    # Define group membership
    basal_orders = ['Labechiida_prop', 'Clathrodictyida_prop']
    derived_orders = ['Actinostromatida_prop', 'Stromatoporida_prop',
                      'Stromatoporellida_prop', 'Syringostromatida_prop',
                      'Amphiporida_prop']
    if len(available_cols) < 2:
        print(f"  {label}: ! Skipped. Found only {len(available_cols)} proportion columns.")
        return dataset
    # 3. Extract and Handle Zeros
    comp_data = dataset[available_cols].replace(0, 1e-5)
    # 4. Geometric Mean per Row
    gmeans = gmean(comp_data, axis=1)
    # 5. Transform: ln(x / gmean)
    clr_data = np.log(comp_data.div(gmeans, axis=0))

    # Use CLR_ prefix (avoiding duplicates)
    new_col_names = []
    for c in available_cols:
        new_name = f"CLR_{c}"
        # Drop existing column if present to avoid duplicates
        if new_name in dataset.columns:
            dataset = dataset.drop(columns=[new_name])
        new_col_names.append(new_name)
    clr_data.columns = new_col_names
    # 6. Create 'Derived vs Basal' Log-Ratio
    basal_in = [c for c in basal_orders if c in dataset.columns]
    derived_in = [c for c in derived_orders if c in dataset.columns]
    if basal_in and derived_in:
        b_sum = dataset[basal_in].sum(axis=1).replace(0, 1e-5)
        d_sum = dataset[derived_in].sum(axis=1).replace(0, 1e-5)
        dataset['log_derived_basal_ratio'] = np.log(d_sum / b_sum)
        print(f"  {label}: Created 'log_derived_basal_ratio'")
    # 7. Merge CLR columns back to dataset
    dataset = pd.concat([dataset.reset_index(drop=True), clr_data.reset_index(drop=True)], axis=1)
    print(f"  {label}: Generated {len(clr_data.columns)} CLR variables.")

    # 8. Calculate GROUP-LEVEL CLR means
    clr_basal_cols = [f"CLR_{c}" for c in basal_in]
    clr_derived_cols = [f"CLR_{c}" for c in derived_in]

    if clr_basal_cols:
        dataset['CLR_basal_mean'] = dataset[clr_basal_cols].mean(axis=1)
    if clr_derived_cols:
        dataset['CLR_derived_mean'] = dataset[clr_derived_cols].mean(axis=1)

    return dataset
# Apply to both master datasets
df = apply_clr(df, "Stage-Level")
df_5myr = apply_clr(df_5myr, "5-Myr Bins")
# =============================================================================
# CLR CORRELATION ANALYSIS (Individual Taxa + Basal/Derived Groups)
# =============================================================================
print("\n" + "-"*90)
print("CLR CORRELATION ANALYSIS")
print("-"*90)
prop_cols = ['Labechiida_prop', 'Clathrodictyida_prop', 'Actinostromatida_prop',
             'Stromatoporida_prop', 'Stromatoporellida_prop',
             'Syringostromatida_prop', 'Amphiporida_prop']
targets = ['thickness_mean', 'width_mean']
def safe_spearman(x, y):
    """Calculate Spearman correlation safely, returning scalars."""
    try:
        # Ensure 1D numpy arrays
        x_arr = np.array(x).flatten()
        y_arr = np.array(y).flatten()

        # Remove NaN pairs
        mask = ~(np.isnan(x_arr) | np.isnan(y_arr))
        x_clean = x_arr[mask]
        y_clean = y_arr[mask]

        if len(x_clean) < 3:
            return np.nan, np.nan

        result = stats.spearmanr(x_clean, y_clean)
        # Handle both old and new scipy return types
        if hasattr(result, 'correlation'):
            return float(result.correlation), float(result.pvalue)
        else:
            return float(result[0]), float(result[1])
    except:
        return np.nan, np.nan
for label, data in [('Stage-Level', df), ('5-Myr Bins', df_5myr)]:
    print(f"\n--- {label} ---")
    print(f"  {'Taxon/Group':<28s} | {'Orig ρ':>8s} | {'CLR ρ':>8s} | {'Diff':>7s} | {'CLR p':>10s}")
    print("  " + "-"*75)

    # A. INDIVIDUAL TAXA
    for prop in prop_cols:
        clr_col = f"CLR_{prop}"
        if prop not in data.columns or clr_col not in data.columns:
            continue

        for target in targets:
            if target not in data.columns:
                continue

            # Get data as 1D arrays
            x_orig = data[prop].values
            x_clr = data[clr_col].values
            y = data[target].values

            # Calculate correlations
            r_orig, p_orig = safe_spearman(x_orig, y)
            r_clr, p_clr = safe_spearman(x_clr, y)

            if np.isnan(r_orig) or np.isnan(r_clr):
                continue

            diff = r_clr - r_orig

            # Interpretation
            if abs(diff) < 0.1:
                interp = "Stable"
            elif diff < -0.2:
                interp = "SUPPRESSED"
            elif diff > 0.2:
                interp = "INFLATED"
            else:
                interp = "Moderate"

            global_clr_results.append({
                'Dataset': label, 'Level': 'Individual',
                'Predictor': prop.replace('_prop', ''),
                'Target': target, 'Original_Rho': r_orig, 'CLR_Rho': r_clr,
                'Difference': diff, 'CLR_P': p_clr, 'Interpretation': interp
            })

            if target == 'thickness_mean':
                sig = "*" if p_clr < 0.05 else ""
                print(f"  {prop.replace('_prop',''):<28s} | {r_orig:>+8.3f} | {r_clr:>+8.3f} | {diff:>+7.3f} | {p_clr:>10.4f} {sig}")

    # B. BASAL vs DERIVED GROUPS
    print("  " + "-"*75)
    print("  GROUP COMPARISONS:")

    group_vars = [
        ('CLR_basal_mean', 'Basal (Labech+Clathr) CLR'),
        ('CLR_derived_mean', 'Derived (5 taxa) CLR'),
        ('log_derived_basal_ratio', 'Log(Derived/Basal) Ratio')
    ]

    for col, name in group_vars:
        if col not in data.columns:
            continue

        for target in targets:
            if target not in data.columns:
                continue

            x = data[col].values
            y = data[target].values

            r, p = safe_spearman(x, y)

            if np.isnan(r):
                continue

            global_clr_results.append({
                'Dataset': label, 'Level': 'Group',
                'Predictor': name, 'Target': target,
                'Original_Rho': np.nan, 'CLR_Rho': r,
                'Difference': np.nan, 'CLR_P': p, 'Interpretation': 'Group-level'
            })

            if target == 'thickness_mean':
                sig = "*" if p < 0.05 else ""
                print(f"  {name:<28s} |      N/A | {r:>+8.3f} |     N/A | {p:>10.4f} {sig}")
# =============================================================================
# SAVE CLR RESULTS
# =============================================================================
pd.DataFrame(global_clr_results).to_csv(f"{OUTPUT_DIR}/results_clr.csv", index=False)
print(f"\nSaved: {OUTPUT_DIR}/results_clr.csv ({len(global_clr_results)} rows)")
# Update predictor lists
if 'all_test_vars' in locals():
    new_vars = [
        ('log_derived_basal_ratio', 'Log(Derived/Basal)'),
        ('CLR_basal_mean', 'CLR Basal Mean'),
        ('CLR_derived_mean', 'CLR Derived Mean')
    ]
    for var in new_vars:
        if not any(v[0] == var[0] for v in all_test_vars):
            all_test_vars.append(var)



PRE-PROCESSING: CLR TRANSFORMATION
Transforming closed compositional data (proportions) to open log-ratios
  Stage-Level: Created 'log_derived_basal_ratio'
  Stage-Level: Generated 7 CLR variables.
  5-Myr Bins: Created 'log_derived_basal_ratio'
  5-Myr Bins: Generated 7 CLR variables.

------------------------------------------------------------------------------------------
CLR CORRELATION ANALYSIS
------------------------------------------------------------------------------------------

--- Stage-Level ---
  Taxon/Group                  |   Orig ρ |    CLR ρ |    Diff |      CLR p
  ---------------------------------------------------------------------------
  Labechiida                   |   -0.575 |   -0.725 |  -0.150 |     0.0001 *
  Clathrodictyida              |   +0.158 |   +0.009 |  -0.149 |     0.9681 
  Actinostromatida             |   +0.647 |   +0.262 |  -0.385 |     0.2396 
  Stromatoporida               |   +0.571 |   +0.173 |  -0.398 |     0.4406 
  Stromatoporellida 

In [11]:
# =============================================================================
# @title CELL 11: CORRELATION FUNCTIONS - SPEARMAN AND PEARSON
# =============================================================================
import numpy as np
import scipy.stats as stats

"""
CORRELATION ANALYSIS METHODS
============================
This cell defines functions for both Spearman and Pearson correlations.

SPEARMAN'S RHO (ρ):
- Non-parametric rank correlation
- Measures monotonic relationships (not just linear)
- Robust to outliers and non-normal distributions
- Appropriate for ordinal data or data with outliers
- Reference: Spearman, C. (1904). American Journal of Psychology, 15(1), 72-101.

PEARSON'S r:
- Parametric correlation coefficient
- Measures linear relationships specifically
- Assumes normally distributed variables
- More powerful when assumptions are met
- Reference: Pearson, K. (1895). Philosophical Transactions of the Royal Society A, 186, 343-414.

For geological time series:
- Spearman is often preferred due to non-normal distributions
- Pearson provides information on linear relationships
- Presenting both allows comparison and assessment of relationship type
"""

def calc_spearman(data, v1, v2):
    """
    Calculate Spearman rank correlation with significance

    Parameters:
    -----------
    data : DataFrame
        Data containing the variables
    v1, v2 : str
        Column names to correlate

    Returns:
    --------
    rho : float
        Spearman correlation coefficient
    p : float
        Two-tailed p-value
    n : int
        Number of valid pairs
    """
    # Check if columns exist
    if v1 not in data.columns or v2 not in data.columns:
        return np.nan, np.nan, 0

    valid = data[[v1, v2]].dropna()

    if len(valid) >= 5:
        # Spearmanr returns a Result object or tuple depending on version, generic unpacking is safer
        result = stats.spearmanr(valid[v1], valid[v2])
        # Handle cases where result might be a struct or tuple
        try:
            r, p = result.correlation, result.pvalue
        except AttributeError:
            r, p = result[0], result[1]

        return r, p, len(valid)

    return np.nan, np.nan, 0

def calc_pearson(data, v1, v2):
    """
    Calculate Pearson correlation with significance

    Parameters:
    -----------
    data : DataFrame
        Data containing the variables
    v1, v2 : str
        Column names to correlate

    Returns:
    --------
    r : float
        Pearson correlation coefficient
    p : float
        Two-tailed p-value
    n : int
        Number of valid pairs
    """
    if v1 not in data.columns or v2 not in data.columns:
        return np.nan, np.nan, 0

    valid = data[[v1, v2]].dropna()

    if len(valid) >= 5:
        r, p = stats.pearsonr(valid[v1], valid[v2])
        return r, p, len(valid)

    return np.nan, np.nan, 0

def calc_both_correlations(data, v1, v2):
    """
    Calculate both Spearman and Pearson correlations

    Returns:
    --------
    dict with both correlation results
    """
    spearman_r, spearman_p, n = calc_spearman(data, v1, v2)
    pearson_r, pearson_p, _ = calc_pearson(data, v1, v2)

    return {
        'spearman_rho': spearman_r,
        'spearman_p': spearman_p,
        'pearson_r': pearson_r,
        'pearson_p': pearson_p,
        'n': n
    }

def get_significance_stars(p):
    """Convert p-value to significance stars"""
    if p is None or np.isnan(p):
        return ''
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return 'ns'

print("✓ Correlation functions defined (Spearman and Pearson)")

✓ Correlation functions defined (Spearman and Pearson)


In [12]:
# =============================================================================
# @title CELL 12: COMPREHENSIVE CORRELATION ANALYSIS (MULTI-METRIC & 5-MYR BIOLOGY)
#   - NO T/W ratio computed or compared
#   - FIXES:
#       (1) Pairwise-complete correlations per predictor (NaNs in other columns won't block)
#       (2) UPDATED: NO group-level presence filters (Option 1 policy)
#           * All groups use the full dataset; missingness handled per-pair only
#       (3) Optional: treat coral zeros as missing (ONLY if zeros encode "no data", not true absence)
#   - Includes δ13C (expects column name 'd13C')
#   - Standardizes atmospheric O2/CO2 column names:
#       atmospheric_O2, atmospheric_CO2
# =============================================================================
import numpy as np
import pandas as pd
import scipy.stats as stats
import re

print("="*90)
print("COMPREHENSIVE CORRELATION ANALYSIS")
print("Metrics: Thickness, Width")
print("Scopes:  Stage-Level AND 5-Myr Bins")
print("Notes:   NO T/W ratio; pairwise-complete; NO group-level presence filters")
print("="*90)

# -----------------------------------------------------------------------------#
# SETTINGS
# -----------------------------------------------------------------------------#
MIN_N = 5
VERBOSE_SKIPS = False            # show reasons for non-results
TREAT_CORAL_ZEROS_AS_MISSING = False  # set True ONLY if 0 means "no measurement" (not true absence)

# -----------------------------------------------------------------------------#
# 0. SAFETY: define STROM_ORDERS if missing
# -----------------------------------------------------------------------------#
if 'STROM_ORDERS' not in globals():
    STROM_ORDERS = [
        'Labechiida', 'Clathrodictyida', 'Actinostromatida',
        'Stromatoporida', 'Stromatoporellida', 'Syringostromatida', 'Amphiporida'
    ]
    print("[WARN] STROM_ORDERS not found in globals(); using default list.")

# -----------------------------------------------------------------------------#
# 1. PRE-PROCESSING: CALCULATE DERIVED VARIABLES (NO T/W)
# -----------------------------------------------------------------------------#
print("Pre-processing data...")

def calculate_derived_metrics(dataset, label):
    if dataset is None or dataset.empty:
        print(f"  - [{label}] Empty dataset; skip derived metrics.")
        return dataset

    dataset = dataset.copy()

    # Remove T/W ratio if it exists from older runs (do NOT compute it)
    if 'thickness_width_ratio' in dataset.columns:
        dataset = dataset.drop(columns=['thickness_width_ratio'])

    # Stromatoporoid Proportions & Groups
    strom_cols = [c for c in dataset.columns if c.endswith('_occ') and c != 'strom_total_occ']
    if strom_cols:
        if 'strom_total_occ' not in dataset.columns:
            dataset['strom_total_occ'] = dataset[strom_cols].sum(axis=1)

        dataset['strom_total_occ'] = pd.to_numeric(dataset['strom_total_occ'], errors='coerce')

        for order in STROM_ORDERS:
            col_occ = f'{order}_occ'
            if col_occ in dataset.columns:
                dataset[col_occ] = pd.to_numeric(dataset[col_occ], errors='coerce')  # keep NaN
                dataset[f'{order}_prop'] = np.where(
                    dataset['strom_total_occ'] > 0,
                    dataset[col_occ] / dataset['strom_total_occ'],
                    np.nan
                )

        derived_orders = ['Actinostromatida', 'Stromatoporida', 'Stromatoporellida',
                          'Syringostromatida', 'Amphiporida']
        basal_orders   = ['Labechiida', 'Clathrodictyida']

        # occurrences: sum with min_count=1 so all-NaN stays NaN
        derived_occ_cols = [f'{o}_occ' for o in derived_orders if f'{o}_occ' in dataset.columns]
        basal_occ_cols   = [f'{o}_occ' for o in basal_orders   if f'{o}_occ' in dataset.columns]

        dataset['derived_strom_occ'] = (
            dataset[derived_occ_cols].apply(pd.to_numeric, errors='coerce').sum(axis=1, min_count=1)
            if derived_occ_cols else np.nan
        )
        dataset['basal_strom_occ'] = (
            dataset[basal_occ_cols].apply(pd.to_numeric, errors='coerce').sum(axis=1, min_count=1)
            if basal_occ_cols else np.nan
        )

        # diversity: sum with min_count=1 so all-NaN stays NaN
        derived_div_cols = [f'{o}_genus' for o in derived_orders if f'{o}_genus' in dataset.columns]
        basal_div_cols   = [f'{o}_genus' for o in basal_orders   if f'{o}_genus' in dataset.columns]

        dataset['derived_strom_div'] = (
            dataset[derived_div_cols].apply(pd.to_numeric, errors='coerce').sum(axis=1, min_count=1)
            if derived_div_cols else np.nan
        )
        dataset['basal_strom_div'] = (
            dataset[basal_div_cols].apply(pd.to_numeric, errors='coerce').sum(axis=1, min_count=1)
            if basal_div_cols else np.nan
        )

        dataset['derived_strom_prop'] = np.where(
            dataset['strom_total_occ'] > 0,
            dataset['derived_strom_occ'] / dataset['strom_total_occ'],
            np.nan
        )
        dataset['basal_strom_prop'] = np.where(
            dataset['strom_total_occ'] > 0,
            dataset['basal_strom_occ'] / dataset['strom_total_occ'],
            np.nan
        )

        print(f"  [OK] [{label}] Calculated strom proportions and groupings")
    else:
        print(f"  - [{label}] No strom occurrence columns found.")

    return dataset

df = calculate_derived_metrics(df, "Stage")
df_5myr = calculate_derived_metrics(df_5myr, "5-Myr")

# -----------------------------------------------------------------------------#
# 1A. STANDARDIZE ATMOSPHERIC O2/CO2 COLUMN NAMES (Stage + 5-Myr)
# -----------------------------------------------------------------------------#
def standardize_atm_cols(dataset, label):
    if dataset is None or dataset.empty:
        return dataset
    dataset = dataset.copy()

    def _find_col(candidates, regex_pat=None):
        for c in candidates:
            if c in dataset.columns:
                return c
        if regex_pat is not None:
            hits = [c for c in dataset.columns if re.search(regex_pat, c, flags=re.IGNORECASE)]
            return hits[0] if hits else None
        return None

    # Atmospheric O2
    if 'atmospheric_O2' not in dataset.columns:
        o2_src = _find_col(
            candidates=['atm_O2','atmospheric_o2','oxygen','pO2','PO2','O2_atm','oxygen_atm'],
            regex_pat=r'(atm|atmos).*o2|po2'
        )
        if o2_src is not None:
            dataset['atmospheric_O2'] = pd.to_numeric(dataset[o2_src], errors='coerce')
            print(f"  [OK] [{label}] atmospheric_O2 <- {o2_src}")
        else:
            print(f"  [WARN] [{label}] No atmospheric O2 column found (atmospheric_O2 missing).")

    # Atmospheric CO2
    if 'atmospheric_CO2' not in dataset.columns:
        co2_src = _find_col(
            candidates=['atm_CO2','atmospheric_co2','co2','pCO2','PCO2','CO2_atm','co2_atm'],
            regex_pat=r'(atm|atmos).*co2|pco2'
        )
        if co2_src is not None:
            dataset['atmospheric_CO2'] = pd.to_numeric(dataset[co2_src], errors='coerce')
            print(f"  [OK] [{label}] atmospheric_CO2 <- {co2_src}")
        else:
            print(f"  [WARN] [{label}] No atmospheric CO2 column found (atmospheric_CO2 missing).")

    return dataset

df = standardize_atm_cols(df, "Stage")
df_5myr = standardize_atm_cols(df_5myr, "5-Myr")

# -----------------------------------------------------------------------------#
# 1B. CORAL TOTALS (for optional coral zero-handling)
# -----------------------------------------------------------------------------#
def add_coral_totals(dataset, label):
    if dataset is None or dataset.empty:
        return dataset
    dataset = dataset.copy()

    for c in ['rugose_occ','tabulate_occ','rugose_div','tabulate_div']:
        if c in dataset.columns:
            dataset[c] = pd.to_numeric(dataset[c], errors='coerce')

    occ_cols = [c for c in ['rugose_occ','tabulate_occ'] if c in dataset.columns]
    div_cols = [c for c in ['rugose_div','tabulate_div'] if c in dataset.columns]

    if occ_cols:
        dataset['coral_total_occ'] = dataset[occ_cols].sum(axis=1, min_count=1)
    if div_cols:
        dataset['coral_total_div'] = dataset[div_cols].sum(axis=1, min_count=1)

    return dataset

df = add_coral_totals(df, "Stage")
df_5myr = add_coral_totals(df_5myr, "5-Myr")

# -----------------------------------------------------------------------------#
# 1C. BUILD DATASET VIEWS (UPDATED: NO PRESENCE FILTERS)
#     Option 1 policy: keep ALL rows; missingness handled pairwise in calc_stats_pairwise().
# -----------------------------------------------------------------------------#
df_all = df
df_5myr_all = df_5myr

# For compatibility with choose_view() and group logic, strom/coral views are identical.
df_strom = df_all
df_5myr_strom = df_5myr_all

df_coral = df_all
df_5myr_coral = df_5myr_all

# -----------------------------------------------------------------------------#
# 2. DEFINE VARIABLE GROUPS
# -----------------------------------------------------------------------------#
reef_targets = [
    ('thickness_mean', 'Thickness (Log)'),
    ('width_mean', 'Width (Log)')
]

strom_prop_vars = [('derived_strom_prop', 'Derived Proportion'), ('basal_strom_prop', 'Basal Proportion')] + \
                  [(f'{order}_prop', f'{order} Prop') for order in STROM_ORDERS]

strom_occ_vars = [('strom_total_occ', 'Total Occurrence'), ('derived_strom_occ', 'Derived Occurrence'), ('basal_strom_occ', 'Basal Occurrence')] + \
                 [(f'{order}_occ', f'{order} Occ') for order in STROM_ORDERS]

strom_div_vars = [('strom_total_gen', 'Total Diversity'), ('derived_strom_div', 'Derived Diversity'), ('basal_strom_div', 'Basal Diversity')] + \
                 [(f'{order}_genus', f'{order} Div') for order in STROM_ORDERS]

coral_div_vars = [('rugose_div', 'Rugose Diversity'), ('tabulate_div', 'Tabulate Diversity')]
coral_occ_vars = [('rugose_occ', 'Rugose Occurrence'), ('tabulate_occ', 'Tabulate Occurrence')]

env_vars_macrostrat = [
    ('total_area_km2', 'Total Area'),
    ('carbonate_area_km2', 'Carb Area'),
    ('carbonate_percentage', 'Carb %')
]

env_vars_proxies = [
    ('temperature', 'SST'),
    ('sea_level', 'Sea Level'),
    ('atmospheric_O2', 'Atm O2'),
    ('atmospheric_CO2', 'Atm CO2'),
    ('dissolved_O2', 'Dissolved O2'),
    ('d13C', 'δ¹³C')
]

var_groups = [
    ('Strom Props', strom_prop_vars),
    ('Strom Occ', strom_occ_vars),
    ('Strom Div', strom_div_vars),
    ('Coral Div', coral_div_vars),
    ('Coral Occ', coral_occ_vars),
    ('Macrostrat', env_vars_macrostrat),
    ('Proxies', env_vars_proxies),
]

STROM_GROUPS = {'Strom Props', 'Strom Occ', 'Strom Div'}
CORAL_GROUPS = {'Coral Div', 'Coral Occ'}

# -----------------------------------------------------------------------------#
# 3. ANALYSIS FUNCTIONS
# -----------------------------------------------------------------------------#
def get_significance_stars(p):
    if p is None or np.isnan(p):
        return ""
    if p < 0.001: return "***"
    if p < 0.01:  return "**"
    if p < 0.05:  return "*"
    return ""

def calc_stats_pairwise(x, y, min_n=5):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)

    mask = np.isfinite(x) & np.isfinite(y)
    x = x[mask]
    y = y[mask]
    n = int(len(x))

    out = dict(
        n=n,
        spearman_rho=np.nan, spearman_p=np.nan,
        pearson_r=np.nan,  pearson_p=np.nan,
        status="too_few"
    )

    if n < min_n:
        return out

    x_const = (np.unique(x).size <= 1)
    y_const = (np.unique(y).size <= 1)

    if x_const and y_const:
        out["status"] = "constant_both"; return out
    if x_const:
        out["status"] = "constant_x"; return out
    if y_const:
        out["status"] = "constant_y"; return out

    sr = stats.spearmanr(x, y)
    pr = stats.pearsonr(x, y)

    out.update(
        spearman_rho=float(sr.correlation),
        spearman_p=float(sr.pvalue),
        pearson_r=float(pr.statistic),
        pearson_p=float(pr.pvalue),
        status="ok"
    )
    return out

def choose_view(group_name, df_all_in, df_strom_in, df_coral_in):
    if group_name in STROM_GROUPS:
        return df_strom_in
    if group_name in CORAL_GROUPS:
        return df_coral_in
    return df_all_in

def run_correlation_suite(scope_label, df_all_in, df_strom_in, df_coral_in, target_col, target_name, min_n=5):
    if df_all_in is None or df_all_in.empty or target_col not in df_all_in.columns:
        print(f"[WARN] ({scope_label}) Missing target {target_col} or dataset empty; skipping {target_name}.")
        return []

    print("\n" + "-"*90)
    print(f"{scope_label} | TARGET: {target_name.upper()}")
    print("-"*90)
    print("{:<12} {:<35} {:>8} {:>10} {:>8} {:>10} {:>5}".format("Group", "Variable", "rho", "p(rho)", "r", "p(r)", "n"))

    results = []

    for group_name, group_vars in var_groups:
        use_df = choose_view(group_name, df_all_in, df_strom_in, df_coral_in)
        if use_df is None or use_df.empty or target_col not in use_df.columns:
            continue

        y = pd.to_numeric(use_df[target_col], errors='coerce').to_numpy(dtype=float)

        for var, label in group_vars:
            if var not in use_df.columns:
                continue

            x = pd.to_numeric(use_df[var], errors='coerce').to_numpy(dtype=float)

            # OPTIONAL: treat coral zeros as missing (ONLY if zeros mean "no data")
            if TREAT_CORAL_ZEROS_AS_MISSING and (group_name in CORAL_GROUPS):
                x = x.copy()
                x[x <= 0] = np.nan

            s = calc_stats_pairwise(x, y, min_n=min_n)

            if s["status"] == "ok":
                s_sig = get_significance_stars(s['spearman_p'])
                p_sig = get_significance_stars(s['pearson_p'])
                print(f"{group_name[:11]:<12} {label:35s} {s['spearman_rho']:8.2f}{s_sig:3s} {s['spearman_p']:10.3g} "
                      f"{s['pearson_r']:8.2f}{p_sig:3s} {s['pearson_p']:10.3g} {s['n']:5d}")
            else:
                if VERBOSE_SKIPS and s["n"] > 0:
                    print(f"{group_name[:11]:<12} {label:35s} {'':>8} {'':>10} {'':>8} {'':>10} {s['n']:5d}  [{s['status']}]")

            results.append({
                'Scope': scope_label,
                'Target': target_col,
                'Predictor': var,
                'Label': label,
                'Group': group_name,
                'Spearman_Rho': s['spearman_rho'],
                'Spearman_P': s['spearman_p'],
                'Pearson_R': s['pearson_r'],
                'Pearson_P': s['pearson_p'],
                'N': int(s['n']),
                'Status': s['status']
            })

    return results

# -----------------------------------------------------------------------------#
# 4. RUN: Stage + 5-Myr
# -----------------------------------------------------------------------------#
all_results_stage = []
for target_col, target_name in reef_targets:
    all_results_stage.extend(
        run_correlation_suite("Stage", df_all, df_strom, df_coral, target_col, target_name, min_n=MIN_N)
    )

all_results_5myr = []
for target_col, target_name in reef_targets:
    all_results_5myr.extend(
        run_correlation_suite("5-Myr", df_5myr_all, df_5myr_strom, df_5myr_coral, target_col, target_name, min_n=MIN_N)
    )

stage_results_df = pd.DataFrame(all_results_stage)
myr_results_df = pd.DataFrame(all_results_5myr)

# Keep only successful correlations for main outputs
stage_ok = stage_results_df[stage_results_df['Status'] == 'ok'].copy() if not stage_results_df.empty else stage_results_df
myr_ok   = myr_results_df[myr_results_df['Status'] == 'ok'].copy() if not myr_results_df.empty else myr_results_df

# -----------------------------------------------------------------------------#
# 5. SAVE + DISPLAY
# -----------------------------------------------------------------------------#
if 'OUTPUT_DIR' in globals():
    if not stage_ok.empty:
        stage_ok.to_csv(f"{OUTPUT_DIR}/results_correlations_stage.csv", index=False, encoding="utf-8-sig")
        print(f"\nSaved: {OUTPUT_DIR}/results_correlations_stage.csv")
    else:
        print("\n[WARN] Stage correlations empty (nothing met criteria).")

    if not myr_ok.empty:
        myr_ok.to_csv(f"{OUTPUT_DIR}/results_correlations_5myr.csv", index=False, encoding="utf-8-sig")
        print(f"Saved: {OUTPUT_DIR}/results_correlations_5myr.csv")
    else:
        print("[WARN] 5-Myr correlations empty (nothing met criteria).")

    # Full tables with skip reasons
    if not stage_results_df.empty:
        stage_results_df.to_csv(f"{OUTPUT_DIR}/results_correlations_stage_FULL_with_status.csv", index=False, encoding="utf-8-sig")
        print(f"Saved: {OUTPUT_DIR}/results_correlations_stage_FULL_with_status.csv")
    if not myr_results_df.empty:
        myr_results_df.to_csv(f"{OUTPUT_DIR}/results_correlations_5myr_FULL_with_status.csv", index=False, encoding="utf-8-sig")
        print(f"Saved: {OUTPUT_DIR}/results_correlations_5myr_FULL_with_status.csv")

display(stage_ok if not stage_ok.empty else pd.DataFrame())
display(myr_ok if not myr_ok.empty else pd.DataFrame())


COMPREHENSIVE CORRELATION ANALYSIS
Metrics: Thickness, Width
Scopes:  Stage-Level AND 5-Myr Bins
Notes:   NO T/W ratio; pairwise-complete; NO group-level presence filters
Pre-processing data...
  [OK] [Stage] Calculated strom proportions and groupings
  [OK] [5-Myr] Calculated strom proportions and groupings
  [OK] [Stage] atmospheric_O2 <- atm_O2
  [OK] [Stage] atmospheric_CO2 <- atm_CO2
  [OK] [5-Myr] atmospheric_O2 <- atm_O2
  [OK] [5-Myr] atmospheric_CO2 <- atm_CO2

------------------------------------------------------------------------------------------
Stage | TARGET: THICKNESS (LOG)
------------------------------------------------------------------------------------------
Group        Variable                                 rho     p(rho)        r       p(r)     n
Strom Props  Derived Proportion                      0.87***    2.5e-07     0.85***   8.77e-07    21
Strom Props  Basal Proportion                       -0.87***    2.5e-07    -0.85***   8.77e-07    21
Strom Props  L

Unnamed: 0,Scope,Target,Predictor,Label,Group,Spearman_Rho,Spearman_P,Pearson_R,Pearson_P,N,Status
0,Stage,thickness_mean,derived_strom_prop,Derived Proportion,Strom Props,0.872563,2.501763e-07,0.853369,8.771189e-07,21,ok
1,Stage,thickness_mean,basal_strom_prop,Basal Proportion,Strom Props,-0.872563,2.501763e-07,-0.853369,8.771189e-07,21,ok
2,Stage,thickness_mean,Labechiida_prop,Labechiida Prop,Strom Props,-0.778648,3.214033e-05,-0.733873,1.527280e-04,21,ok
3,Stage,thickness_mean,Clathrodictyida_prop,Clathrodictyida Prop,Strom Props,0.047557,8.378025e-01,-0.151025,5.134546e-01,21,ok
4,Stage,thickness_mean,Actinostromatida_prop,Actinostromatida Prop,Strom Props,0.613505,3.098310e-03,0.595730,4.377395e-03,21,ok
...,...,...,...,...,...,...,...,...,...,...,...
79,Stage,width_mean,sea_level,Sea Level,Proxies,-0.082463,7.152392e-01,-0.027155,9.045200e-01,22,ok
80,Stage,width_mean,atmospheric_O2,Atm O2,Proxies,0.215193,3.361696e-01,0.181302,4.193975e-01,22,ok
81,Stage,width_mean,atmospheric_CO2,Atm CO2,Proxies,-0.143462,5.241721e-01,-0.107114,6.351780e-01,22,ok
82,Stage,width_mean,dissolved_O2,Dissolved O2,Proxies,0.145722,5.175893e-01,0.152138,4.991127e-01,22,ok


Unnamed: 0,Scope,Target,Predictor,Label,Group,Spearman_Rho,Spearman_P,Pearson_R,Pearson_P,N,Status
0,5-Myr,thickness_mean,derived_strom_prop,Derived Proportion,Strom Props,0.878380,1.673329e-08,0.927584,7.061008e-11,24,ok
1,5-Myr,thickness_mean,basal_strom_prop,Basal Proportion,Strom Props,-0.878380,1.673329e-08,-0.927584,7.061008e-11,24,ok
2,5-Myr,thickness_mean,Labechiida_prop,Labechiida Prop,Strom Props,-0.828941,5.599693e-07,-0.825183,6.980943e-07,24,ok
3,5-Myr,thickness_mean,Clathrodictyida_prop,Clathrodictyida Prop,Strom Props,0.144726,4.998510e-01,-0.068724,7.496635e-01,24,ok
4,5-Myr,thickness_mean,Actinostromatida_prop,Actinostromatida Prop,Strom Props,0.708982,1.051787e-04,0.666808,3.730702e-04,24,ok
...,...,...,...,...,...,...,...,...,...,...,...
79,5-Myr,width_mean,sea_level,Sea Level,Proxies,-0.167762,4.029193e-01,-0.107911,5.921266e-01,27,ok
80,5-Myr,width_mean,atmospheric_O2,Atm O2,Proxies,0.135676,4.998249e-01,0.177231,3.764953e-01,27,ok
81,5-Myr,width_mean,atmospheric_CO2,Atm CO2,Proxies,-0.115814,5.651215e-01,-0.132356,5.104707e-01,27,ok
82,5-Myr,width_mean,dissolved_O2,Dissolved O2,Proxies,0.178457,3.731488e-01,0.190122,3.421826e-01,27,ok


In [17]:
# =============================================================================
# @title CELL 13: PERFORM ADVANCED STATISTICAL ANALYSES (Centralized) — PART A (REWRITE v3 + LOWESS)
#   Updated to enforce the SAME missingness/presence logic for stromatoporoids and corals:
#   - KEEP NaNs in MASTER; each analysis uses pairwise complete-case only for its variables
#   - Apply strom_total_occ > 0 for ALL strom/coral predictors (your existing rule)
#   - Apply coral_total_occ > 0 (or rugose+tabulate fallback) for coral predictors
#   - occ/div rule:
#       * For group-summary occ/div (derived/basal + rugose/tabulate), DO NOT require pred>0
#         (zeros are allowed once the group is present)
#       * For other occ/div (taxon-specific occ/div), require pred>0
#   - prop rule: allow zeros but require enough non-zero overall (signal gate)
# =============================================================================

import pandas as pd
import numpy as np
from scipy import stats
import statsmodels.api as sm
from statsmodels.nonparametric.smoothers_lowess import lowess
from pathlib import Path
import warnings
import re

warnings.filterwarnings("ignore")

# -------------------------
# Config
# -------------------------
DATA_DIR = Path("./output")

if "OUTPUT_DIR" in globals():
    try:
        OUTPUT_DIR = Path(OUTPUT_DIR)
    except Exception:
        OUTPUT_DIR = DATA_DIR
else:
    OUTPUT_DIR = DATA_DIR

N_BOOT = 10000
N_PERM = 10000
RNG_SEED = 0

MIN_N = 5
MIN_POSITIVE = 5          # only used for presence-only occ/div variables
MIN_NONZERO_PROP = 5
VERBOSE_SKIP_COUNTS = True

# LOWESS config
DO_LOWESS = True
LOWESS_FRAC = 0.4
LOWESS_IT = 0

print("=" * 80)
print("PERFORMING ADVANCED STATISTICAL ANALYSES (CENTRALIZED) — PART A (REWRITE v3 + LOWESS)")
print(f"Iterations: Bootstrap={N_BOOT}, Permutation={N_PERM}")
print(f"LOWESS: {DO_LOWESS} (frac={LOWESS_FRAC}, it={LOWESS_IT})")
print(f"Output dir: {OUTPUT_DIR}")
print("=" * 80)

# -------------------------
# Load dataset
# -------------------------
master_path = DATA_DIR / "MASTER_dataset_stage.csv"
df = pd.read_csv(master_path, encoding="utf-8-sig")
print(f"Loaded: {master_path.name} ({len(df)} rows)")

if df.empty:
    raise SystemExit("ERROR: Dataset is empty.")

target = "thickness_mean"
if target not in df.columns:
    raise SystemExit(f"ERROR: target column '{target}' not found.")

# numeric core
for c in ["midpoint_ma", target]:
    if c in df.columns:
        df[c] = pd.to_numeric(df[c], errors="coerce")

# -------------------------
# Standardize atmospheric columns -> atm_O2 / atm_CO2
# -------------------------
def _find_col(frame, candidates, regex_pat=None):
    for c in candidates:
        if c in frame.columns:
            return c
    if regex_pat:
        hits = [c for c in frame.columns if re.search(regex_pat, c, flags=re.IGNORECASE)]
        return hits[0] if hits else None
    return None

if "atm_O2" not in df.columns:
    o2_src = _find_col(
        df,
        ["atmospheric_O2","atmospheric_o2","atm_o2","pO2","PO2","oxygen","O2_atm","oxygen_atm"],
        r"(atm|atmos).*o2|po2"
    )
    if o2_src is not None:
        df["atm_O2"] = pd.to_numeric(df[o2_src], errors="coerce")
        print(f"[INFO] atm_O2 <- {o2_src}")

if "atm_CO2" not in df.columns:
    co2_src = _find_col(
        df,
        ["atmospheric_CO2","atmospheric_co2","atm_co2","pCO2","PCO2","co2","CO2_atm","co2_atm"],
        r"(atm|atmos).*co2|pco2"
    )
    if co2_src is not None:
        df["atm_CO2"] = pd.to_numeric(df[co2_src], errors="coerce")
        print(f"[INFO] atm_CO2 <- {co2_src}")

# -------------------------
# Global strom filter (kept)
# -------------------------
def _filter_no_strom(sub: pd.DataFrame) -> pd.DataFrame:
    if sub is None or sub.empty:
        return sub
    if "strom_total_occ" in sub.columns:
        v = pd.to_numeric(sub["strom_total_occ"], errors="coerce").fillna(0)
        sub = sub.loc[v > 0].copy()
    return sub

# -------------------------
# Global coral filter (added)
# -------------------------
def _filter_no_coral(sub: pd.DataFrame) -> pd.DataFrame:
    if sub is None or sub.empty:
        return sub
    if "coral_total_occ" in sub.columns:
        c = pd.to_numeric(sub["coral_total_occ"], errors="coerce").fillna(0)
        return sub.loc[c > 0].copy()
    # fallback if totals missing
    cols = [cc for cc in ["rugose_occ", "tabulate_occ"] if cc in sub.columns]
    if cols:
        tmp = sub[cols].apply(pd.to_numeric, errors="coerce").sum(axis=1, min_count=1).fillna(0)
        return sub.loc[tmp > 0].copy()
    return sub

# -------------------------
# Predictor typing + presence rules
# -------------------------
ENV_VARS = set([v for v in [
    "carbonate_area_km2","temperature","dissolved_O2",
    "d13C","atm_O2","atm_CO2",
    "total_area_km2","carbonate_percentage","sea_level"
] if v in df.columns])

COUNTLIKE_EXTRA = set([v for v in ["reef_count","strom_total_occ"] if v in df.columns])

STROM_GROUPS = set([
    "derived_strom_prop","basal_strom_prop",
    "derived_strom_occ","basal_strom_occ",
    "derived_strom_div","basal_strom_div",
    "strom_total_occ","strom_total_gen",
    "Labechiida_prop","Clathrodictyida_prop","Actinostromatida_prop","Stromatoporida_prop",
    "Stromatoporellida_prop","Syringostromatida_prop","Amphiporida_prop",
    "Labechiida_occ","Clathrodictyida_occ","Actinostromatida_occ","Stromatoporida_occ",
    "Stromatoporellida_occ","Syringostromatida_occ","Amphiporida_occ",
    "Labechiida_genus","Clathrodictyida_genus","Actinostromatida_genus","Stromatoporida_genus",
    "Stromatoporellida_genus","Syringostromatida_genus","Amphiporida_genus",
])

CORAL_GROUPS = set(["rugose_occ","tabulate_occ","rugose_div","tabulate_div"])

# IMPORTANT: these occ/div variables are "group summaries" where zeros are valid within a present group
KEEP_ZERO_OCCDIV = set([
    "derived_strom_occ","basal_strom_occ",
    "derived_strom_div","basal_strom_div",
    "rugose_occ","tabulate_occ",
    "rugose_div","tabulate_div",
    "strom_total_gen",
])

def _pred_kind(pred: str) -> str:
    if pred in ENV_VARS:
        return "env"
    if pred.endswith("_prop"):
        return "prop"
    if pred.endswith("_occ") or pred.endswith("_div") or pred in COUNTLIKE_EXTRA:
        return "occdiv"
    return "other"

def _passes_prop_signal(sub: pd.DataFrame, pred: str) -> bool:
    if sub is None or sub.empty or len(sub) < MIN_N:
        return False
    vals = pd.to_numeric(sub[pred], errors="coerce").values
    nonzero = int(np.sum(np.isfinite(vals) & (vals != 0)))
    return nonzero >= MIN_NONZERO_PROP

def _apply_occdiv_presence_rule(sub: pd.DataFrame, pred: str) -> pd.DataFrame:
    """
    occ/div rule:
      - if pred in KEEP_ZERO_OCCDIV: do NOT require pred>0
      - else: require pred>0
    """
    if sub is None or sub.empty:
        return sub
    if pred in KEEP_ZERO_OCCDIV:
        return sub
    return sub.loc[pd.to_numeric(sub[pred], errors="coerce") > 0].copy()

def get_analysis_data(d: pd.DataFrame, pred: str, targ: str) -> pd.DataFrame:
    cols = [pred, targ]
    if "stage" in d.columns: cols.append("stage")
    if "midpoint_ma" in d.columns: cols.append("midpoint_ma")
    if "strom_total_occ" in d.columns: cols.append("strom_total_occ")
    if "coral_total_occ" in d.columns: cols.append("coral_total_occ")
    for cc in ["rugose_occ","tabulate_occ"]:
        if cc in d.columns and cc not in cols:
            cols.append(cc)

    sub = d[cols].copy()
    sub[pred] = pd.to_numeric(sub[pred], errors="coerce")
    sub[targ] = pd.to_numeric(sub[targ], errors="coerce")

    # Pairwise complete-case
    sub = sub[np.isfinite(sub[pred].values) & np.isfinite(sub[targ].values)].copy()

    # Apply group presence filters
    if pred in STROM_GROUPS or pred in CORAL_GROUPS:
        sub = _filter_no_strom(sub)
    if pred in CORAL_GROUPS:
        sub = _filter_no_coral(sub)

    # Apply predictor-type presence
    kind = _pred_kind(pred)
    if kind == "occdiv":
        sub = _apply_occdiv_presence_rule(sub, pred)

    return sub

def get_lowess_data(d: pd.DataFrame, pred: str) -> pd.DataFrame:
    if "midpoint_ma" not in d.columns:
        return d.iloc[0:0].copy()

    cols = ["midpoint_ma", pred]
    if "stage" in d.columns: cols.append("stage")
    if "strom_total_occ" in d.columns: cols.append("strom_total_occ")
    if "coral_total_occ" in d.columns: cols.append("coral_total_occ")
    for cc in ["rugose_occ","tabulate_occ"]:
        if cc in d.columns and cc not in cols:
            cols.append(cc)

    sub = d[cols].copy()
    sub["midpoint_ma"] = pd.to_numeric(sub["midpoint_ma"], errors="coerce")
    sub[pred] = pd.to_numeric(sub[pred], errors="coerce")

    sub = sub[np.isfinite(sub["midpoint_ma"].values) & np.isfinite(sub[pred].values)].copy()

    if pred in STROM_GROUPS or pred in CORAL_GROUPS:
        sub = _filter_no_strom(sub)
    if pred in CORAL_GROUPS:
        sub = _filter_no_coral(sub)

    kind = _pred_kind(pred)
    if kind == "occdiv":
        sub = _apply_occdiv_presence_rule(sub, pred)

    return sub

def _safe_spearman(x, y):
    x = pd.to_numeric(pd.Series(x), errors="coerce")
    y = pd.to_numeric(pd.Series(y), errors="coerce")
    m = np.isfinite(x.values) & np.isfinite(y.values)
    if m.sum() < MIN_N:
        return np.nan, np.nan
    res = stats.spearmanr(x.values[m], y.values[m])
    return float(res.correlation), float(res.pvalue)

def _min_n_for_partial(k_controls: int) -> int:
    return max(6, k_controls + 3)

rng = np.random.default_rng(RNG_SEED)

# -------------------------
# Predictor list
# -------------------------
all_predictors = []

for v in ["derived_strom_prop", "basal_strom_prop", "derived_strom_occ", "basal_strom_occ", "derived_strom_div", "basal_strom_div"]:
    if v in df.columns:
        all_predictors.append(v)

for v in ["rugose_occ","tabulate_occ","rugose_div","tabulate_div"]:
    if v in df.columns:
        all_predictors.append(v)

for v in ["carbonate_area_km2","temperature","dissolved_O2",
          "d13C","atm_O2","atm_CO2","total_area_km2","carbonate_percentage","sea_level"]:
    if v in df.columns and v not in all_predictors:
        all_predictors.append(v)

for v in [
    "Labechiida_prop","Clathrodictyida_prop","Actinostromatida_prop",
    "Stromatoporida_prop","Stromatoporellida_prop","Syringostromatida_prop","Amphiporida_prop"
]:
    if v in df.columns:
        all_predictors.append(v)

all_predictors = list(dict.fromkeys(all_predictors))
print("[INFO] Predictors included:", all_predictors)

# ==============================================================================
# LOWESS (pred vs time)
# ==============================================================================
if DO_LOWESS:
    print("\n[LOWESS] Computing LOWESS smooths vs time (midpoint_ma)...")

    if "midpoint_ma" not in df.columns:
        print("  [WARN] midpoint_ma missing; LOWESS skipped.")
    else:
        lowess_predictors = []
        for v in [
            "basal_strom_prop","derived_strom_prop",
            "basal_strom_occ","derived_strom_occ",
            "basal_strom_div","derived_strom_div"
        ]:
            if v in df.columns:
                lowess_predictors.append(v)

        lowess_predictors = list(dict.fromkeys(lowess_predictors))
        print("  LOWESS predictors:", lowess_predictors)

        lowess_rows = []
        for pred in lowess_predictors:
            sub = get_lowess_data(df, pred)
            kind = _pred_kind(pred)

            if kind == "prop" and len(sub) >= MIN_N and (not _passes_prop_signal(sub, pred)):
                continue
            if len(sub) < MIN_N:
                continue

            sub2 = sub.sort_values("midpoint_ma").reset_index(drop=True)
            x = sub2["midpoint_ma"].values.astype(float)
            y = sub2[pred].values.astype(float)

            try:
                fit = lowess(endog=y, exog=x, frac=LOWESS_FRAC, it=LOWESS_IT, return_sorted=True)
                for xi, yhat in fit:
                    lowess_rows.append({
                        "Predictor": pred,
                        "midpoint_ma": float(xi),
                        "lowess_y": float(yhat),
                        "N_used": int(len(sub2)),
                        "frac": float(LOWESS_FRAC),
                        "it": int(LOWESS_IT)
                    })
            except Exception:
                continue

        lowess_df = pd.DataFrame(lowess_rows)
        lowess_df.to_csv(OUTPUT_DIR / "results_lowess_predictors_vs_time_long.csv",
                         index=False, encoding="utf-8-sig")
        print("  -> Saved results_lowess_predictors_vs_time_long.csv")
        display(lowess_df.head(20))

# ==============================================================================
# From here down: advanced stats blocks.
# basal_strom_prop excluded for multivariate blocks to avoid singularity.
# ==============================================================================
stats_predictors = [p for p in all_predictors if p != "basal_strom_prop"]

print("\n[INFO] Predictors used for correlation/bootstrap/permutation/LOO/detrend/partial/univariate:",
      stats_predictors)

# ==============================================================================
# 0. ORIGINAL SPEARMAN CORRELATIONS
# ==============================================================================
print("\n0. Computing original Spearman correlations...")
corr_rows = []
skip_counts = {"too_few":0, "prop_low_signal":0, "ok":0}

for pred in stats_predictors:
    sub = get_analysis_data(df, pred, target)
    kind = _pred_kind(pred)

    if kind == "prop" and len(sub) >= MIN_N and (not _passes_prop_signal(sub, pred)):
        corr_rows.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(len(sub)),
                          "spearman_rho":np.nan,"spearman_p":np.nan,"Status":"SKIP_prop_low_signal"})
        skip_counts["prop_low_signal"] += 1
        continue

    r, p = _safe_spearman(sub[pred], sub[target]) if len(sub) else (np.nan, np.nan)
    status = "OK" if (len(sub) >= MIN_N and np.isfinite(r)) else "SKIP_too_few_rows"
    corr_rows.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(len(sub)),
                      "spearman_rho":r,"spearman_p":p,"Status":status})
    skip_counts["ok" if status=="OK" else "too_few"] += 1

corr_df = pd.DataFrame(corr_rows)
corr_df.to_csv(OUTPUT_DIR / "results_spearman_original_all_predictors.csv", index=False, encoding="utf-8-sig")
print("  -> Saved results_spearman_original_all_predictors.csv")
if VERBOSE_SKIP_COUNTS:
    print("[INFO] Spearman status counts:", skip_counts)

# ==============================================================================
# 1. BOOTSTRAP
# ==============================================================================
print(f"\n1. Running Bootstrap ({N_BOOT} iter)...")
boot_summary = []
boot_dist_long = []

for pred in stats_predictors:
    sub = get_analysis_data(df, pred, target)
    kind = _pred_kind(pred)
    n = len(sub)

    if kind == "prop" and n >= MIN_N and (not _passes_prop_signal(sub, pred)):
        boot_summary.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(n),
                             "spearman_observed":np.nan,"spearman_ci_low":np.nan,"spearman_ci_high":np.nan,
                             "Status":"SKIP_prop_low_signal"})
        continue

    if n < MIN_N:
        boot_summary.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(n),
                             "spearman_observed":np.nan,"spearman_ci_low":np.nan,"spearman_ci_high":np.nan,
                             "Status":"SKIP_too_few_rows"})
        continue

    rho_obs, _ = stats.spearmanr(sub[pred], sub[target])

    idx = np.arange(n)
    rhos = np.empty(N_BOOT, dtype=float)
    for b in range(N_BOOT):
        s = rng.choice(idx, size=n, replace=True)
        r, _ = stats.spearmanr(sub[pred].iloc[s], sub[target].iloc[s])
        rhos[b] = r

    ci_low = np.nanpercentile(rhos, 2.5)
    ci_high = np.nanpercentile(rhos, 97.5)

    boot_summary.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(n),
                         "spearman_observed":float(rho_obs) if np.isfinite(rho_obs) else np.nan,
                         "spearman_ci_low":float(ci_low),"spearman_ci_high":float(ci_high),"Status":"OK"})
    boot_dist_long.extend([{"Predictor":pred,"iter":i+1,"rho":float(rhos[i])} for i in range(N_BOOT)])

pd.DataFrame(boot_summary).to_csv(OUTPUT_DIR / "results_bootstrap.csv", index=False, encoding="utf-8-sig")
pd.DataFrame(boot_dist_long).to_csv(OUTPUT_DIR / "results_bootstrap_dist_all_predictors_long.csv", index=False, encoding="utf-8-sig")
print("  -> Saved results_bootstrap.csv")
print("  -> Saved results_bootstrap_dist_all_predictors_long.csv")

# ==============================================================================
# 2. PERMUTATION
# ==============================================================================
print(f"2. Running Permutation ({N_PERM} iter)...")
perm_summary = []
perm_dist_long = []

for pred in stats_predictors:
    sub = get_analysis_data(df, pred, target)
    kind = _pred_kind(pred)
    n = len(sub)

    if kind == "prop" and n >= MIN_N and (not _passes_prop_signal(sub, pred)):
        perm_summary.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(n),
                             "spearman_observed":np.nan,"spearman_p_permutation":np.nan,
                             "Status":"SKIP_prop_low_signal"})
        continue

    if n < MIN_N:
        perm_summary.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(n),
                             "spearman_observed":np.nan,"spearman_p_permutation":np.nan,
                             "Status":"SKIP_too_few_rows"})
        continue

    rho_obs, _ = stats.spearmanr(sub[pred], sub[target])
    y = sub[target].values

    perm_rhos = np.empty(N_PERM, dtype=float)
    extreme = 0
    for i in range(N_PERM):
        y_perm = rng.permutation(y)
        r_perm, _ = stats.spearmanr(sub[pred].values, y_perm)
        perm_rhos[i] = r_perm
        if np.isfinite(r_perm) and np.isfinite(rho_obs) and (abs(r_perm) >= abs(rho_obs)):
            extreme += 1

    p_val = (extreme + 1) / (N_PERM + 1)
    perm_summary.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(n),
                         "spearman_observed":float(rho_obs) if np.isfinite(rho_obs) else np.nan,
                         "spearman_p_permutation":float(p_val),"Status":"OK"})
    perm_dist_long.extend([{"Predictor":pred,"iter":i+1,"rho":float(perm_rhos[i])} for i in range(N_PERM)])

pd.DataFrame(perm_summary).to_csv(OUTPUT_DIR / "results_permutation.csv", index=False, encoding="utf-8-sig")
pd.DataFrame(perm_dist_long).to_csv(OUTPUT_DIR / "results_permutation_dist_all_predictors_long.csv", index=False, encoding="utf-8-sig")
print("  -> Saved results_permutation.csv")
print("  -> Saved results_permutation_dist_all_predictors_long.csv")

# ==============================================================================
# 3. LEAVE-ONE-OUT
# ==============================================================================
print("3. Running Leave-One-Out...")
loo_detailed_long = []
loo_summary = []

for pred in stats_predictors:
    cols = ["stage", pred, target]
    if "strom_total_occ" in df.columns: cols.append("strom_total_occ")
    if "coral_total_occ" in df.columns: cols.append("coral_total_occ")
    for cc in ["rugose_occ","tabulate_occ"]:
        if cc in df.columns and cc not in cols:
            cols.append(cc)
    if "midpoint_ma" in df.columns: cols.append("midpoint_ma")

    v = df[cols].copy()
    for c in [pred, target]:
        v[c] = pd.to_numeric(v[c], errors="coerce")
    v = v.dropna(subset=["stage", pred, target])

    if pred in STROM_GROUPS or pred in CORAL_GROUPS:
        v = _filter_no_strom(v)
    if pred in CORAL_GROUPS:
        v = _filter_no_coral(v)

    kind = _pred_kind(pred)
    if kind == "occdiv":
        v = _apply_occdiv_presence_rule(v, pred)

    if kind == "prop" and len(v) >= MIN_N and (not _passes_prop_signal(v, pred)):
        loo_summary.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(len(v)),
                            "LOO_Mean_Rho":np.nan,"LOO_Max_P":np.nan,"LOO_Min_Rho":np.nan,"LOO_Max_Rho":np.nan,
                            "Status":"SKIP_prop_low_signal"})
        continue

    v = v.reset_index(drop=True)
    n = len(v)
    if n < MIN_N:
        loo_summary.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(n),
                            "LOO_Mean_Rho":np.nan,"LOO_Max_P":np.nan,"LOO_Min_Rho":np.nan,"LOO_Max_Rho":np.nan,
                            "Status":"SKIP_too_few_rows"})
        continue

    rho_full, _ = stats.spearmanr(v[pred], v[target])

    rhos, ps = [], []
    for i in range(n):
        vv = v.drop(index=i)
        r, p = _safe_spearman(vv[pred], vv[target])
        rhos.append(r); ps.append(p)
        loo_detailed_long.append({
            "Predictor": pred,
            "Stage_Dropped": v.loc[i, "stage"],
            "LOO_Rho": r,
            "Diff_from_Full": (r - rho_full) if (np.isfinite(r) and np.isfinite(rho_full)) else np.nan
        })

    loo_summary.append({
        "Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":int(n),
        "LOO_Mean_Rho":float(np.nanmean(rhos)),
        "LOO_Max_P":float(np.nanmax(ps)),
        "LOO_Min_Rho":float(np.nanmin(rhos)),
        "LOO_Max_Rho":float(np.nanmax(rhos)),
        "Status":"OK"
    })

pd.DataFrame(loo_detailed_long).to_csv(OUTPUT_DIR / "results_loo_detailed_all_predictors_long.csv", index=False, encoding="utf-8-sig")
pd.DataFrame(loo_summary).to_csv(OUTPUT_DIR / "results_loo_summary.csv", index=False, encoding="utf-8-sig")
print("  -> Saved results_loo_detailed_all_predictors_long.csv")
print("  -> Saved results_loo_summary.csv")

# ==============================================================================
# 4. DETRENDING
# ==============================================================================
print("4. Running Detrending (status logged)...")
det_results = []

if "midpoint_ma" not in df.columns:
    print("  [ERROR] midpoint_ma missing -> detrending cannot run.")
else:
    for pred in stats_predictors:
        row = {"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"N":0,
               "Detrend_Rho":np.nan,"Detrend_P":np.nan,"Status":"INIT"}

        cols = ["midpoint_ma", pred, target]
        if "strom_total_occ" in df.columns: cols.append("strom_total_occ")
        if "coral_total_occ" in df.columns: cols.append("coral_total_occ")
        for cc in ["rugose_occ","tabulate_occ"]:
            if cc in df.columns and cc not in cols:
                cols.append(cc)

        v = df[cols].copy()
        v["midpoint_ma"] = pd.to_numeric(v["midpoint_ma"], errors="coerce")
        v[pred] = pd.to_numeric(v[pred], errors="coerce")
        v[target] = pd.to_numeric(v[target], errors="coerce")

        v = v.dropna(subset=["midpoint_ma", pred, target])

        if pred in STROM_GROUPS or pred in CORAL_GROUPS:
            v = _filter_no_strom(v)
        if pred in CORAL_GROUPS:
            v = _filter_no_coral(v)

        kind = _pred_kind(pred)
        if kind == "occdiv":
            v = _apply_occdiv_presence_rule(v, pred)
        elif kind == "prop":
            if len(v) >= MIN_N and (not _passes_prop_signal(v, pred)):
                row["N"] = int(len(v))
                row["Status"] = "SKIP_prop_low_signal"
                det_results.append(row)
                continue

        n = len(v)
        row["N"] = int(n)
        if n < MIN_N:
            row["Status"] = "SKIP_too_few_rows"
            det_results.append(row)
            continue
        if v[pred].nunique(dropna=True) < 2 or float(np.nanstd(v[pred].values)) == 0.0:
            row["Status"] = "SKIP_constant_predictor"
            det_results.append(row)
            continue

        try:
            sx, ix, *_ = stats.linregress(v["midpoint_ma"], v[pred])
            sy, iy, *_ = stats.linregress(v["midpoint_ma"], v[target])
            resid_x = v[pred] - (sx * v["midpoint_ma"] + ix)
            resid_y = v[target] - (sy * v["midpoint_ma"] + iy)
            r, p = stats.spearmanr(resid_x, resid_y, nan_policy="omit")
            row["Detrend_Rho"] = float(r) if np.isfinite(r) else np.nan
            row["Detrend_P"] = float(p) if np.isfinite(p) else np.nan
            row["Status"] = "OK"
        except Exception as e:
            row["Status"] = f"FAIL_{type(e).__name__}"

        det_results.append(row)

pd.DataFrame(det_results).to_csv(OUTPUT_DIR / "results_detrending_improved.csv", index=False, encoding="utf-8-sig")
print("  -> Saved results_detrending_improved.csv")

# ==============================================================================
# 5. PARTIAL CORRELATIONS
# ==============================================================================
print("\n5. Running Partial Correlations (predictor-type aware)...")

env_controls = [v for v in ["carbonate_area_km2","temperature","dissolved_O2"] if v in df.columns]
biotic_controls_for_env = [v for v in ["derived_strom_prop"] if v in df.columns]

def _partial_spearman(sub: pd.DataFrame, pred: str, targ: str, controls: list):
    Xc = sm.add_constant(sub[controls], has_constant="add")
    res_pred = sm.OLS(sub[pred].values.astype(float), Xc.values.astype(float)).fit().resid
    res_targ = sm.OLS(sub[targ].values.astype(float), Xc.values.astype(float)).fit().resid
    r, p = stats.spearmanr(res_pred, res_targ, nan_policy="omit")
    return float(r) if np.isfinite(r) else np.nan, float(p) if np.isfinite(p) else np.nan

part_rows = []

for pred in stats_predictors:
    kind = _pred_kind(pred)
    if kind == "env":
        controls = biotic_controls_for_env[:]
        test_type = "Environment (Biotic Controlled)"
    else:
        controls = env_controls[:]
        test_type = "Biotic/Taxon (Env Controlled)"

    controls = [c for c in controls if c in df.columns]
    if len(controls) < 1:
        part_rows.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"Test_Type":test_type,
                          "Controls":",".join(controls),"N":0,"spearman_partial":np.nan,"spearman_p_partial":np.nan,
                          "Status":"SKIP_insufficient_controls"})
        continue

    cols = [pred, target] + controls
    if "strom_total_occ" in df.columns: cols.append("strom_total_occ")
    if "coral_total_occ" in df.columns: cols.append("coral_total_occ")
    for cc in ["rugose_occ","tabulate_occ"]:
        if cc in df.columns and cc not in cols:
            cols.append(cc)

    sub = df[cols].copy()
    for c in [pred, target] + controls:
        sub[c] = pd.to_numeric(sub[c], errors="coerce")

    sub = sub.dropna(subset=[pred, target] + controls)

    if pred in STROM_GROUPS or pred in CORAL_GROUPS:
        sub = _filter_no_strom(sub)
    if pred in CORAL_GROUPS:
        sub = _filter_no_coral(sub)

    if kind == "occdiv":
        sub = _apply_occdiv_presence_rule(sub, pred)
        # Only enforce MIN_POSITIVE for presence-only occ/div variables
        if pred not in KEEP_ZERO_OCCDIV and len(sub) < MIN_POSITIVE:
            part_rows.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"Test_Type":test_type,
                              "Controls":",".join(controls),"N":int(len(sub)),
                              "spearman_partial":np.nan,"spearman_p_partial":np.nan,
                              "Status":"SKIP_too_few_positive"})
            continue
    elif kind == "prop":
        if len(sub) >= MIN_N and (not _passes_prop_signal(sub, pred)):
            part_rows.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"Test_Type":test_type,
                              "Controls":",".join(controls),"N":int(len(sub)),
                              "spearman_partial":np.nan,"spearman_p_partial":np.nan,
                              "Status":"SKIP_prop_low_signal"})
            continue

    n = len(sub)
    min_n = _min_n_for_partial(len(controls))
    if n < min_n:
        part_rows.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"Test_Type":test_type,
                          "Controls":",".join(controls),"N":int(n),
                          "spearman_partial":np.nan,"spearman_p_partial":np.nan,
                          "Status":f"SKIP_too_few_rows(n<{min_n})"})
        continue

    if sub[pred].nunique(dropna=True) < 2 or float(np.nanstd(sub[pred].values)) == 0.0:
        part_rows.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"Test_Type":test_type,
                          "Controls":",".join(controls),"N":int(n),
                          "spearman_partial":np.nan,"spearman_p_partial":np.nan,
                          "Status":"SKIP_constant_predictor"})
        continue

    try:
        r, p = _partial_spearman(sub, pred, target, controls)
        part_rows.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"Test_Type":test_type,
                          "Controls":",".join(controls),"N":int(n),
                          "spearman_partial":r,"spearman_p_partial":p,"Status":"OK"})
    except Exception as e:
        part_rows.append({"Dataset":"Stage-Level Data","Target":target,"Predictor":pred,"Test_Type":test_type,
                          "Controls":",".join(controls),"N":int(n),
                          "spearman_partial":np.nan,"spearman_p_partial":np.nan,
                          "Status":f"FAIL_{type(e).__name__}"})

part_df = pd.DataFrame(part_rows).sort_values(["Status","spearman_p_partial"], na_position="last")
part_df.to_csv(OUTPUT_DIR / "results_partial_correlations.csv", index=False, encoding="utf-8-sig")
print("  -> Saved results_partial_correlations.csv")

print("\n[DISPLAY] Partial correlations (OK rows):")
display(part_df[part_df["Status"] == "OK"].sort_values("spearman_p_partial", na_position="last"))

print("\n[DISPLAY] Partial correlations (all rows incl. skips/fails):")
display(part_df)

# ==============================================================================
# 6. VARIANCE PARTITIONING (strict)
# ==============================================================================
print("6. Running Variance Partitioning (groups; strict subset)...")

def get_adj_r2_group(data: pd.DataFrame, y_col: str, x_cols: list) -> float:
    if len(x_cols) == 0:
        return np.nan
    vv = data[[y_col] + x_cols].copy()
    for c in [y_col] + x_cols:
        vv[c] = pd.to_numeric(vv[c], errors="coerce")
    vv = vv.dropna()
    if len(vv) < len(x_cols) + 2:
        return np.nan
    # drop constant predictors to avoid singular designs
    keep = []
    for xc in x_cols:
        if vv[xc].nunique(dropna=True) >= 2 and float(np.nanstd(vv[xc].values)) > 0.0:
            keep.append(xc)
    if len(keep) == 0:
        return np.nan
    X = sm.add_constant(vv[keep], has_constant="add")
    y = vv[y_col].values.astype(float)
    try:
        return float(sm.OLS(y, X.values.astype(float)).fit(method="qr").rsquared_adj)
    except Exception:
        return np.nan

A = [v for v in ["derived_strom_prop"] if v in df.columns]  # derived only
B = [v for v in ["rugose_occ","tabulate_occ"] if v in df.columns]
C = env_controls[:]
ABC = A + B + C

vp_cols = [target] + ABC
if "strom_total_occ" in df.columns: vp_cols.append("strom_total_occ")
if "coral_total_occ" in df.columns: vp_cols.append("coral_total_occ")
for cc in ["rugose_occ","tabulate_occ"]:
    if cc in df.columns and cc not in vp_cols:
        vp_cols.append(cc)

sub_vp = df[vp_cols].copy()
for c in [target] + ABC:
    sub_vp[c] = pd.to_numeric(sub_vp[c], errors="coerce")
sub_vp = sub_vp.dropna(subset=[target] + ABC)

# Apply group presence (do NOT force rugose/tabulate >0; zeros allowed once corals present)
sub_vp = _filter_no_strom(sub_vp)
sub_vp = _filter_no_coral(sub_vp)

# Prop signal gate for A
for pred in A:
    if not _passes_prop_signal(sub_vp, pred):
        sub_vp = sub_vp.iloc[0:0].copy()
        break

r2_abc = get_adj_r2_group(sub_vp, target, ABC)
r2_bc  = get_adj_r2_group(sub_vp, target, B + C)
r2_ac  = get_adj_r2_group(sub_vp, target, A + C)
r2_ab  = get_adj_r2_group(sub_vp, target, A + B)

unique_a = r2_abc - r2_bc if np.isfinite(r2_abc) and np.isfinite(r2_bc) else np.nan
unique_b = r2_abc - r2_ac if np.isfinite(r2_abc) and np.isfinite(r2_ac) else np.nan
unique_c = r2_abc - r2_ab if np.isfinite(r2_abc) and np.isfinite(r2_ab) else np.nan
shared   = r2_abc - (unique_a + unique_b + unique_c) if np.isfinite(r2_abc) else np.nan

var_out = pd.DataFrame([{
    "Dataset":"Stage-Level Data",
    "Target":target,
    "Strom_Group":",".join(A),
    "Coral_Group":",".join(B),
    "Env_Group":",".join(C),
    "Unique_Strom":unique_a,
    "Unique_Coral":unique_b,
    "Unique_Env":unique_c,
    "Shared":shared,
    "Residual":(1 - r2_abc) if np.isfinite(r2_abc) else np.nan,
    "Total_R2_adj":r2_abc,
    "N":int(len(sub_vp))
}])
var_out.to_csv(OUTPUT_DIR / "results_variance_partition_improved.csv", index=False, encoding="utf-8-sig")
print("  -> Saved results_variance_partition_improved.csv")

# ==============================================================================
# 6b. Univariate adj-R2
# ==============================================================================
print("6b. Running Univariate adj-R2 (robust)...")
uni_rows = []

for pred in stats_predictors:
    sub = get_analysis_data(df, pred, target)
    kind = _pred_kind(pred)

    if kind == "prop" and len(sub) >= MIN_N and (not _passes_prop_signal(sub, pred)):
        uni_rows.append({"Predictor":pred,"N":int(len(sub)),"Adj_R2_univariate":np.nan,"Status":"SKIP_prop_low_signal"})
        continue

    n = len(sub)
    if n < 6:
        uni_rows.append({"Predictor":pred,"N":int(n),"Adj_R2_univariate":np.nan,"Status":"SKIP_too_few_rows"})
        continue

    x = pd.to_numeric(sub[pred], errors="coerce")
    y = pd.to_numeric(sub[target], errors="coerce")
    m = np.isfinite(x.values) & np.isfinite(y.values)
    x = x[m]; y = y[m]

    if len(x) < 6:
        uni_rows.append({"Predictor":pred,"N":int(len(x)),"Adj_R2_univariate":np.nan,"Status":"SKIP_too_few_rows"})
        continue
    if x.nunique(dropna=True) < 2 or float(np.nanstd(x.values)) == 0.0:
        uni_rows.append({"Predictor":pred,"N":int(len(x)),"Adj_R2_univariate":np.nan,"Status":"SKIP_constant_predictor"})
        continue

    X = sm.add_constant(pd.DataFrame({pred: x.values}), has_constant="add").values.astype(float)
    yy = y.values.astype(float)

    try:
        fit = sm.OLS(yy, X).fit(method="qr")
        uni_rows.append({"Predictor":pred,"N":int(len(x)),"Adj_R2_univariate":float(fit.rsquared_adj),"Status":"OK"})
    except Exception as e:
        uni_rows.append({"Predictor":pred,"N":int(len(x)),"Adj_R2_univariate":np.nan,"Status":f"FAIL_{type(e).__name__}"})

pd.DataFrame(uni_rows).to_csv(OUTPUT_DIR / "results_adjR2_univariate_all_predictors.csv", index=False, encoding="utf-8-sig")
print("  -> Saved results_adjR2_univariate_all_predictors.csv")

CELL13_PART_A_DONE = True
print("\n✓ CELL 13 PART A COMPLETE (v3 + LOWESS; strom+coral summary occ/div allow zeros once present).")


PERFORMING ADVANCED STATISTICAL ANALYSES (CENTRALIZED) — PART A (REWRITE v3 + LOWESS)
Iterations: Bootstrap=10000, Permutation=10000
LOWESS: True (frac=0.4, it=0)
Output dir: output
Loaded: MASTER_dataset_stage.csv (22 rows)
[INFO] Predictors included: ['derived_strom_prop', 'basal_strom_prop', 'derived_strom_occ', 'basal_strom_occ', 'derived_strom_div', 'basal_strom_div', 'rugose_occ', 'tabulate_occ', 'rugose_div', 'tabulate_div', 'carbonate_area_km2', 'temperature', 'dissolved_O2', 'd13C', 'atm_O2', 'atm_CO2', 'total_area_km2', 'carbonate_percentage', 'sea_level', 'Labechiida_prop', 'Clathrodictyida_prop', 'Actinostromatida_prop', 'Stromatoporida_prop', 'Stromatoporellida_prop', 'Syringostromatida_prop', 'Amphiporida_prop']

[LOWESS] Computing LOWESS smooths vs time (midpoint_ma)...
  LOWESS predictors: ['basal_strom_prop', 'derived_strom_prop', 'basal_strom_occ', 'derived_strom_occ', 'basal_strom_div', 'derived_strom_div']
  -> Saved results_lowess_predictors_vs_time_long.csv


Unnamed: 0,Predictor,midpoint_ma,lowess_y,N_used,frac,it
0,basal_strom_prop,365.55,0.27993,21,0.4,0
1,basal_strom_prop,377.45,0.20224,21,0.4,0
2,basal_strom_prop,385.2,0.154829,21,0.4,0
3,basal_strom_prop,390.5,0.122368,21,0.4,0
4,basal_strom_prop,400.45,0.113658,21,0.4,0
5,basal_strom_prop,409.2,0.136065,21,0.4,0
6,basal_strom_prop,415.0,0.16071,21,0.4,0
7,basal_strom_prop,421.1,0.321537,21,0.4,0
8,basal_strom_prop,424.3,0.416435,21,0.4,0
9,basal_strom_prop,426.5,0.482884,21,0.4,0



[INFO] Predictors used for correlation/bootstrap/permutation/LOO/detrend/partial/univariate: ['derived_strom_prop', 'derived_strom_occ', 'basal_strom_occ', 'derived_strom_div', 'basal_strom_div', 'rugose_occ', 'tabulate_occ', 'rugose_div', 'tabulate_div', 'carbonate_area_km2', 'temperature', 'dissolved_O2', 'd13C', 'atm_O2', 'atm_CO2', 'total_area_km2', 'carbonate_percentage', 'sea_level', 'Labechiida_prop', 'Clathrodictyida_prop', 'Actinostromatida_prop', 'Stromatoporida_prop', 'Stromatoporellida_prop', 'Syringostromatida_prop', 'Amphiporida_prop']

0. Computing original Spearman correlations...
  -> Saved results_spearman_original_all_predictors.csv
[INFO] Spearman status counts: {'too_few': 0, 'prop_low_signal': 0, 'ok': 25}

1. Running Bootstrap (10000 iter)...
  -> Saved results_bootstrap.csv
  -> Saved results_bootstrap_dist_all_predictors_long.csv
2. Running Permutation (10000 iter)...
  -> Saved results_permutation.csv
  -> Saved results_permutation_dist_all_predictors_long.cs

Unnamed: 0,Dataset,Target,Predictor,Test_Type,Controls,N,spearman_partial,spearman_p_partial,Status
0,Stage-Level Data,thickness_mean,derived_strom_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.816883,6e-06,OK
24,Stage-Level Data,thickness_mean,Amphiporida_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.723377,0.000211,OK
1,Stage-Level Data,thickness_mean,derived_strom_occ,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.685714,0.000601,OK
22,Stage-Level Data,thickness_mean,Stromatoporellida_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.637662,0.001873,OK
2,Stage-Level Data,thickness_mean,basal_strom_occ,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,-0.631169,0.002153,OK
18,Stage-Level Data,thickness_mean,Labechiida_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,-0.628571,0.002274,OK
23,Stage-Level Data,thickness_mean,Syringostromatida_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.618182,0.002819,OK
20,Stage-Level Data,thickness_mean,Actinostromatida_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.563636,0.007793,OK
3,Stage-Level Data,thickness_mean,derived_strom_div,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.487013,0.025151,OK
4,Stage-Level Data,thickness_mean,basal_strom_div,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,-0.481818,0.026987,OK



[DISPLAY] Partial correlations (all rows incl. skips/fails):


Unnamed: 0,Dataset,Target,Predictor,Test_Type,Controls,N,spearman_partial,spearman_p_partial,Status
0,Stage-Level Data,thickness_mean,derived_strom_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.816883,6e-06,OK
24,Stage-Level Data,thickness_mean,Amphiporida_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.723377,0.000211,OK
1,Stage-Level Data,thickness_mean,derived_strom_occ,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.685714,0.000601,OK
22,Stage-Level Data,thickness_mean,Stromatoporellida_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.637662,0.001873,OK
2,Stage-Level Data,thickness_mean,basal_strom_occ,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,-0.631169,0.002153,OK
18,Stage-Level Data,thickness_mean,Labechiida_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,-0.628571,0.002274,OK
23,Stage-Level Data,thickness_mean,Syringostromatida_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.618182,0.002819,OK
20,Stage-Level Data,thickness_mean,Actinostromatida_prop,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.563636,0.007793,OK
3,Stage-Level Data,thickness_mean,derived_strom_div,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,0.487013,0.025151,OK
4,Stage-Level Data,thickness_mean,basal_strom_div,Biotic/Taxon (Env Controlled),"carbonate_area_km2,temperature,dissolved_O2",21,-0.481818,0.026987,OK


6. Running Variance Partitioning (groups; strict subset)...
  -> Saved results_variance_partition_improved.csv
6b. Running Univariate adj-R2 (robust)...
  -> Saved results_adjR2_univariate_all_predictors.csv

✓ CELL 13 PART A COMPLETE (v3 + LOWESS; strom+coral summary occ/div allow zeros once present).


In [18]:
# =============================================================================
# @title CELL 14: MODEL SELECTION + LOWESS + LAG (Robust guards) — NO BASAL
# =============================================================================
import numpy as np
import pandas as pd
import statsmodels.api as sm
from scipy import stats
from statsmodels.nonparametric.smoothers_lowess import lowess

print("="*80)
print("CELL 14: AICc + LOWESS + LAG (NO BASAL)")
print("="*80)

def _filter_no_strom(sub):
    if 'strom_total_occ' in sub.columns:
        sub = sub[sub['strom_total_occ'].notna() & (pd.to_numeric(sub['strom_total_occ'], errors='coerce') > 0)]
    return sub

# ---------------------------
# 7. AICc MODEL SELECTION (COMMON SUBSET; Method 1) — strom/coral/env only
# ---------------------------
print("7. Running AICc Model Selection (COMMON SUBSET; strom/coral/env only)...")

target = 'thickness_mean'

# requested groups (NO BASAL)
strom_vars = [v for v in ['derived_strom_prop'] if v in df.columns]
coral_vars = [v for v in ['rugose_div', 'tabulate_div'] if v in df.columns]
env_vars   = [v for v in ['carbonate_area_km2', 'temperature', 'dissolved_O2'] if v in df.columns]

needed = [target] + strom_vars + coral_vars + env_vars
if 'strom_total_occ' in df.columns:
    needed += ['strom_total_occ']

sub_aic = df[needed].copy()
sub_aic[target] = pd.to_numeric(sub_aic[target], errors='coerce')
for c in strom_vars + coral_vars + env_vars:
    sub_aic[c] = pd.to_numeric(sub_aic[c], errors='coerce')

# COMMON SUBSET (complete for full set)
sub_aic = sub_aic.dropna(subset=[target] + strom_vars + coral_vars + env_vars)
sub_aic = _filter_no_strom(sub_aic)

print(f"  Common-subset N = {len(sub_aic)} rows")

models = {
    'Strom+Coral': strom_vars + coral_vars,
    'Strom-only':  strom_vars,
    'Strom+Env':   strom_vars + env_vars,
    'Full':        strom_vars + coral_vars + env_vars,
    'Null':        [],
    'Coral-only':  coral_vars,
    'Env-only':    env_vars,
    'Coral+Env':   coral_vars + env_vars,
}

aic_rows = []
n = len(sub_aic)

if n == 0:
    print("  [SKIP] AICc: common-subset is empty.")
else:
    y = sub_aic[target].values.astype(float)

    for name, preds in models.items():
        if len(preds) == 0:
            X = np.ones((n, 1))
        else:
            X = sm.add_constant(sub_aic[preds], has_constant='add')
            X = np.asarray(X, dtype=float)

        k = X.shape[1]
        if n <= k + 1:
            print(f"  [SKIP] {name}: too few rows for AICc (n={n}, k={k})")
            continue

        try:
            fit = sm.OLS(y, X).fit(method='qr')
            aic = float(fit.aic)
            aicc = aic + (2 * k * (k + 1)) / (n - k - 1)
            aic_rows.append({
                'Model': name,
                'Predictors': str(preds),
                'R2': float(fit.rsquared),
                'Adj_R2': float(fit.rsquared_adj),
                'AIC': aic,
                'AICc': float(aicc),
                'N': int(n),
                'K': int(k)
            })
        except Exception as e:
            print(f"  [FAIL] {name}: fit failed -> {e}")

aic_df = pd.DataFrame(aic_rows)
if aic_df.empty:
    print("  [SKIP] AICc: no models could be fit.")
else:
    aic_df = aic_df.sort_values('AICc')
    min_aicc = aic_df['AICc'].min()
    aic_df['Delta_AICc'] = aic_df['AICc'] - min_aicc
    aic_df['Weight'] = np.exp(-0.5 * aic_df['Delta_AICc'])
    aic_df['Weight'] = aic_df['Weight'] / aic_df['Weight'].sum()
    aic_df.to_csv(OUTPUT_DIR / 'results_aic_improved.csv', index=False)
    print("  -> Saved results_aic_improved.csv")
    display(aic_df)

# ---------------------------
# 8. LOWESS (keep: derived only now that basal is removed)
# ---------------------------
print("8. LOWESS fits (derived only; 1000 boot for plotting stability/speed)...")

rng = np.random.default_rng(0)

def get_lowess_ci(x, y, frac=0.6, n_boot=1000):
    sort_idx = np.argsort(x)
    x_sorted = x.iloc[sort_idx].values
    y_sorted = y.iloc[sort_idx].values

    z = lowess(y_sorted, x_sorted, frac=frac)
    x_grid = z[:, 0]
    y_fit = z[:, 1]

    boot_curves = []
    idx_all = np.arange(len(x))
    for _ in range(n_boot):
        idx = rng.choice(idx_all, size=len(idx_all), replace=True)
        x_s = x.iloc[idx].values
        y_s = y.iloc[idx].values
        s_idx = np.argsort(x_s)
        try:
            z_b = lowess(y_s[s_idx], x_s[s_idx], frac=frac)
            y_interp = np.interp(x_grid, z_b[:, 0], z_b[:, 1])
            boot_curves.append(y_interp)
        except:
            pass

    boot_curves = np.array(boot_curves)
    ci_low = np.nanpercentile(boot_curves, 2.5, axis=0)
    ci_high = np.nanpercentile(boot_curves, 97.5, axis=0)
    return pd.DataFrame({'x': x_grid, 'y_fit': y_fit, 'ci_low': ci_low, 'ci_high': ci_high})

if 'derived_strom_prop' in df.columns:
    v = df.dropna(subset=['derived_strom_prop', target]).copy()
    v = _filter_no_strom(v)
    low = get_lowess_ci(v['derived_strom_prop'], v[target])
    low.to_csv(OUTPUT_DIR / 'results_lowess_derived.csv', index=False)
    print("  -> Saved results_lowess_derived.csv")

# =============================================================================
# LAG ANALYSIS (Max-|t| & |Cohen’s d|) — robust + saves outputs
#   Outputs:
#     - output/results_lag_summary_all_predictors.csv
#     - output/results_lag_profile_all_predictors_long.csv
#     - output/results_lag_profile_thickness.csv
#   Also defines: peak_thick, peak_deriv (for segmented regression cell)
# =============================================================================
import numpy as np
import pandas as pd
from scipy import stats
from pathlib import Path

DATA_DIR = Path("./output")
OUTPUT_DIR = DATA_DIR

# -------------------------
# Ensure df exists
# -------------------------
if "df" not in globals() or df is None or not isinstance(df, pd.DataFrame) or df.empty:
    df = pd.read_csv(DATA_DIR / "MASTER_dataset_stage.csv", encoding="utf-8-sig")
    print(f"[INFO] Loaded df from MASTER_dataset_stage.csv: {len(df)} rows")

# -------------------------
# Config
# -------------------------
target = "thickness_mean"
AGE_MIN, AGE_MAX = 358.9, 485.4   # Ord–Dev window used previously
MIN_BEFORE, MIN_AFTER = 3, 3

# -------------------------
# Build predictor list (use all_predictors if available; else reconstruct)
# -------------------------
if "all_predictors" in globals() and isinstance(all_predictors, (list, tuple)) and len(all_predictors) > 0:
    predictors = [p for p in all_predictors if p in df.columns]
else:
    predictors = []
    for p in [
        "derived_strom_prop", "basal_strom_prop",
        "rugose_div", "tabulate_div",
        "carbonate_area_km2", "temperature", "dissolved_O2",
        "log_derived_basal_ratio",
        "d13C", "atm_O2", "atm_CO2", "pO2", "pCO2",
        "Labechiida_prop", "Clathrodictyida_prop", "Actinostromatida_prop",
        "Stromatoporida_prop", "Stromatoporellida_prop", "Syringostromatida_prop",
        "Amphiporida_prop",
    ]:
        if p in df.columns:
            predictors.append(p)

# Ensure required columns exist
needed_cols = ["stage", "midpoint_ma", "strom_total_occ", target]
missing = [c for c in needed_cols if c not in df.columns]
if missing:
    raise ValueError(f"Missing required columns for lag analysis: {missing}")

# -------------------------
# Filter Ord–Dev + strom presence + required numeric
# -------------------------
work = df.copy()
work["midpoint_ma"] = pd.to_numeric(work["midpoint_ma"], errors="coerce")
work["strom_total_occ"] = pd.to_numeric(work["strom_total_occ"], errors="coerce")
work[target] = pd.to_numeric(work[target], errors="coerce")

work = work.dropna(subset=["stage", "midpoint_ma", "strom_total_occ", target]).copy()
work = work[(work["midpoint_ma"] >= AGE_MIN) & (work["midpoint_ma"] <= AGE_MAX)].copy()
work = work[work["strom_total_occ"] > 0].copy()

# Sort oldest->youngest (descending Ma)
work = work.sort_values("midpoint_ma", ascending=False).reset_index(drop=True)

print(f"[INFO] Lag dataset after filters: n={len(work)} (Ord–Dev, strom_total_occ>0)")

def _cohen_d(before, after):
    before = before[np.isfinite(before)]
    after  = after[np.isfinite(after)]
    if len(before) < 2 or len(after) < 2:
        return np.nan
    pooled_std = np.sqrt(((len(before)-1)*np.var(before, ddof=1) + (len(after)-1)*np.var(after, ddof=1)) /
                         (len(before) + len(after) - 2))
    if not np.isfinite(pooled_std) or pooled_std <= 0:
        return np.nan
    return (np.mean(after) - np.mean(before)) / pooled_std

def lag_profile_for_var(df_sorted, var, min_before=3, min_after=3):
    """Return break profile rows for a single variable."""
    y = pd.to_numeric(df_sorted[var], errors="coerce").values.astype(float)
    ages = pd.to_numeric(df_sorted["midpoint_ma"], errors="coerce").values.astype(float)
    stages = df_sorted["stage"].astype(str).values

    rows = []
    n_all = len(df_sorted)
    for i in range(min_before, n_all - min_after):
        before = y[:i]
        after  = y[i:]

        before = before[np.isfinite(before)]
        after  = after[np.isfinite(after)]

        if len(before) >= 2 and len(after) >= 2:
            t, p = stats.ttest_ind(before, after, equal_var=True, nan_policy="omit")
            d = _cohen_d(before, after)
        else:
            t, p, d = np.nan, np.nan, np.nan

        rows.append({
            "Variable": var,
            "stage": stages[i],
            "age": float(ages[i]) if np.isfinite(ages[i]) else np.nan,
            "n_before": int(i),
            "n_after": int(n_all - i),
            "t_abs": float(np.abs(t)) if np.isfinite(t) else 0.0,
            "p": float(p) if np.isfinite(p) else np.nan,
            "d_abs": float(np.abs(d)) if np.isfinite(d) else 0.0
        })
    return pd.DataFrame(rows)

# -------------------------
# Compute thickness profile + peak
# -------------------------
prof_thick = lag_profile_for_var(work, target, MIN_BEFORE, MIN_AFTER)
if prof_thick.empty:
    raise RuntimeError("[ERROR] Thickness lag profile is empty (too few rows?)")

peak_thick = prof_thick.loc[prof_thick["t_abs"].idxmax()].to_dict()

# -------------------------
# Compute predictor profiles + peaks
# -------------------------
profiles = [prof_thick]
summary_rows = []

peak_deriv = None  # composition changepoint

for pred in predictors:
    # skip if fully missing in this filtered dataset
    if pred not in work.columns:
        continue

    prof = lag_profile_for_var(work, pred, MIN_BEFORE, MIN_AFTER)
    if prof.empty:
        continue

    profiles.append(prof)

    peak = prof.loc[prof["t_abs"].idxmax()].to_dict()
    lag_vs_thick = peak["age"] - peak_thick["age"] if np.isfinite(peak.get("age", np.nan)) else np.nan

    summary_rows.append({
        "Predictor": pred,
        "Peak_Age": peak["age"],
        "Peak_Stage": peak["stage"],
        "Max_t_abs": peak["t_abs"],
        "Max_d_abs": peak["d_abs"],
        "p_at_peak": peak["p"],
        "Lag_vs_Thickness_Myr": lag_vs_thick,
        "Thickness_Peak_Age": peak_thick["age"],
        "Thickness_Peak_Stage": peak_thick["stage"],
    })

    if pred == "derived_strom_prop":
        peak_deriv = peak

# fallback for peak_deriv (so downstream cells won't break)
if peak_deriv is None:
    peak_deriv = {"age": np.nan, "stage": None, "t_abs": np.nan, "d_abs": np.nan}

lag_summary_df = pd.DataFrame(summary_rows).sort_values("Max_t_abs", ascending=False)
lag_profiles_long = pd.concat(profiles, ignore_index=True)

# -------------------------
# Save
# -------------------------
(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

lag_summary_df.to_csv(OUTPUT_DIR / "results_lag_summary_all_predictors.csv", index=False, encoding="utf-8-sig")
lag_profiles_long.to_csv(OUTPUT_DIR / "results_lag_profile_all_predictors_long.csv", index=False, encoding="utf-8-sig")
prof_thick.to_csv(OUTPUT_DIR / "results_lag_profile_thickness.csv", index=False, encoding="utf-8-sig")

print("\n[OK] Saved lag outputs:")
print("  - output/results_lag_summary_all_predictors.csv")
print("  - output/results_lag_profile_all_predictors_long.csv")
print("  - output/results_lag_profile_thickness.csv")

print(f"\nChangepoint (thickness):   {peak_thick['age']:.2f} Ma ({peak_thick['stage']})")
if np.isfinite(peak_deriv.get("age", np.nan)):
    print(f"Changepoint (composition): {peak_deriv['age']:.2f} Ma ({peak_deriv['stage']})")
else:
    print("Changepoint (composition): NA (derived_strom_prop not available/insufficient)")

display(lag_summary_df.head(25))


print("\n✓ CELL 14 COMPLETE (NO BASAL).")


CELL 14: AICc + LOWESS + LAG (NO BASAL)
7. Running AICc Model Selection (COMMON SUBSET; strom/coral/env only)...
  Common-subset N = 21 rows
  -> Saved results_aic_improved.csv


Unnamed: 0,Model,Predictors,R2,Adj_R2,AIC,AICc,N,K,Delta_AICc,Weight
0,Strom+Coral,"['derived_strom_prop', 'rugose_div', 'tabulate...",0.80398,0.7693882,-16.139037,-13.639037,21,4,0.0,0.6143463
1,Strom-only,['derived_strom_prop'],0.7282392,0.713936,-13.278221,-12.611555,21,2,1.027482,0.3675347
2,Strom+Env,"['derived_strom_prop', 'carbonate_area_km2', '...",0.758271,0.6978387,-9.737422,-5.737422,21,5,7.901615,0.01181951
3,Full,"['derived_strom_prop', 'rugose_div', 'tabulate...",0.8297055,0.7567222,-13.093479,-4.478095,21,7,9.160942,0.00629709
4,Null,[],1.110223e-16,1.110223e-16,12.081272,12.291798,21,1,25.930835,1.437487e-06
5,Coral-only,"['rugose_div', 'tabulate_div']",0.1657233,0.07302591,12.276278,13.688043,21,3,27.32708,7.151765e-07
6,Env-only,"['carbonate_area_km2', 'temperature', 'dissolv...",0.1971273,0.05544383,13.470532,15.970532,21,4,29.609569,2.284426e-07
7,Coral+Env,"['rugose_div', 'tabulate_div', 'carbonate_area...",0.3493958,0.1325277,13.054342,19.054342,21,6,32.693379,4.888057e-08


8. LOWESS fits (derived only; 1000 boot for plotting stability/speed)...
  -> Saved results_lowess_derived.csv
[INFO] Lag dataset after filters: n=21 (Ord–Dev, strom_total_occ>0)

[OK] Saved lag outputs:
  - output/results_lag_summary_all_predictors.csv
  - output/results_lag_profile_all_predictors_long.csv
  - output/results_lag_profile_thickness.csv

Changepoint (thickness):   426.50 Ma (Gorstian)
Changepoint (composition): 431.95 Ma (Sheinwoodian)


Unnamed: 0,Predictor,Peak_Age,Peak_Stage,Max_t_abs,Max_d_abs,p_at_peak,Lag_vs_Thickness_Myr,Thickness_Peak_Age,Thickness_Peak_Stage
19,Labechiida_prop,444.5,Hirnantian,17.523005,8.977868,3.471399e-13,18.0,426.5,Gorstian
1,basal_strom_prop,431.95,Sheinwoodian,9.514933,4.195691,1.165251e-08,5.45,426.5,Gorstian
0,derived_strom_prop,431.95,Sheinwoodian,9.514933,4.195691,1.165251e-08,5.45,426.5,Gorstian
23,Stromatoporellida_prop,400.45,Emsian,7.150752,3.66367,8.510268e-07,-26.05,426.5,Gorstian
15,atm_CO2,390.5,Eifelian,6.970991,3.873912,1.215429e-06,-36.0,426.5,Gorstian
11,temperature,455.7,Sandbian,5.758257,3.590904,1.505465e-05,29.2,426.5,Gorstian
21,Actinostromatida_prop,435.95,Telychian,5.623248,2.526855,2.015312e-05,9.45,426.5,Gorstian
22,Stromatoporida_prop,431.95,Sheinwoodian,5.200166,2.293058,5.094031e-05,5.45,426.5,Gorstian
25,Amphiporida_prop,424.3,Ludfordian,5.149704,2.270806,5.696891e-05,-2.2,426.5,Gorstian
4,derived_strom_div,428.95,Homerian,4.67787,2.043909,0.0001639099,2.45,426.5,Gorstian



✓ CELL 14 COMPLETE (NO BASAL).


In [19]:
# =============================================================================
# @title CELL 15: TIME-SERIES-ROBUST TESTS (Segmented regression + sampling control)
# =============================================================================
import numpy as np
import pandas as pd
import scipy.stats as stats
import statsmodels.api as sm
from statsmodels.regression.linear_model import OLS, GLSAR
from statsmodels.stats.sandwich_covariance import cov_hac

def _as_1d_numeric(v):
    """Force Series/DataFrame column to 1-D numeric float array."""
    if isinstance(v, pd.DataFrame):
        v = v.iloc[:, 0]
    v = pd.to_numeric(v, errors='coerce')
    return v.values.astype(float)

def run_segmented_regression(df_in, target, changepoint):
    """Segmented (ITS) regression with OLS+HAC, WLS+HAC, and GLSAR"""
    needed = ['midpoint_ma', target, 'derived_strom_prop', 'reef_count', 'strom_total_occ']
    for c in needed:
        if c not in df_in.columns:
            print(f"[WARN] Missing column: {c}")
            return []

    v = df_in[needed].copy()

    # Exclude no-strom rows
    v = v[v['strom_total_occ'].notna() & (v['strom_total_occ'] > 0)].copy()

    # Force numeric
    v['midpoint_ma'] = pd.to_numeric(v['midpoint_ma'], errors='coerce')
    v[target] = pd.to_numeric(v[target], errors='coerce')
    v['derived_strom_prop'] = pd.to_numeric(v['derived_strom_prop'], errors='coerce')
    v['strom_total_occ'] = pd.to_numeric(v['strom_total_occ'], errors='coerce')

    v = v.dropna(subset=['midpoint_ma', target, 'derived_strom_prop', 'strom_total_occ'])
    v = v.sort_values('midpoint_ma', ascending=False).reset_index(drop=True)

    if len(v) < 10:
        print("[WARN] Too few observations after filtering:", len(v))
        return []

    t = _as_1d_numeric(v['midpoint_ma'])
    y = _as_1d_numeric(v[target])

    I_post = (t <= changepoint).astype(float)
    t_post = np.maximum(0, changepoint - t)

    x_derived = _as_1d_numeric(v['derived_strom_prop'])
    s_intensity = np.log1p(_as_1d_numeric(v['strom_total_occ']))

    X = np.column_stack([np.ones(len(t)), t, I_post, t_post, x_derived, s_intensity])
    col_names = ['const', 'time', 'step', 'post_slope', 'derived_prop', 'sampling']

    results = []
    n = len(y)
    maxlags = max(1, int(np.floor(4 * (n/100)**(2/9))))  # Newey–West

    # OLS + HAC
    try:
        m = OLS(y, X).fit()
        hac_cov = cov_hac(m, nlags=maxlags)
        hac_se = np.sqrt(np.diag(hac_cov))
        hac_t = m.params / hac_se
        hac_p = 2 * (1 - stats.t.cdf(np.abs(hac_t), df=n - X.shape[1]))
        for i, name in enumerate(col_names):
            results.append({'Model': 'OLS+HAC', 'Target': target, 'Changepoint_Ma': changepoint,
                            'Term': name, 'Coef': m.params[i], 'SE': hac_se[i], 'P': hac_p[i],
                            'N': n, 'R2': m.rsquared})
    except Exception as e:
        print("OLS+HAC failed:", e)

    # WLS + HAC
    try:
        weights = 1.0 / np.log1p(_as_1d_numeric(v['strom_total_occ']) + 1.0)
        m = sm.WLS(y, X, weights=weights).fit()
        hac_cov = cov_hac(m, nlags=maxlags)
        hac_se = np.sqrt(np.diag(hac_cov))
        hac_t = m.params / hac_se
        hac_p = 2 * (1 - stats.t.cdf(np.abs(hac_t), df=n - X.shape[1]))
        for i, name in enumerate(col_names):
            results.append({'Model': 'WLS+HAC', 'Target': target, 'Changepoint_Ma': changepoint,
                            'Term': name, 'Coef': m.params[i], 'SE': hac_se[i], 'P': hac_p[i],
                            'N': n, 'R2': m.rsquared})
    except Exception as e:
        print("WLS+HAC failed:", e)

    # GLSAR(AR1)
    try:
        ar_m = GLSAR(y, X, rho=1)
        ar_fit = ar_m.iterative_fit(maxiter=20)
        for i, name in enumerate(col_names):
            results.append({'Model': 'GLSAR', 'Target': target, 'Changepoint_Ma': changepoint,
                            'Term': name, 'Coef': ar_fit.params[i], 'SE': ar_fit.bse[i],
                            'P': ar_fit.pvalues[i], 'N': n, 'R2': ar_fit.rsquared})
    except Exception as e:
        print("GLSAR failed:", e)

    return results

# -------------------------
# RUN (use changepoints from Cell 13)
# -------------------------
CP_thick = np.nan
CP_comp  = np.nan

if 'peak_thick' in globals():
    if isinstance(peak_thick, dict) and 'age' in peak_thick:
        CP_thick = float(peak_thick['age'])
    elif hasattr(peak_thick, '__getitem__') and 'age' in peak_thick:
        CP_thick = float(peak_thick['age'])

if 'peak_deriv' in globals():
    if isinstance(peak_deriv, dict) and 'age' in peak_deriv:
        CP_comp = float(peak_deriv['age'])
    elif hasattr(peak_deriv, '__getitem__') and 'age' in peak_deriv:
        CP_comp = float(peak_deriv['age'])

# fallbacks if missing
if not np.isfinite(CP_thick): CP_thick = 426.5
if not np.isfinite(CP_comp):  CP_comp  = 431.9

print(f"Running segmented regression at thickness CP = {CP_thick:.1f} Ma...")
seg_thick = pd.DataFrame(run_segmented_regression(df, 'thickness_mean', changepoint=CP_thick))

print(f"Running segmented regression at composition CP = {CP_comp:.1f} Ma...")
seg_comp  = pd.DataFrame(run_segmented_regression(df, 'thickness_mean', changepoint=CP_comp))

seg_df = pd.concat([seg_thick.assign(CP_type='Thickness'),
                    seg_comp.assign(CP_type='Composition')], ignore_index=True)

if 'OUTPUT_DIR' in globals() and not seg_df.empty:
    seg_df.to_csv(f'{OUTPUT_DIR}/results_segmented_regression.csv', index=False, encoding='utf-8-sig')
    print(f"Saved: {OUTPUT_DIR}/results_segmented_regression.csv")

display(seg_df if not seg_df.empty else pd.DataFrame())



Running segmented regression at thickness CP = 426.5 Ma...
Running segmented regression at composition CP = 431.9 Ma...
Saved: output/results_segmented_regression.csv


Unnamed: 0,Model,Target,Changepoint_Ma,Term,Coef,SE,P,N,R2,CP_type
0,OLS+HAC,thickness_mean,426.5,const,4.437553,1.963271,0.039107,21,0.800573,Thickness
1,OLS+HAC,thickness_mean,426.5,time,-0.006305,0.00404,0.139493,21,0.800573,Thickness
2,OLS+HAC,thickness_mean,426.5,step,0.267077,0.103465,0.02086,21,0.800573,Thickness
3,OLS+HAC,thickness_mean,426.5,post_slope,-0.007805,0.003683,0.051189,21,0.800573,Thickness
4,OLS+HAC,thickness_mean,426.5,derived_prop,0.226117,0.238352,0.357823,21,0.800573,Thickness
5,OLS+HAC,thickness_mean,426.5,sampling,-0.011693,0.032358,0.722855,21,0.800573,Thickness
6,WLS+HAC,thickness_mean,426.5,const,4.438575,1.605677,0.014464,21,0.891673,Thickness
7,WLS+HAC,thickness_mean,426.5,time,-0.006293,0.003329,0.078221,21,0.891673,Thickness
8,WLS+HAC,thickness_mean,426.5,step,0.25272,0.111388,0.038469,21,0.891673,Thickness
9,WLS+HAC,thickness_mean,426.5,post_slope,-0.00894,0.003298,0.016109,21,0.891673,Thickness


In [20]:
# ============================================================
#@title CELL 16: ZIP everything in ./output and download the zip (Colab-safe)
# ============================================================
from pathlib import Path
import zipfile, datetime, os

OUTPUT_DIR = Path("./output")
if not OUTPUT_DIR.exists():
    raise FileNotFoundError(f"OUTPUT_DIR not found: {OUTPUT_DIR.resolve()}")

# Make a timestamped zip name
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
zip_path = Path(f"output_{ts}.zip")

# Collect files
files_to_zip = [p for p in OUTPUT_DIR.rglob("*") if p.is_file()]
if len(files_to_zip) == 0:
    print(f"[WARN] No files found under: {OUTPUT_DIR.resolve()}")
else:
    # Create zip
    with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
        for fp in files_to_zip:
            # store relative to OUTPUT_DIR (so zip has a clean structure)
            zf.write(fp, arcname=fp.relative_to(OUTPUT_DIR))
    print(f"[OK] Created zip: {zip_path.resolve()}  ({len(files_to_zip)} files)")

    # Download (Colab) or show link (Jupyter)
    try:
        from google.colab import files
        files.download(str(zip_path))
    except Exception:
        try:
            from IPython.display import FileLink, display
            display(FileLink(str(zip_path)))
        except Exception:
            print(f"Download not auto-supported here. Zip is at: {zip_path.resolve()}")


[OK] Created zip: /content/output_20260118_071048.zip  (31 files)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>