In [None]:
import matplotlib
matplotlib.use('Agg')
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from stagpy.stagyydata import StagyyData
from stagpy import field as sp_field

# --- USER INPUT ---
field_to_plot = "eta"  
snap_min = 0
snap_max = 800

# --- TOGGLE ---

mode = "constant_frame" # Options: "constant_time" or "constant_frame"

# --- CONSTANT_TIME SETTINGS ---
dt_Gyr = 0.01 

# --- CONSTANT_FRAME SETTINGS ---
snap_step = 10   # 1 = every snapshot, 10 = every 10th snapshot, etc.

# --- PLOT SETTINGS ---
fig_width = 8
fig_height = 6

field_limits = {
    "T": {"vmin": 300, "vmax": 4000},
    "basalt": {"vmin": 0.0, "vmax": 1.0},
    "eta": {"vmin": 1e18, "vmax": 1e25}
}

field_labels = {
    "T": "Temperature",
    "eta": "Viscosity",
    "basalt": "Basalt Fraction"
}

# --- DATA PATH ---
data_path = Path("/media/aritro/f522493b-003a-404d-a839-3e0925c674b6/Aritro/StagYY/runs/venus_02/archive")
sdat = StagyyData(data_path)

# Extract folder name (e.g., 'venus_02')
folder_name = data_path.parent.name 

# Create Directory: [folder_name]_frames_[field_to_plot]_[mode]
output_dir = Path(f"{folder_name}_frames_{field_to_plot}_{mode}")
output_dir.mkdir(parents=True, exist_ok=True)

SEC_PER_GYR = 1e9 * 365.25 * 24 * 3600

# --- 1. PREPARE THE FRAME LIST ---
frames_to_render = [] 

print(f"Scanning snapshots {snap_min} to {snap_max} in {folder_name}...")
available_snaps = []
available_times = []

for n in range(snap_min, snap_max + 1):
    try:
        snap = sdat.snaps[n]
        # StagPy fix: use snap.time or fallback to snap.timeinfo["time"]
        t = snap.time
        if t is None:
            t = snap.timeinfo["time"]
            
        available_snaps.append(snap.isnap)
        available_times.append(t)
    except:
        continue

available_snaps = np.array(available_snaps)
available_times = np.array(available_times)

if len(available_snaps) == 0:
    print("Error: No data found in that snapshot range.")
else:
    if mode == "constant_time":
        t_start = available_times.min()
        t_end = available_times.max()
        target_times = np.arange(t_start, t_end, dt_Gyr * SEC_PER_GYR)
        
        for t_target in target_times:
            idx = np.abs(available_times - t_target).argmin()
            # Store (snapshot_number, actual_time_from_file)
            frames_to_render.append((int(available_snaps[idx]), available_times[idx]))
    else:
        # constant_frame mode
        for i in range(0, len(available_snaps), snap_step):
            frames_to_render.append((int(available_snaps[i]), available_times[i]))

    print(f"Mode: {mode.upper()}. Generating {len(frames_to_render)} frames.")

    # --- 2. RENDER THE FRAMES ---
    for i, (snap_number, t_val) in enumerate(frames_to_render):
        try:
            snapshot = sdat.snaps[snap_number]
            lims = field_limits.get(field_to_plot)
            
            if field_to_plot == "eta":
                norm = colors.LogNorm(vmin=lims["vmin"], vmax=lims["vmax"])
                fig, ax, mesh, cbar = sp_field.plot_scalar(snapshot, field_to_plot, norm=norm)
            else:
                fig, ax, mesh, cbar = sp_field.plot_scalar(snapshot, field_to_plot, 
                                                         vmin=lims["vmin"], vmax=lims["vmax"])
            
            # Colorbar labels
            unit = snapshot.fields[field_to_plot].meta.dim
            display_name = field_labels.get(field_to_plot, field_to_plot)
            cbar.set_label(f"{display_name} [{unit}]")
            
            # Time Label on Plot (keeping same format)
            time_Gyr = t_val / SEC_PER_GYR
            label_text = f"{time_Gyr:.3f} Gyr"
            
            ax.text(0.5, 0.5, label_text, 
                    transform=ax.transAxes, ha="center", va="center", 
                    fontsize=24, color="black", 
                    bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"))
            
            fig.set_size_inches(fig_width, fig_height)
            plt.tight_layout()
            
            # FILE NAME: [folder_name]_[field]_snap-[number]_[time]-Gyr.png
            file_name = f"{folder_name}_{field_to_plot}_snap-{snap_number}_{time_Gyr:.3f}-Gyr.png"
            fig.savefig(output_dir / file_name, dpi=300)
            plt.close(fig) 
            
            if i % 10 == 0 or i == len(frames_to_render) - 1:
                print(f"Saved: {file_name}")
                
        except Exception as e:
            print(f"Error at Snap {snap_number}: {e}")
            plt.close()