In [None]:
import matplotlib
matplotlib.use('Agg')
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from stagpy.stagyydata import StagyyData
from stagpy import field as sp_field

# --- USER INPUT ---
# Set this to "time" or "snapshot"
plot_mode = "time"  

# If plot_mode is "time":
target_time_Gyr = 1.3 

# If plot_mode is "snapshot":
target_snapshot = 400 

field_to_plot = "eta"   # "T", "eta", or "basalt"

# --- CONFIGURATION ---
field_limits = {
    "T": {"vmin": 300, "vmax": 4000},
    "basalt": {"vmin": 0.0, "vmax": 1.0},
    "eta": {"vmin": 1e18, "vmax": 1e25}
}
field_labels = {"T": "Temperature", "eta": "Viscosity", "basalt": "Basalt Fraction"}

data_path = Path("/media/aritro/f522493b-003a-404d-a839-3e0925c674b6/Aritro/StagYY/runs/venus_01/archive/")
sdat = StagyyData(data_path)
SEC_PER_GYR = 1e9 * 365.25 * 24 * 3600

# --- 1. SELECTION LOGIC ---
snap_number = None
actual_time_Gyr = None

if plot_mode == "time":
    print(f"Searching for snapshot closest to {target_time_Gyr} Gyr...")
    snap_indices = []
    snap_times_seconds = []

    for snap in sdat.snaps:
        try:
            t = snap.time
            if t is None: t = snap.timeinfo["time"]
            snap_indices.append(snap.isnap)
            snap_times_seconds.append(t)
        except:
            continue
    
    snap_times_seconds = np.array(snap_times_seconds)
    snap_indices = np.array(snap_indices)

    if len(snap_times_seconds) == 0:
        raise ValueError("No valid timestamps found. Check your _time.dat file.")

    target_seconds = target_time_Gyr * SEC_PER_GYR
    idx = np.abs(snap_times_seconds - target_seconds).argmin()
    snap_number = int(snap_indices[idx])
    actual_time_Gyr = snap_times_seconds[idx] / SEC_PER_GYR
    print(f"Time Mode: Found Snapshot {snap_number} at {actual_time_Gyr:.4f} Gyr")

else:
    # Snapshot Mode
    try:
        snap_number = target_snapshot
        snap = sdat.snaps[snap_number]
        # Try to get the time for the label
        t = snap.time
        if t is None: t = snap.timeinfo["time"]
        actual_time_Gyr = t / SEC_PER_GYR
        print(f"Snapshot Mode: Plotting Snapshot {snap_number} (Time: {actual_time_Gyr:.4f} Gyr)")
    except Exception as e:
        print(f"Error: Snapshot {target_snapshot} not found or is unreadable. {e}")
        snap_number = None

# --- 2. GENERATE THE PLOT ---
if snap_number is not None:
    try:
        snapshot = sdat.snaps[snap_number]
        lims = field_limits.get(field_to_plot)
        
        if field_to_plot == "eta":
            norm = colors.LogNorm(vmin=lims["vmin"], vmax=lims["vmax"])
            fig, ax, mesh, cbar = sp_field.plot_scalar(snapshot, field_to_plot, norm=norm)
        else:
            fig, ax, mesh, cbar = sp_field.plot_scalar(snapshot, field_to_plot, 
                                                     vmin=lims["vmin"], vmax=lims["vmax"])
        
        # Labels and Metadata
        unit = snapshot.fields[field_to_plot].meta.dim
        display_name = field_labels.get(field_to_plot, field_to_plot)
        cbar.set_label(f"{display_name} [{unit}]")
        
        # Overlay Text
        label_text = f"Snap {snap_number} | {actual_time_Gyr:.3f} Gyr"
        ax.text(0.5, 0.05, label_text, transform=ax.transAxes, ha="center", 
                va="center", fontsize=16, color="black", 
                bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"))
        
        fig.set_size_inches(10, 6)
        plt.tight_layout()
        
        save_name = f"snap_{snap_number}_{field_to_plot}.png"
        fig.savefig(save_name, dpi=300)
        plt.close(fig)
        print(f"Successfully saved: {save_name}")

    except Exception as e:
        print(f"An error occurred during plotting: {e}")

In [None]:
# This will list the indices of all snapshots found in the folder
snap_indices = [snap.isnap for snap in sdat.snaps]
print(f"Number of snapshots: {len(snap_indices)}")
print(f"Snapshot indices found: {snap_indices}")