In [None]:
#!/usr/bin/env python3
"""
saber_analysis.py

Comprehensive analysis toolkit for SABER datasets:
- saber_1d: ['year_month', 'ktemp']
- saber_2d: ['year_month', 'altitude', 'ktemp']
- saber_4d: ['year_month', 'altitude', 'latitude', 'longitude', 'ktemp']

Features:
- Preprocessing (datetime conversion + sorting)
- Basic statistics (mean, median, std)
- Time-series plotting with rolling mean and linear trend
- 2D: time x altitude heatmap, altitude-slice time series
- 4D: time x altitude heatmap, lat-lon maps for time slices, 3D scatter

Usage:
- Import functions from this file and call `analyze_all(saber_1d, saber_2d, saber_4d)`

Author: ChatGPT (for user)
"""

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.linear_model import LinearRegression
import warnings

warnings.filterwarnings("ignore")

In [None]:
# ----------------------
# Preprocessing helpers
# ----------------------

def preprocess_time(df, time_col="year_month"):
    """Convert time column to datetime and sort by it (returns a copy).

    Parameters
    ----------
    df : pd.DataFrame
    time_col : str

    Returns
    -------
    pd.DataFrame
    """
    df = df.copy()
    if not np.issubdtype(df[time_col].dtype, np.datetime64):
        df[time_col] = pd.to_datetime(df[time_col])
    df = df.sort_values(time_col).reset_index(drop=True)
    return df


# ----------------------
# Basic statistics
# ----------------------

def compute_basic_stats(df, kcol="ktemp"):
    """Return a small dataframe with mean, median, std for kcol."""
    stats = {
        "mean_ktemp": float(df[kcol].mean()),
        "median_ktemp": float(df[kcol].median()),
        "std_ktemp": float(df[kcol].std(ddof=0)),
    }
    return pd.DataFrame(stats, index=[0])


# ----------------------
# Time-series & trend
# ----------------------

def fit_linear_trend(series_time, series_y):
    """Fit a linear regression of y against seconds since start. Returns predicted trend and the model."""
    X = (series_time - series_time.min()).dt.total_seconds().values.reshape(-1, 1)
    # simple NaN handling for fitting
    y = series_y.fillna(series_y.mean()).values.reshape(-1, 1)
    model = LinearRegression()
    model.fit(X, y)
    trend = model.predict(X).ravel()
    return trend, model


def plot_time_series_with_trend(df, time_col="year_month", kcol="ktemp", title=None, window=12, ax=None):
    """Plot raw time series, rolling mean and linear trend. Returns (fig, dict).

    - window is the rolling window in periods (e.g., months).
    """
    df = preprocess_time(df, time_col)
    t = df[time_col]
    y = df[kcol]

    rolling = y.rolling(window=window, min_periods=1, center=True).mean()
    trend, model = fit_linear_trend(t, y)

    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 4.5))
        created_fig = True
    else:
        fig = ax.figure

    ax.plot(t, y, label="ktemp (raw)", linewidth=1)
    ax.plot(t, rolling, label=f"Rolling mean (window={window})", linewidth=1.5)
    ax.plot(t, trend, label="Linear trend", linestyle="--", linewidth=1.5)
    ax.set_xlabel("Time")
    ax.set_ylabel("ktemp")
    if title:
        ax.set_title(title)
    ax.legend()
    ax.grid(True)
    fig.tight_layout()

    return fig, {"rolling": rolling, "trend": trend, "model": model}


# ----------------------
# Aggregation helpers
# ----------------------

def aggregate_time_mean(df, time_col="year_month", kcol="ktemp", agg_dims=None):
    """Aggregate mean ktemp per time. If agg_dims are provided, first compute mean across them per time."""
    df = preprocess_time(df, time_col)
    if agg_dims:
        gb = df.groupby([time_col] + agg_dims)[kcol].mean().reset_index()
        time_mean = gb.groupby(time_col)[kcol].mean().reset_index(name="mean_ktemp")
    else:
        time_mean = df.groupby(time_col)[kcol].mean().reset_index(name="mean_ktemp")
    return time_mean


# ----------------------
# 2D analysis (time x altitude)
# ----------------------

def analyze_2d(saber_2d, time_col="year_month", alt_col="altitude", kcol="ktemp", rolling_window=12, save_figs=False, prefix="2d"):
    """Run analyses and plots for saber_2d. Returns a result dict with stats and figures.

    Figures are returned as matplotlib Figure objects. If save_figs=True, they are also saved to disk with prefix.
    """
    df = preprocess_time(saber_2d, time_col)

    results = {}
    results["stats"] = compute_basic_stats(df, kcol=kcol)

    # 1) mean across altitude per time
    time_mean = df.groupby(time_col)[kcol].mean().reset_index(name="mean_ktemp")
    fig_time, extras = plot_time_series_with_trend(time_mean.rename(columns={"mean_ktemp": kcol}), title="saber_2d: mean ktemp over time (averaged across altitude)", window=rolling_window)
    results["time_mean"] = time_mean
    results["fig_time"] = fig_time

    if save_figs:
        fig_time.savefig(f"{prefix}_time_mean.png", dpi=200)

    # 2) time x altitude pivot for heatmap
    pivot = df.groupby([time_col, alt_col])[kcol].mean().unstack(level=1)
    try:
        pivot = pivot.reindex(sorted(pivot.columns), axis=1)
    except Exception:
        pass
    results["pivot"] = pivot

    # Heatmap
    fig_heat, ax = plt.subplots(figsize=(12, 6))
    # For pcolormesh we need numeric x,y grid
    times = pivot.index
    X = np.arange(len(times) + 1)
    alts = pivot.columns.values
    Y = np.arange(len(alts) + 1)
    data = pivot.values.T
    mesh = ax.pcolormesh(X, Y, np.vstack([data, data[-1, :]]).T, shading="auto")
    ax.set_yticks(np.arange(len(alts)) + 0.5)
    ax.set_yticklabels(alts)
    ax.set_xticks(np.arange(0, len(times), max(1, len(times)//10)))
    ax.set_xticklabels([t.strftime("%Y-%m") for t in times[::max(1, len(times)//10)]], rotation=45)
    ax.set_xlabel("Time (year-month)")
    ax.set_ylabel("Altitude")
    ax.set_title("saber_2d: ktemp (time x altitude heatmap)")
    fig_heat.colorbar(mesh, ax=ax, label="ktemp")
    fig_heat.tight_layout()
    results["fig_heat"] = fig_heat
    if save_figs:
        fig_heat.savefig(f"{prefix}_heatmap.png", dpi=200)

    # 3) slice time series (some representative altitudes)
    unique_alts = np.sort(df[alt_col].unique())
    if len(unique_alts) > 4:
        idxs = [0, len(unique_alts)//3, 2*len(unique_alts)//3, -1]
        sample_alts = unique_alts[idxs]
    else:
        sample_alts = unique_alts

    slice_figs = {}
    for alt in sample_alts:
        slice_df = df[df[alt_col] == alt].groupby(time_col)[kcol].mean().reset_index(name=kcol)
        fig_s, _ = plot_time_series_with_trend(slice_df, title=f"ktemp over time at altitude={alt}", window=rolling_window)
        slice_figs[alt] = fig_s
        if save_figs:
            fig_s.savefig(f"{prefix}_slice_alt_{alt}.png", dpi=200)

    results["slice_figs"] = slice_figs
    return results


# ----------------------
# 4D analysis (time x altitude x lat x lon)
# ----------------------

def analyze_4d(saber_4d, time_col="year_month", alt_col="altitude", lat_col="latitude", lon_col="longitude", kcol="ktemp", rolling_window=12, time_slices=None, save_figs=False, prefix="4d"):
    df = preprocess_time(saber_4d, time_col)
    results = {}
    results["stats"] = compute_basic_stats(df, kcol=kcol)

    # 1) global mean time series
    time_mean = df.groupby(time_col)[kcol].mean().reset_index(name="mean_ktemp")
    fig_time, _ = plot_time_series_with_trend(time_mean.rename(columns={"mean_ktemp": kcol}), title="saber_4d: global mean ktemp over time (averaged across space)", window=rolling_window)
    results["time_mean"] = time_mean
    results["fig_time"] = fig_time
    if save_figs:
        fig_time.savefig(f"{prefix}_time_mean.png", dpi=200)

    # 2) altitude x time pivot
    pivot_alt = df.groupby([time_col, alt_col])[kcol].mean().unstack(level=1)
    try:
        pivot_alt = pivot_alt.reindex(sorted(pivot_alt.columns), axis=1)
    except Exception:
        pass
    results["pivot_alt"] = pivot_alt

    fig_alt, ax = plt.subplots(figsize=(12, 6))
    times = pivot_alt.index
    X = np.arange(len(times) + 1)
    alts = pivot_alt.columns.values
    Y = np.arange(len(alts) + 1)
    data = pivot_alt.values.T
    mesh = ax.pcolormesh(X, Y, np.vstack([data, data[-1, :]]).T, shading="auto")
    ax.set_yticks(np.arange(len(alts)) + 0.5)
    ax.set_yticklabels(alts)
    ax.set_xticks(np.arange(0, len(times), max(1, len(times)//10)))
    ax.set_xticklabels([t.strftime("%Y-%m") for t in times[::max(1, len(times)//10)]], rotation=45)
    ax.set_xlabel("Time (year-month)")
    ax.set_ylabel("Altitude")
    ax.set_title("saber_4d: mean ktemp by altitude over time")
    fig_alt.colorbar(mesh, ax=ax, label="ktemp")
    fig_alt.tight_layout()
    results["fig_alt"] = fig_alt
    if save_figs:
        fig_alt.savefig(f"{prefix}_alt_time_heatmap.png", dpi=200)

    # 3) spatial maps for selected times (lat-lon scatter colored by ktemp averaged over altitude)
    if time_slices is None:
        unique_times = np.sort(df[time_col].unique())
        n = len(unique_times)
        picks = [0, n//3, 2*n//3, n-1] if n > 4 else list(range(n))
        time_slices = [unique_times[i] for i in picks]

    spatial_figs = {}
    for t in time_slices:
        sub = df[df[time_col] == t]
        if sub.empty:
            continue
        spatial = sub.groupby([lat_col, lon_col])[kcol].mean().reset_index()
        fig, ax = plt.subplots(figsize=(8, 5))
        sc = ax.scatter(spatial[lon_col], spatial[lat_col], c=spatial[kcol], s=35)
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")
        ax.set_title(f"Spatial lat-lon ktemp at {pd.to_datetime(t).strftime('%Y-%m')} (averaged across altitude)")
        fig.colorbar(sc, ax=ax, label="ktemp")
        fig.tight_layout()
        spatial_figs[pd.to_datetime(t)] = fig
        if save_figs:
            fig.savefig(f"{prefix}_spatial_{pd.to_datetime(t).strftime('%Y%m')}.png", dpi=200)

    results["spatial_figs"] = spatial_figs

    # 4) 3D scatter at median time
    try:
        mid_time = df[time_col].median()
    except Exception:
        mid_time = df[time_col].iloc[len(df) // 2]

    sub = df[df[time_col] == mid_time]
    fig3d = None
    if not sub.empty:
        fig3d = plt.figure(figsize=(8, 6))
        ax3 = fig3d.add_subplot(111, projection="3d")
        sc3 = ax3.scatter(sub[lon_col], sub[lat_col], sub[alt_col], s=8, c=sub[kcol])
        ax3.set_xlabel("Longitude")
        ax3.set_ylabel("Latitude")
        ax3.set_zlabel("Altitude")
        ax3.set_title(f"3D scatter (lon, lat, alt) colored by ktemp at {pd.to_datetime(mid_time).strftime('%Y-%m')}")
        fig3d.colorbar(sc3, ax=ax3, label="ktemp")
        fig3d.tight_layout()
        if save_figs:
            fig3d.savefig(f"{prefix}_3d_{pd.to_datetime(mid_time).strftime('%Y%m')}.png", dpi=200)

    results["fig3d"] = fig3d

    return results


# ----------------------
# Master runner
# ----------------------

def analyze_all(saber_1d, saber_2d, saber_4d, rolling_window=12, save_figs=False, prefix=None):
    """Run full analysis for all three datasets and return a dictionary of results.

    Parameters
    ----------
    saber_1d, saber_2d, saber_4d : pd.DataFrame
    rolling_window : int
    save_figs : bool
    prefix : str or None

    Returns
    -------
    dict
        keys: '1d', '2d', '4d' each mapping to a results dict
    """
    if prefix is None:
        prefix = "saber"

    s1 = preprocess_time(saber_1d)
    s2 = preprocess_time(saber_2d)
    s4 = preprocess_time(saber_4d)

    out = {}

    # 1D
    out["1d"] = {}
    out["1d"]["stats"] = compute_basic_stats(s1)
    fig1d, extras1 = plot_time_series_with_trend(s1, title="saber_1d: ktemp over time (1D)", window=rolling_window)
    out["1d"]["fig"] = fig1d
    out["1d"]["extras"] = extras1
    if save_figs:
        fig1d.savefig(f"{prefix}_1d_time.png", dpi=200)

    # 2D
    out["2d"] = analyze_2d(s2, rolling_window=rolling_window, save_figs=save_figs, prefix=(prefix + "_2d"))

    # 4D
    out["4d"] = analyze_4d(s4, rolling_window=rolling_window, save_figs=save_figs, prefix=(prefix + "_4d"))

    return out

In [None]:
# ----------------------
# Demonstration (synthetic) when run as a script
# ----------------------
if __name__ == "__main__":
    def make_synthetic():
        rng = pd.date_range("2018-01-01", "2022-12-01", freq="MS")
        n = len(rng)
        df1 = pd.DataFrame({
            "year_month": rng,
            "ktemp": 200 + 5 * np.sin(np.linspace(0, 2 * np.pi, n)) + np.linspace(0, 1.5, n) + np.random.normal(0, 0.5, n),
        })

        alts = np.array([80, 85, 90, 95, 100])
        rows = []
        for t in rng:
            for a in alts:
                val = 200 + 0.5 * (a - 80) + 5 * np.sin((t.year - 2018 + t.month / 12) * 2 * np.pi / 1.0) + np.random.normal(0, 0.7)
                rows.append({"year_month": t, "altitude": a, "ktemp": val})
        df2 = pd.DataFrame(rows)

        lats = np.linspace(-60, 60, 6)
        lons = np.linspace(-180, 180, 8)
        rows = []
        for t in rng:
            for a in alts:
                for lat in lats:
                    for lon in lons:
                        val = 200 + 0.4 * (a - 80) + 0.02 * lat + 0.01 * lon + 4 * np.sin((t.year - 2018 + t.month / 12) * 2 * np.pi / 1.0) + np.random.normal(0, 1.0)
                        rows.append({"year_month": t, "altitude": a, "latitude": lat, "longitude": lon, "ktemp": val})
        df4 = pd.DataFrame(rows)
        return df1, df2, df4

    print("Creating synthetic datasets and running analysis (demo)...")
    s1, s2, s4 = make_synthetic()
    results = analyze_all(s1, s2, s4, rolling_window=12, save_figs=False)
    print("Demo complete. Returned results keys:", list(results.keys()))
