In [None]:
# test update

# Essential Jupyter Notebook Magic
%matplotlib inline

# General Purpose and Data Handling Libraries
import os
import re
import glob
import numpy as np
import pandas as pd
from os import listdir
from os.path import isfile, join
from natsort import natsorted
import pickle
from operator import add
import random

# MatPlotlib for Plotting and Visualization
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.tri as tri
import matplotlib as mpl
import matplotlib.colors as mcolors
from matplotlib import cm, ticker
from matplotlib.colors import LogNorm, LightSource, ListedColormap, BoundaryNorm
from matplotlib.collections import LineCollection
from matplotlib.cm import ScalarMappable
from matplotlib.ticker import LogFormatter, LogFormatterSciNotation
from matplotlib.ticker import LogLocator, MultipleLocator, NullFormatter
from mpl_toolkits.mplot3d import axes3d
from mpl_toolkits.mplot3d import Axes3D
from streamtracer import StreamTracer, VectorGrid
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d.proj3d import proj_transform
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from cmap import Colormap
from matplotlib.ticker import MaxNLocator
from matplotlib.gridspec import GridSpec


# Scipy for Scientific Computing and Analysis
from scipy import stats, interpolate
from scipy.optimize import curve_fit
from scipy.interpolate import interp1d, griddata
from scipy.ndimage import label, gaussian_filter
from scipy.spatial import ConvexHull
from scipy.interpolate import RegularGridInterpolator
from skimage import measure
from shapely.geometry import Polygon
from scipy.ndimage import label
from scipy.ndimage import distance_transform_edt
from scipy.ndimage import binary_fill_holes

# Image Handling and Processing
from PIL import Image

# Tecplot for Scientific Data Visualization
import tecplot as tp
from tecplot.exception import *
from tecplot.constant import *

# For 3d plotting
from skimage import measure

# FLEKS toolkit
import sys  
sys.path.insert(1, '/Users/atcushen/SWMF/PC/FLEKS/tools/')
import fleks,yt

# From UMGPT
from collections import defaultdict
import statistics

# Define Constants
amu = 1.67e-27
k_b = 1.38e-23
mu_0 = 1.257e-6
R_M = 2440e3 #m
m_p = 1.67e-27 # kg
e = 1.60218e-19 # C
c = 2.99e8 # m/s
eV = 6.242e18 # 1 J in eV
m_e = 9.11e-31 # kg

In [None]:
# Define utility functions
def read_dataset(mypath):
    # Reads in file "mypath" and returns a dataset object. May take a while for larger files.

    print("reading:",mypath)
    # First connect to TecPlot
    tp.session.connect(port=7600)

    # Configure layout
    tp.new_layout()
    dataset = tp.data.load_tecplot(mypath)
    frame = tp.active_frame()
    frame.plot_type = PlotType.Cartesian3D

    # Return dataset
    return dataset

def get_files(dir, key=".*cut_particle_region0_0.*", read_time = False, reduce = True):
    # For a directory "dir", return a list of all files which match the regex expression "key"
    
    all_files = [f for f in listdir(dir) if isfile(join(dir, f))]

    files=[]
    for file in all_files:
        match = re.search(key,file)
        if match != None:
            files.append(file)
    files.sort()
    
    # Now give them the appropriate name for their time
    # If we haven't already named these files with their time, do that now
    named_files = {}
    if read_time == False:
        for i in range(len(files)):
            time = round(i*dt+start_time,3)
            named_files[time] = files[i]
    # Otherwise, read the time right from the (last 6 elements) filename
    else:
        for i in range(len(files)):
            time = str("%.2f"%float(files[i][-6:]))
            named_files[time] = files[i]
    
    # Now cut the list down to files inside t_bound
    if reduce:
        reduced_files = {}
        filtered_keys = [file_time for file_time in list(named_files.keys()) if t_bound[0] <= float(file_time) < t_bound[1]]
        for file_time in filtered_keys:
            reduced_files[file_time] = str(named_files[file_time])
        return reduced_files

    else:
        return named_files

def nearest_voxel_mask(X, Y, Z, points):
    # Initialize a mask with False values
    mask = np.zeros(X.shape, dtype=bool)

    for point in points:
        # Find the index in the grid closest to this point
        idx = (
            np.abs(Y[:, 0, 0] - point[1]).argmin(),
            np.abs(X[0, :, 0] - point[0]).argmin(),
            np.abs(Z[0, 0, :] - point[2]).argmin()
        )
        #print("error:",np.sqrt((point[0]-X[idx])**2+(point[1]-Y[idx])**2+(point[2]-Z[idx])**2))

        # Set the mask at the nearest cell to True
        mask[idx] = True
    
    return mask

def nearest_cell(X, Y, Z, point):
    # Return indices of nearest cell to given point

    # Find the index in the grid closest to this point
    idx = (
        np.abs(Y[:, 0, 0] - point[1]).argmin(),
        np.abs(X[0, :, 0] - point[0]).argmin(),
        np.abs(Z[0, 0, :] - point[2]).argmin()
    )

    return idx

def average_value(var_ls,t0,t_start,t_stop,dt = 0.05,type='csdata',path=dir):
    # Input: Variables to average, the current time (t0), and the times relative to present to average over (t0+t_start to t0+t_stop)
    # Ouput: dictionary of arrays of time-averaged values

    averages = {}
    count = 0
    
    temp_files = get_files(path,key="3d\_fluid.*"+type+"\_t\_...\...",read_time = True, reduce = False)

    # We can change how many files to sample to generate the average
    file_skip = int(dt//0.05)

    for t in list(temp_files.keys())[::file_skip]: 
        # Check to see if this file is in the time range we want
        if (float(t) >= (t0+t_start)) and (float(t) <= t0+t_stop):
            temp_file = str(temp_files[t])
    
            # Read in this data
            print("reading",str(path+temp_file))
            with open(path+temp_file, 'rb') as f:
                temp_data = pickle.load(f) 
  
            # Compute derived quantities, if requested
            if "APhi" in var_ls:
                Pxx = temp_data["pxxS0"]
                Pxy = temp_data["pxyS0"]
                Pxz = temp_data["pxzS0"]
                Pyy = temp_data["pyyS0"]
                Pyz = temp_data["pyzS0"]
                Pzz = temp_data["pzzS0"]
                # Extract field terms
                Bx = temp_data["Bx"]
                By = temp_data["By"]
                Bz = temp_data["Bz"]
                # Compute magnetic unit vector
                B_mag = np.sqrt(Bx**2+By**2+Bz**2)
                bx,by,bz = 1/B_mag * [Bx,By,Bz]
                # APhi comes from Daughton (2008), and is most easily computed by first finding a number of N terms:
                Nxx = by*by*Pzz - 2*by*bz*Pyz + bz*bz*Pyy
                Nxy = -by*bx*Pzz + by*bz*Pxz + bz*bx*Pyz - bz*bz*Pxy
                Nxz = by*bx*Pyz - by*by*Pxz - bz*bx*Pyy + bz*by*Pxy
                Nyy = bx*bx*Pzz - 2*bx*bz*Pxz + bz*bz*Pxx
                Nyz = -bx*bx*Pyz + bx*by*Pxz + bz*bx*Pxy - bz*by*Pxx
                Nzz = bx*bx*Pyy - 2*bx*by*Pxy + by*by*Pxx
                alpha = Nxx+Nyy+Nzz
                beta = -(Nxy**2+Nxz**2+Nyz**2-Nxx*Nyy-Nxx*Nzz-Nyy*Nzz)
                temp_data["APhi"] = np.nan_to_num(2*np.sqrt(alpha**2-4*beta)/alpha, nan = 0)

            # Compute magnetic elevation
            if "mag_elev" in var_ls:
                temp_data["mag_elev"] = np.arctan(temp_data["Bz"]/np.sqrt(temp_data["Bx"]**2+temp_data["By"]**2))

            # Compute average current
            if ("Jx" in var_ls) or ("Jy" in var_ls) or ("Jz" in var_ls):
                temp_data["Jx"] = (e*temp_data["rhoS1"]*1e6*(temp_data["uxS1"]-temp_data["uxS0"])*1e3) # [A/m^2]
                temp_data["Jy"] = (e*temp_data["rhoS1"]*1e6*(temp_data["uyS1"]-temp_data["uyS0"])*1e3) # [A/m^2]
                temp_data["Jz"] = (e*temp_data["rhoS1"]*1e6*(temp_data["uzS1"]-temp_data["uzS0"])*1e3) # [A/m^2]
            
            # Add the data to our running average for each variable
            for var in var_ls:
                if var not in averages.keys():
                    averages[var] = temp_data[var]
                else:
                    averages[var] += temp_data[var]

            count += 1
        
    # Divide by total time steps
    for var in var_ls:
        averages[var] = averages[var]/count

    return averages

def plot_sphere(ax, radius=1, center=(0, 0, 0), color='b', alpha=0.5, zorder = 1, xlims = [-10,10], ylims = [-10,10], zlims = [-10,10]):
    """
    Plots a sphere of given radius centered at center on the provided 3D axis.
    
    Parameters:
    - ax: The 3D axis to plot the sphere on.
    - radius: The radius of the sphere (default: 1).
    - center: The (x, y, z) coordinates of the sphere's center (default: (0, 0, 0)).
    - color: The color of the sphere (default: blue).
    - alpha: The transparency of the sphere (default: 0.5).
    """
    u = np.linspace(np.pi/2, 3/2 * np.pi, 1000)
    v = np.linspace(0, np.pi, 1000)
    
    x = radius * np.outer(np.cos(u), np.sin(v)) + center[0]
    y = radius * np.outer(np.sin(u), np.sin(v)) + center[1]
    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2]

    # Mask out any values outside the axes lims
    mask = (x < xlims[0]) | (x > xlims[1]) | (y < ylims[0]) | (y > ylims[1]) | (z < zlims[0]) | (z > zlims[1]) 
    x[mask] = np.nan
    y[mask] = np.nan
    z[mask] = np.nan

    ax.plot_surface(x, y, z, color=color, alpha=alpha, zorder=zorder)

def compute_dt(var_ls,time,type='csdata',path=dir):
    # Input: variables to compute the time derivative for, and the current time, whether its cs or 3d data
    # Output: dictionary of time derivatives for each variable, calculated as dvar_dt = var(time+dt)-var(time-dt)/(2*dt)
    # If earlier or later times are not available, we do either var(time+dt)-var(time)/dt or var(time)-var(time-dt)/dt

    # Declare output dictionary
    deriv_dict = {}
    #print("Computing time derivatives at time",time)
    key_minus = '{:.2f}'.format((float(time)-dt), 'wb')
    key_plus = '{:.2f}'.format((float(time)+dt), 'wb')
    #print(files.keys())
    #print(key_minus)
    #print(key_plus)

    if type=='csdata':
        temp_files = filescs
        data = datacs
    elif type=='numpy':
        temp_files = files3D
        data = data3d

    # Read in the data depending on whether its available
    if (key_minus in temp_files.keys()) and (key_plus in temp_files.keys()):
        # Case one: earlier and later timestep available, so use both
        #print("Earlier and later timesteps available!")
        with open(path+temp_files[key_minus], 'rb') as f:
            data_tminus = pickle.load(f) 
        with open(path+temp_files[key_plus], 'rb') as f:
            data_tplus = pickle.load(f) 
        for var in var_ls:
            deriv_dict[var] = (data_tplus[var]-data_tminus[var])/(2*dt)
    
    elif (key_minus in temp_files.keys()):
        # Case two: only earlier time available
        print("Only earlier timestep available for dt!")
        with open(path+temp_files[key_minus], 'rb') as f:
            data_tminus = pickle.load(f) 
        for var in var_ls:
            deriv_dict[var] = (data[var]-data_tminus[var])/(dt)

    elif (key_plus in temp_files.keys()):
        # Case three: only later time available
        print("Only later timestep available for dt!")
        #print(dir+temp_files[key_plus])
        with open(path+temp_files[key_plus], 'rb') as f:
            data_tplus = pickle.load(f) 
        for var in var_ls:
            deriv_dict[var] = (data_tplus[var]-data[var])/(dt)

    else:
        print("ERROR: NO OTHER TIMESTEPS FOUND FOR DERIVATIVE AT TIME =",time)
            
    return deriv_dict

def create_enclosing_box_mask(Xgrid, Ygrid, Zgrid, point):
    """
    Create a mask for the cells closest to the given point (X, Y, Z) within 3D meshgrid arrays.

    Parameters:
    - X3d, Y3d, Z3d: 3D meshgrid arrays of the same shape
    - point: tuple (X, Y, Z), the coordinates of the point

    Returns:
    - mask: a boolean array with the same shape as the input arrays, True for the 8 closest cells
    """
    X, Y, Z = point

    if (X<np.min(Xgrid)+1/64) or (X>np.max(Xgrid)-1/64) or (Y<np.min(Ygrid)+1/64) or (Y>np.max(Ygrid)-1/64) or (Z<np.min(Zgrid)+1/64) or (Z>np.max(Zgrid)-1/64):
        # Point is outside or close to the edge
        return np.zeros_like(Xgrid, dtype=bool)
    
    # Create a mask with the same shape
    mask = np.zeros_like(Xgrid, dtype=bool)
    ix = np.where(Xgrid[0,:,0]<X)[0][-1]
    iy = np.where(Ygrid[:,0,0]<Y)[0][-1]
    iz = np.where(Zgrid[0,0,:]<Z)[0][-1]
    
    for dx in [0,1]:
        for dy in [0,1]:
            for dz in [0,1]:
                mask[iy+dy,ix+dx,iz+dz] = True
    
    return mask

def create_above_surface_mask(X, Y, Z, XX, YY, ZZ):
    # Works out all the 3D points above a 2D surface ie all the points above the current sheet.
    # Used for 3D plotting to determine what is above what.
    # Check that XX, YY, ZZ have the same shape
    assert XX.shape == YY.shape == ZZ.shape, "Arrays XX, YY, and ZZ must have the same shape"
    
    # Check that X, Y, Z have the same shape
    assert X.shape == Y.shape == Z.shape, "Arrays X, Y, and Z must have the same shape"

    # Determine the shape of the input arrays
    nx, ny, nz = XX.shape

    # Initialize a mask with the same shape as ZZ
    mask = np.zeros_like(ZZ, dtype=bool)

    # Iterate over the entire 3D meshgrid
    for i in range(nx):
        for j in range(ny):
            # Find the index in the 2D arrays corresponding to the x and y coordinates
            xi = np.argmin(np.abs(X[0] - XX[i, j, 0]))
            yi = np.argmin(np.abs(Y[:, 0] - YY[i, j, 0]))

            # Compare ZZ with Z to determine the mask
            mask[i, j, :] = ZZ[i, j, :] > Z[yi, xi]

    return mask


def plane_intersection(x, y, z, plane_z=0.2):
    # Used in 3D_df_tracker2 to check whether a given field line intersects the current sheet multiple times or not
    intersections = [] # list to store coodinates of intersection

    # Iterate over the list of points
    for i in range(1, len(z)):
        if ((z[i-1] - plane_z) * (z[i] - plane_z) < 0) or (z[i-1] == plane_z and z[i] != plane_z):
            # There is an intersection between z[i-1] and z[i] since their signs are different, or the point is exactly on the plane
            intersections.append([(x[i]+x[i-1])/2,(y[i]+y[i-1])/2,plane_z])

    return intersections

def fit_FR_ellipse(streamplot):

    # Pass the function a streamplot object in the xz plane showing flux rope-like field lines.
    # The function returns 'ellipse', a mask for this slice of the xz plane of the best fit ellipse
    
    # StreamplotSet object that holds the data
    streamlines = streamplot.lines.get_paths()
    
    # Extracting streamline traces
    stream_traces = []
    for streamline in streamlines:
        trace = streamline.vertices
        stream_traces.append(trace)
    
    # Get the min/max X of the DF in this plane
    df_xmin = np.min(X3d[iy_mean,:,:][mask[iy_mean,:,:]])
    df_xmax = np.max(X3d[iy_mean,:,:][mask[iy_mean,:,:]])
    
    # Find the largest streamline enclosing this area
    
    # Example: plot individual streamlines (optional visualization)
    trace_xmin = df_xmax
    for trace in stream_traces:
        if np.min(trace[:,0])>df_xmin and np.min(trace[:,0])<trace_xmin and (np.abs(trace[0,1]-0.2)<0.2 or np.abs(trace[-1,1]-0.2)<0.2):
            trace_xmin = np.min(trace[:,0])
            spiral = trace
    
    # Find the center of the spiral:
    #if np.abs(spiral[0,1]-0.2) < np.abs(spiral[-1,1]-0.2):
    #    center = spiral[0,:]
    #    end = spiral[-1,:]
    #else:
    #    center = spiral[-1,:]
    #    end = spiral[0,:]

    # Find the center of the spiral, as the point closest to the 'center of mass'
    x_com = np.mean(spiral[:,0])
    z_com = np.mean(spiral[:,1])
    if np.sqrt((spiral[0,0]-x_com)**2+(spiral[0,1]-z_com)**2) < np.sqrt((spiral[-1,0]-x_com)**2+(spiral[-1,1]-z_com)**2):
        center = spiral[0,:]
        end = spiral[-1,:]
    else:
        center = spiral[-1,:]
        end = spiral[0,:]
    
    # When the flux rope has a 'tail', we should use the center as given. Sometimes, if it is a hollow circular spiral, then we should use the center of mass
    if np.sqrt((end[0]-center[0])**2+(end[1]-center[1])**2) > 0.5:
        center = [x_com,z_com]
    
    # Find the y and x radii of the ellipse
    # Work out if the tail of the spiral is upper, lower, left, right
    if end[0] > center[0]:
        # This means the tail is to the right, so look for x_r going left
        x_r = np.abs(center[0]-np.min(spiral[:,0][spiral[:,0]<center[0]]))
    else:
        x_r = np.abs(center[0]-np.max(spiral[:,0][spiral[:,0]>center[0]]))
    if end[1] > center[1]:
        # This means the tail is above, so look for y_r going down
        y_r = np.abs(center[1]-np.min(spiral[:,1][spiral[:,1]<center[1]]))
    else:
        y_r = np.abs(center[1]-np.max(spiral[:,1][spiral[:,1]>center[1]]))
    
    # Define ellipse mask
    ellipse = ((X3d[iy_mean,:,:] - center[0])**2 / x_r**2 + (Z3d[iy_mean,:,:] - center[1])**2 / y_r**2) <= 1

    return ellipse

def fit_FR_ellipse2(streamplot,min_density = 20):
    # StreamplotSet object that holds the data
    streamlines = streamplot.lines.get_paths()
    
    # Extracting streamline traces
    stream_traces = []
    for streamline in streamlines:
        trace = streamline.vertices
        stream_traces.append(trace)
    
    # Get the min/max X of the DF in this plane
    df_xmin = np.min(X3d[iy_mean,:,:][mask[iy_mean,:,:]])
    df_xmax = np.max(X3d[iy_mean,:,:][mask[iy_mean,:,:]])
    
    # Find the largest streamline enclosing this area
    
    # Example: plot individual streamlines (optional visualization)
    totals = np.zeros_like(X3d[iy_mean,:,:],dtype=float)
    
    for trace in stream_traces:
        # Select flux-rope type lines that intersect the DF region
        if np.min(trace[:,0])>df_xmin and np.min(trace[:,0])<df_xmax and (np.abs(trace[0,1]-0.2)<0.2 or np.abs(trace[-1,1]-0.2)<0.2):
            hist, x_edges, y_edges = np.histogram2d(trace[:,0], trace[:,1], bins=[len(X3d[iy_mean,:,0]),len(X3d[iy_mean,0,:])], 
                                                    range=[[X3d[iy_mean,0,0], X3d[iy_mean,-1,0]], [Z3d[iy_mean,0,0], Z3d[iy_mean,0,-1]]])
            totals += hist
            # Reduce down to only curves that have no 'tail':
            #if np.sqrt((trace[0,0]-trace[-1,0])**2+(trace[0,1]-trace[-1,1])**2) < 0.5:
            #ax.plot(trace[:,0],trace[:,1],color='black', lw=0.1)
    
    # Define minimum cutoff to count
    ellipse_mask = totals > min_density

    if np.any(ellipse_mask):
    
        # Find the largest cluster (apart from all the huge 'false')
        labeled_mask, num_features = label(ellipse_mask)
        region_num, counts = np.unique(labeled_mask, return_counts=True)
        ellipse_mask = labeled_mask == region_num[np.argsort(counts)[-2]]
    
        return fill_gaps(ellipse_mask)

    else:
        return np.zeros_like(ellipse_mask, dtype=bool)
    

def fill_gaps(grid):
    # Convert the input list to a NumPy array
    grid_array = np.asarray(grid, dtype=bool)
    
    # Use binary_fill_holes to fill internal gaps
    filled_array = binary_fill_holes(grid_array)
    
    # Convert the result back to a list of lists, if needed
    filled_grid = filled_array.tolist()
    return filled_grid

def draw_cube_outline(ax, xbox, ybox, zbox, color='red',lw = 1, zorder = 2):
    # input: axes object to plot onto, 3 tuples of the axes limits, and then plotting paramters
    # no output: it drawns directly onto ax
    for ai in [0,1]:
        for bi in [0,1]:
            for ci in [0,1]:
                ax.plot(xbox[ai],ybox[bi],zbox[ci], color = color, lw = lw, zorder = zorder)

def compute_para_perp(Bx,By,Bz,Ex,Ey,Ez,pxx,pyy,pzz,pxy,pxz,pyz):
    # Inpute: 3 components of B and E, and 6 components of pressure tensor. 
    # Output: diagonal of pressure tensor projected into magnetic coords, where P11 and P22 are perpendicular components, P33 is parallel
    # Output units match input (i.e. nPa), magnetic field and E can be any units due to normalization
    
    # Compute magnetic unit vector
    Bmag = np.sqrt(Bx**2+By**2+Bz**2)
    bx,by,bz = 1/Bmag * [Bx,By,Bz]

    # Compute magnetic field unit vectors: u,v,b. u is parallel to ExB direction. v = uxb. We already have b.
    # In the end, our new system will be perp1, perp2, para (u,v,b). See back page of space physics 2024 notebook for derivation, or photos from sep. 4, 2024
    # Compute u first, making sure to normalize
    ux,uy,uz = [Ey*Bz-Ez*By, Ez*Bx-Ex*Bz, Ex*By-Ey*Bx]
    ux,uy,uz = 1/np.sqrt(ux**2+uy**2+uz**2)*[ux,uy,uz]

    # Then compute v (already normalized):
    vx,vy,vz = [uy*bz-uz*by, uz*bx-ux*bz, ux*by-uy*bx]

    # We now compute the three diagonal pressure components, using e.g. P_para = b . p . b
    P11 = ux*(pxx*ux+pxy*uy+pxz*uz) + uy*(pxy*ux+pyy*uy+pyz*uz) + uz*(pxz*ux+pyz*uy+pzz*uz)
    P22 = vx*(pxx*vx+pxy*vy+pxz*vz) + vy*(pxy*vx+pyy*vy+pyz*vz) + vz*(pxz*vx+pyz*vy+pzz*vz)
    P33 = bx*(pxx*bx+pxy*by+pxz*bz) + by*(pxy*bx+pyy*by+pyz*bz) + bz*(pxz*bx+pyz*by+pzz*bz)

    return P11, P22, P33

def compute_recon_score(Bx,By,Bz,Ex,Ey,Ez,Jx,Jy,Jz,uex,uey,uez,Pxx,Pyy,Pzz,Pxy,Pxz,Pyz):
    # Inputs: magnetic field, electric field, current density, electron bulk velocity, electron pressure tensor
    # All inputs in SI i.e. [V/m, T, A/m^2, etc]
    # Outputs: need to check units
    # Compute reconnection score from Li et al. 2024
    # L
    L = np.nan_to_num(np.log10(c*np.sqrt((Ey*Bz-Ez*By)**2+(Ez*Bx-Ex*Bz)**2+(Ex*By-Ey*Bx)**2) / (Ex**2+Ey**2+Ez**2)),0)

    # D_e
    D_e_x = Jx*Ex + Jx*(uey*Bz-uez*By)
    D_e_y = Jy*Ey + Jy*(uez*Bx-uex*Bz)
    D_e_z = Jz*Ez + Jz*(uex*By-uey*Bx)
    D_e = D_e_x+D_e_y+D_e_z

    # APhi
    Bmag = np.sqrt(Bx**2+By**2+Bz**2)
    bx,by,bz = 1/Bmag * [Bx,By,Bz]
    Nxx = by*by*Pzz - 2*by*bz*Pyz + bz*bz*Pyy
    Nxy = -by*bx*Pzz + by*bz*Pxz + bz*bx*Pyz - bz*bz*Pxy
    Nxz = by*bx*Pyz - by*by*Pxz - bz*bx*Pyy + bz*by*Pxy
    Nyy = bx*bx*Pzz - 2*bx*bz*Pxz + bz*bz*Pxx
    Nyz = -bx*bx*Pyz + bx*by*Pxz + bz*bx*Pxy - bz*by*Pxx
    Nzz = bx*bx*Pyy - 2*bx*by*Pxy + by*by*Pxx
    alpha = Nxx+Nyy+Nzz
    beta = -(Nxy**2+Nxz**2+Nyz**2-Nxx*Nyy-Nxx*Nzz-Nyy*Nzz)
    APhi = np.nan_to_num(2*np.sqrt(alpha**2-4*beta)/alpha, nan = 0)

    # root(Q)
    P_para = bx**2*Pxx + by**2*Pyy + bz**2*Pzz + 2*(bx*by*Pxy + bx*bz*Pxz + by*bz*Pyz)
    I1 = Pxx+Pyy+Pzz
    I2 = Pxx*Pyy+Pxx*Pzz+Pyy*Pzz - (Pxy**2+Pxz**2+Pyz**2)
    root_Q = np.nan_to_num(np.sqrt(1 - 4*I2/(I1-P_para)/(I1+3*P_para)), nan = 0)

    # Compute total score
    S = 10**(0.25*(3-L)/(3-np.min(L))) + 10**(0.25*D_e/np.max(D_e)) + 10**(0.25*APhi/np.max(APhi)) + 10**(0.25*root_Q/np.max(root_Q))

    return L, D_e, APhi, root_Q, S

def compute_AEPIC_recon_score(Bx,By,Bz,dB_dx,dB_dy,dB_dz,dBx_dx,dBx_dy,dBx_dz,dBy_dx,dBy_dy,dBy_dz,dBz_dx,dBz_dy,dBz_dz,Jx,Jy,Jz,c1_min = 0.005, c2_min = 1e-7):
    # All input in SI
    # Compute first critera of Wang et al. 2022
        epsilon = 1 # 1 nT, as given in paper
        c1 = (Jx**2+Jy**2+Jz**2)*dx / (np.sqrt((Jy*Bz-Jz*By)**2 + (Jz*Bx-Jx*Bz)**2 + (Jx*By-Jy*Bx)**2) + np.sqrt(Jx**2+Jy**2+Jz**2)*epsilon)
        # c1>0.8 is recommended threshold
        
        # Compute second criteria, based on curvature divergence. If large enough, it seperates x lines from o-lines or flux ropes
        Bmag = np.sqrt(Bx**2+By**2+Bz**2) # [T]
        bx = Bx/Bmag
        by = By/Bmag
        bz = Bz/Bmag
        # We need partial derivatives of the magnetic field unit vectors.
        # To reduce the number of derivates to take, we compute them using the chain rule and the pre-computed derivatives
        dbx_dx = (1/Bmag) * (dBx_dx - bx*dB_dx)
        dbx_dy = (1/Bmag) * (dBx_dy - bx*dB_dy)
        dbx_dz = (1/Bmag) * (dBx_dz - bx*dB_dz)
        dby_dx = (1/Bmag) * (dBy_dx - by*dB_dx)
        dby_dy = (1/Bmag) * (dBy_dy - by*dB_dy)
        dby_dz = (1/Bmag) * (dBy_dz - by*dB_dz)
        dbz_dx = (1/Bmag) * (dBz_dx - bz*dB_dx)
        dbz_dy = (1/Bmag) * (dBz_dy - bz*dB_dy)
        dbz_dz = (1/Bmag) * (dBz_dz - bz*dB_dz)
        # Precompute the terms of c2
        pre_c2_x = bx*dbx_dx + by*dbx_dy + bz*dbx_dz
        pre_c2_y = bx*dby_dx + by*dby_dy + bz*dby_dz
        pre_c2_z = bx*dbz_dx + by*dbz_dy + bz*dbz_dz
        # Compute the gradient of each term
        c2_x = np.gradient(pre_c2_x,dx*R_M,axis=1)
        c2_y = np.gradient(pre_c2_y,dx*R_M,axis=0)
        c2_z = np.gradient(pre_c2_z,dx*R_M,axis=2)
        c2 = (c2_x+c2_y+c2_z)*(dx)**2

        # Combine criteria
        recon_sites = np.copy(c1)
        recon_sites[c1<c1_min] = 0
        recon_sites[c2<c2_min] = 0
        recon_sites[recon_sites>0]=1

        return recon_sites

def plot_colored_surface(ax, X_plot, Y_plot, Z_plot, value, vmin = 0, vmax = 6, cmap = 'viridis', alpha = 1, zorder = 2, shading = False, nan_threshold = -1e5):
    # Input: (2D) X, Y, Z coordinate arrays for the array "value", other plotting params
    # Ouput: surface object, already plotted onto ax

    # Use transparency to remove sections we want masked out
    alpha_array = np.zeros_like(X_plot)
    alpha_array[value>nan_threshold] = alpha
    
    norm = plt.Normalize(vmin,vmax)
    if cmap == 'viridis':
        surf_colors = cm.viridis(norm(value),alpha=alpha_array)
    elif cmap == 'plasma':
        surf_colors = cm.plasma(norm(value),alpha=alpha_array)
    elif cmap == 'rainbow':
        surf_colors = cm.rainbow(norm(value),alpha=alpha_array)
    elif cmap == 'bwr':
        surf_colors = cm.bwr(norm(value),alpha=alpha_array)
    elif cmap == 'Reds':
        surf_colors = cm.Reds(norm(value),alpha=alpha_array)
    elif cmap == 'Greens':
        surf_colors = cm.Greens(norm(value),alpha=alpha_array)
    elif cmap == 'YlOrRd':
        surf_colors = cm.YlOrRd(norm(value),alpha=alpha_array)
    elif cmap == 'gray':
        surf_colors = cm.gray(norm(value),alpha=alpha_array)
        
    # Set the lighting
    light = LightSource(azdeg = 155,altdeg = 10)  # Azimuth and altitude of the light source
    if shading:
        surf_illuminated_colors = light.shade_rgb(surf_colors, Z_plot, blend_mode='soft')  # Apply light source shading

    # Mask out dataless values
    #X_plot[value<nan_threshold] = np.nan
    #Y_plot[value<nan_threshold] = np.nan
    #Z_plot[value<nan_threshold] = np.nan
    #surf_colors[value<nan_threshold] = np.nan
    
    if shading:
        return ax.plot_surface(X_plot, Y_plot, Z_plot, facecolors=surf_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=zorder,
                               edgecolor='none')
    else:
        return ax.plot_surface(X_plot, Y_plot, Z_plot, facecolors=surf_colors, rstride=1, cstride=1, antialiased=False, zorder=zorder,
                               shade = False, edgecolor='none')

def add_grid(ax, xlims, ylims, x_major, x_minor, y_major, y_minor):
    # Input: ax object, x and y lims (both should be a tuple), and the major and minor gridspacing for x and y
    x_major_ticks = np.arange(xlims[0], xlims[1], x_major)
    x_minor_ticks = np.arange(xlims[0], xlims[1], x_minor)
    y_major_ticks = np.arange(ylims[0], ylims[1], y_major)
    y_minor_ticks = np.arange(ylims[0], ylims[1], y_minor)

    ax.set_xticks(x_major_ticks)
    ax.set_xticks(x_minor_ticks, minor=True)
    ax.set_yticks(y_major_ticks)
    ax.set_yticks(y_minor_ticks, minor=True)

    ax.grid(which='both')
    ax.grid(which='minor', alpha=0.2)
    ax.grid(which='major', alpha=0.5)

def get_tracer(X,Y,Z,Ax,Ay,Az,nsteps = 10000,step_size = 1e-3, cell_size = 0.01562501):
    # Sets up field line tracing based off 3 3d arrays for the vectorfield
    ny,nx,nz = Ax.shape
    field = np.zeros((nx,ny,nz,3))
    field[:,:,:,0] = np.transpose(Ax,axes=[1,0,2])
    field[:,:,:,1] = np.transpose(Ay,axes=[1,0,2])
    field[:,:,:,2] = np.transpose(Az,axes=[1,0,2])
    grid_spacing = [cell_size,cell_size,cell_size]
    grid = VectorGrid(field, grid_spacing, origin_coord = [X.min(),Y.min(),Z.min()])
    return StreamTracer(nsteps, step_size), grid
        

In [None]:
# Find flux ropes in the selected time range, and save the results

dir = "/Users/atcushen/Documents/MercuryModelling/runs/DR_run1/alldata/"    
dir_3D = "/Volumes/My Book Duo/runs/DR_run1/alldata/" 
start_time = 210
t_bound = [212.00,500]                                               # Start and stop times of this data to be plot
dt = 0.05
cell_size = R_M/64   

read_data = True

xlims = [-4, 0]
ylims = [-1.2,1.2]
zlims = [-0.7,0.8]

# END INPUT

# Precomputations for nice axes limits: calculate the range for each axis
x_range = xlims[1] - xlims[0]
y_range = ylims[1] - ylims[0]
z_range = zlims[1] - zlims[0]

# Find the maximum range among x, y, z
max_range = max(x_range, y_range, z_range)

#RUN
files3D = get_files(dir_3D,key="3d\_fluid.*numpy\_t\_...\...",read_time = True)
filescs = get_files(dir,key="3d\_fluid.*csdata\_t\_...\...",read_time = True)

iter = 0
for time in list(files3D.keys()): 
    print("Plotting t =",time)

    # Read in the 3D data
    if read_data:
        file3D = str(files3D[time])
        with open(dir_3D+file3D, 'rb') as f:
            print("reading 3d data: ",str(dir_3D+file3D))
            data3d = pickle.load(f) 
        
        # Read in current sheet data
        filecs = str(filescs[time])
        with open(dir+filecs, 'rb') as f:
            print("reading cs data:",str(dir+filecs))
            datacs = pickle.load(f) 
    

        trim_x = np.where((data3d["X"][0,:,0]>xlims[0]) & (data3d["X"][0,:,0]<xlims[1]))[0]
        trim_y = np.where((data3d["Y"][:,0,0]>ylims[0]) & (data3d["Y"][:,0,0]<ylims[1]))[0]
        trim_z = np.where((data3d["Z"][0,0,:]>zlims[0]) & (data3d["Z"][0,0,:]<zlims[1]))[0]
    
        # Unpack data
        Xcs = datacs["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1]]
        Ycs = datacs["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1]]
        Zcs = datacs["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1]]
        Bzcs = datacs["Bz"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1]]
        X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        n3d = data3d["rhoS1"]*1e6 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e6 #[1/m^3]
        Ex3d = data3d["Ex"]* 1e-6#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] * 1e-6 #V/m
        Ey3d = data3d["Ey"]* 1e-6#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] * 1e-6
        Ez3d = data3d["Ez"]* 1e-6#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] * 1e-6
        Bx3d = data3d["Bx"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # nT
        By3d = data3d["By"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Bz3d = data3d["Bz"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        uix3d = data3d["uxS1"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # km/s
        uiy3d = data3d["uyS1"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        uiz3d = data3d["uzS1"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        uex3d = data3d["uxS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # km/s
        uey3d = data3d["uyS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        uez3d = data3d["uzS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Pxx = data3d["pxxS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # nPa
        Pxy = data3d["pxyS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Pxz = data3d["pxzS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Pyy = data3d["pyyS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Pyz = data3d["pyzS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Pzz = data3d["pzzS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        
        # Compute derived terms
        Jx3d = e*n3d*(uix3d-uex3d)*1e3 # A/m^2
        Jy3d = e*n3d*(uiy3d-uey3d)*1e3 
        Jz3d = e*n3d*(uiz3d-uez3d)*1e3 
        Bmag3d = np.sqrt(Bx3d**2+By3d**2+Bz3d**2) # nT
    
        # Look for flux ropes as anywhere a field line traced from the current sheet reintersects the plane
        # We are only looking for planetward flux ropes, tail moving will need to be filtered out
        rope_candidates = np.zeros_like(Xcs)-1
        
        # First, set up field line tracing to do a short trace to check all the cells
        ny,nx,nz = Bx3d.shape
        field = np.zeros((nx,ny,nz,3))
        field[:,:,:,0] = np.transpose(Bx3d,axes=[1,0,2])
        field[:,:,:,1] = np.transpose(By3d,axes=[1,0,2])
        field[:,:,:,2] = np.transpose(Bz3d,axes=[1,0,2])
        dx = np.mean(np.diff(Xcs[0,:]))
        grid_spacing = [dx,dx,dx]
        grid = VectorGrid(field, grid_spacing, origin_coord = [X3d.min(),Y3d.min(),Z3d.min()])
        nsteps = 1000
        step_size = dx/10
        tracer = StreamTracer(nsteps, step_size)
    
        # Define the seeds, and remove points inside the planet
        seeds = np.array([np.ravel(Xcs),np.ravel(Ycs),np.ravel(Zcs)]).T
        seeds = seeds[np.sqrt(np.sum(seeds**2,axis=1))>1.05]
    
        # Complete the line tracing
        print("starting trace")
        tracer.trace(seeds, grid)
        print("tracing done")

    # Debug plot
    fig = plt.figure(figsize=(20,12), constrained_layout=True)
    ax1 = fig.add_subplot(121, projection="3d",computed_zorder=False)
    ax2 = fig.add_subplot(122, projection="3d",computed_zorder=False)


    # Iterate through each seed to classify topology
    for iseed, seed in enumerate(seeds):
        
        # Unpack line data
        trace_x = tracer.xs[iseed][:,0]
        trace_y = tracer.xs[iseed][:,1]
        trace_z = tracer.xs[iseed][:,2]
        xz_points = np.array([trace_x,trace_z]).T
        
        # Calculate angles for consecutive points with respect to the x-axis
        angles = np.arctan2(np.diff(xz_points[:, 1]), np.diff(xz_points[:, 0]))
        
        # Calculate angle differences between consecutive segments
        angle_diffs = np.ediff1d(angles)
        
        # Normalize the angle differences to be between -pi and pi
        angle_diffs = (angle_diffs + np.pi) % (2 * np.pi) - np.pi
        
        # Compute the cumulative change in angle and convert to degrees
        cumulative_angle_change = np.sum(angle_diffs) 
        
        # Check if cumulative angle change is at least 360 degrees (or -360 for counter-clockwise)
        completes_full_rotation = abs(cumulative_angle_change) >= 3*np.pi
    
        if completes_full_rotation:
            ix = np.where(seed[0] == Xcs[0,:])[0]
            iy = np.where(seed[1] == Ycs[:,0])[0]    
            rope_candidates[iy,ix] = iseed
            ax1.plot(tracer.xs[iseed][:,0],tracer.xs[iseed][:,1],tracer.xs[iseed][:,2],c='black',lw=0.005)
            ax2.plot(tracer.xs[iseed][:,0],tracer.xs[iseed][:,1],tracer.xs[iseed][:,2],c='black',lw=0.005)

    # Compute reconnection score from Li et al. 2024
    # L
    L = np.nan_to_num(np.log10(c*np.sqrt((Ey3d*Bz3d-Ez3d*By3d)**2+(Ez3d*Bx3d-Ex3d*Bz3d)**2+(Ex3d*By3d-Ey3d*Bx3d)**2)*1e-9 / (Ex3d**2+Ey3d**2+Ez3d**2)),0)

    # D_e
    D_e_x = Jx3d*Ex3d + Jx3d*(uey3d*Bz3d-uez3d*By3d)*1e3*1e-9
    D_e_y = Jy3d*Ey3d + Jy3d*(uez3d*Bx3d-uex3d*Bz3d)*1e3*1e-9
    D_e_z = Jz3d*Ez3d + Jz3d*(uex3d*By3d-uey3d*Bx3d)*1e3*1e-9
    D_e = D_e_x+D_e_y+D_e_z

    # APhi
    bx,by,bz = 1/Bmag3d * [Bx3d,By3d,Bz3d]
    Nxx = by*by*Pzz - 2*by*bz*Pyz + bz*bz*Pyy
    Nxy = -by*bx*Pzz + by*bz*Pxz + bz*bx*Pyz - bz*bz*Pxy
    Nxz = by*bx*Pyz - by*by*Pxz - bz*bx*Pyy + bz*by*Pxy
    Nyy = bx*bx*Pzz - 2*bx*bz*Pxz + bz*bz*Pxx
    Nyz = -bx*bx*Pyz + bx*by*Pxz + bz*bx*Pxy - bz*by*Pxx
    Nzz = bx*bx*Pyy - 2*bx*by*Pxy + by*by*Pxx
    alpha = Nxx+Nyy+Nzz
    beta = -(Nxy**2+Nxz**2+Nyz**2-Nxx*Nyy-Nxx*Nzz-Nyy*Nzz)
    APhi = np.nan_to_num(2*np.sqrt(alpha**2-4*beta)/alpha, nan = 0)

    # root(Q)
    P_para = bx**2*Pxx + by**2*Pyy + bz**2*Pzz + 2*(bx*by*Pxy + bx*bz*Pxz + by*bz*Pyz)
    I1 = Pxx+Pyy+Pzz
    I2 = Pxx*Pyy+Pxx*Pzz+Pyy*Pzz - (Pxy**2+Pxz**2+Pyz**2)
    root_Q = np.sqrt(np.nan_to_num(1 - 4*I2/(I1-P_para)/(I1+3*P_para), nan = 0))

    # Compute total score
    S = 10**(0.25*(3-L)/(3-np.min(L))) + 10**(0.25*D_e/np.max(D_e)) + 10**(0.25*APhi/np.max(APhi)) + 10**(0.25*root_Q/np.max(root_Q))

    # Iterate through the flux rope candidates, skipping some to save time
    trace_skip = 10

    # Declare mask representing all the cells in 3d space connected to flux ropes
    mask = np.zeros_like(X3d,dtype='bool')
    for iseed in np.ravel(rope_candidates[rope_candidates>=0])[::trace_skip]:
        mask = np.logical_or(mask, nearest_voxel_mask(X3d, Y3d, Z3d, tracer.xs[int(iseed)]))

    # Save the mask for this time step
    pickle.dump(mask, open(str(dir+"FRs/FR_mask_t_"+'{:06.2f}'.format(round(float(time),2))), 'wb') )

    # plot:
    # Define reconnection regions 
    recon_mask = (S > 4.9) 

    for ax in [ax1,ax2]: 
        if np.any(mask):
            ax.scatter(X3d[mask],Y3d[mask],Z3d[mask],c=-Z3d[mask],vmin=-1.2,vmax=0.9,cmap="plasma",alpha = 0.01)
        #ax.scatter(X3d[recon_mask],Y3d[recon_mask],Z3d[recon_mask],c=-Z3d[recon_mask],vmin=-1.2,vmax=0.9,cmap="Greens",alpha = 0.3)
    
        # Update limits to be centered with max range
        ax.set_xlim(np.mean(xlims) - max_range / 2, np.mean(xlims) + max_range / 2)
        ax.set_ylim(np.mean(ylims) - max_range / 2, np.mean(ylims) + max_range / 2)
        ax.set_zlim(np.mean(zlims) - max_range / 2, np.mean(zlims) + max_range / 2)
    
        ax.set_xlabel("X [$R_M$]")
        ax.set_ylabel("Y [$R_M$]")
        ax.set_zlabel("Z [$R_M$]")
        
        # Set viewing angle
        ax1.view_init(elev=90, azim=-90)
        ax2.view_init(elev=0, azim=-90)
        
        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/flux_rope_volume_fit"+"_"+"%.2f"%round(float(time),2)+".png"),bbox_inches='tight', dpi=300)
        plt.show()
        plt.close(fig)


    

In [None]:
# Read in saved FR data and perform further analysis

dir = "/Users/atcushen/Documents/MercuryModelling/runs/nightside_v5_run3/alldata/"   # Should contain the FR data  in dir/FRs/
dir_3D = "/Volumes/My Book Duo/runs/nightside_v5_run3/alldata/" 
#dir = "/Users/atcushen/Documents/MercuryModelling/runs/DR_run1/alldata/"
#dir_3D = "/Volumes/My Book Duo/runs/DR_run1/alldata/" 
#dir_3D = dir
start_time = 120
#t_bound = [140,143]                                               # Start and stop times of this data to be plot
t_bound = [158,160.00]
dt = 0.05
cell_size = R_M/64   
dx = cell_size
mi_me = 100

read_data = True

# Plotting area
x_region = [-2.5,-1.0]
y_region = [-1.2,1.20]
z_region = [-0.3,0.7]

# Outdated settings for reducing the analysis area
# Option for subregion
xlims = [-2.25,-1] #LIMITS MUST MATCH FROM ABOVE FOR MASK TO WORK
ylims = [-0.25,1.0]
zlims = [-0.1,0.5]

# Used to specifiy timeseries point
loc = [-1.4,0.4,0.3]

# Plot/analysis modes
'''
"Background_field": computes the average background field for other analyses, and the flux tube volume

"Adibaticity_1D": line plot instantenous pressure compared to background profile

"FR_DF_formation": shows DFs and FRs, saving a mask for all the DFs

"dBz_dt_formation": New method for identifying DFs (currently without FR data), using dBz_dt a la 10.1029/2025JA033892

"DF_heating":

"DF_example_extraction": Based on the start time and a given coordinate, attempts to extract the position mask for a given DF

"DF_example_filtering:" loads in example masks from the above loop, and tries to remove extensional features based on electron velocity

"DF_example_summary": loads in filtered example masks from the above loop, plots time series data

"DF_example_visualizer":  loads in filtered example masks from the above loop, shows nice 3d view for context, in subregion of interest

"DF_example_visualizer2":  loads in filtered example masks from the| above loop, shows 3 superimposed times

"DF_example_visualizer3": Combines 1 and 2: a 3d view of 3 superimposed times

"force_equilib": plots JxB and grad(p), and J_y (colored according to how diamagnetic it is)

"Jy_components": plots of total Jy, Jy_dia, and Jy_RC for a given y plane

"deltaBz_xz": detail plot of deltaBz to investigate DF rereconnection geometry

"p_xz": detail plot of p to investigate DF rereconnection geometry

"xz_slice": generic xz plane slice, showing selected parameter

"Adiabaticity_2D": 2d slice showing entropy at each cell in xz

"DF_force_diagram": Binned data showing force balance as function of x

"FACs": DF and FR plot, with FACs identified

"B_timeseries": Declare a point with loc and generate a timeseries of the magnetic field there for t_bound

"reconnection_sites": average S in the tail

"AEPIC_reconnection_sites": Apply Wang et al. 2022 method for AEPIC regions to identify reconnection sites

"total_current": Computes eastward and westward currents in a given sector

"current_spectra": shows spectral-like 2d plot of cross-tail current as function of time and X, for a given Y

"reconnection_spectra": shows spectral-like 2d plot of cross-tail reconnection activity as a function of X and T, and y and T

"entropy_FAC_map": 2D map of upward and downward FACs and flux tube entropy

"flux_tube_visualizer": 

"pressure_balance": 1d lineplot along x and y=loc[1] and z=loc[2] of p_i, p_e, p_dyn, p_mag
'''

plot_preset = "xz_slice"

#RUN

# Precomputations for nice axes limits: calculate the range for each axis
x_range = xlims[1] - xlims[0]
y_range = ylims[1] - ylims[0]
z_range = zlims[1] - zlims[0]

# Find the maximum range among x, y, z
max_range = max(x_range, y_range, z_range)

# Get the relevant file names
files3D = get_files(dir_3D,key="3d\_fluid.*numpy\_t\_...\...",read_time = True)
filescs = get_files(dir,key="3d\_fluid.*csdata\_t\_...\...",read_time = True)

iter = 0
file_found = False # Flag used to indicate if optional pre-processed data is loaded

for time in list(files3D.keys()): 
    print("Plotting t =",time)
    
    # Read in the 3D data
    if read_data: #or iter == 0:
        file3D = str(files3D[time])
        with open(dir_3D+file3D, 'rb') as f:
            print("reading 3d data: ",str(dir_3D+file3D))
            data3d = pickle.load(f) 
        
        # Read in current sheet data
        if plot_preset in ["DF_example_visualizer","reconnection_sites","AEPIC_reconnection_sites","entropy_FAC_map","flux_tube_visualizer"]:
            filecs = str(filescs[time])
            with open(dir+filecs, 'rb') as f:
                print("reading cs data:",str(dir+filecs))
                datacs = pickle.load(f)

        # Read in flux rope data
        if plot_preset in ["FR_DF_formation", "dBz_dt_formation",
                           "DF_heating","DF_example_visualizer","DF_example_visualizer2"]:
            filesFR = get_files(str(dir+"FRs/"),key="FR\_mask\_t\_...\...",read_time = True)
            fileFR = str(filesFR[time])
            with open(dir+"FRs/"+fileFR, 'rb') as f:
                print("reading FR data:",str(dir+"FRs/"+fileFR))
                FR_mask = pickle.load(f) 

        # Read in DF mask data
        if plot_preset in ["DF_example_summary","DF_example_extraction","DF_examplefiltering","DF_heating",
                           "DF_example_visualizer","DF_example_visualizer2"]:#,"deltaBz_xz"]:
            filesDF = get_files(str(dir+"DFs/"),key="DF\_mask\_t\_...\...",read_time = True)
            fileDF = str(filesDF[time])
            with open(dir+"DFs/"+fileDF, 'rb') as f:
                print("reading DF data:",str(dir+"DFs/"+fileDF))
                DF_mask = pickle.load(f) 

        # Read in background field data
        if plot_preset in ["Adibaticity_1D","Adiabaticity_2D","DF_example_summary","FR_DF_formation",
                           "DF_heating","DF_example_visualizer","DF_example_visualizer2","B_timeseries","entropy_FAC_map",
                           "flux_tube_visualizer","Jy_components","deltaBz_xz"] and iter==0:
            with open(dir+"background_field/background_field", 'rb') as f:
                print("reading background field data:",str(dir+"background_field/background_field"))
                background_field = pickle.load(f) 

        # Read in entropy data
        if (plot_preset in ["flux_tube_visualizer","Adiabaticity_2D"]):
            filesS = get_files(str(dir+"S_data/"),key="S\_t\_...\...",read_time = True)
            if time in filesS.keys():
                fileS = str(filesS[time])
                with open(dir+"S_data/"+fileS, 'rb') as f:
                    print("reading entropy data:",str(dir+"S_data/"+fileS))
                    entropy_map = pickle.load(f) 
                S_array = entropy_map['S']
                FAC_array = entropy_map['FAC']
                file_found = True
                
            else:
                print("No data for entropy found at this time step")
                file_found = False
        
    # Read timeseries data
    if plot_preset == "B_timeseries" and (not read_data): 
        filesTS = get_files(str(dir+"timeseries/"),key="B\_timeseries\_...\...",read_time = False, reduce = False)
        for key in filesTS.keys():
            if str(loc).replace(" ", "") in filesTS[key]:
                fileTS = filesTS[key]
                with open(dir+"timeseries/"+fileTS, 'rb') as f:
                    timeseries = pickle.load(f) 
                print("reading timeseries data:",str(dir+"timeseries/"+fileTS))
                file_found = True
                break
            
        # Compute the limits based on where FRs have been searched for
        trim_x = np.where((data3d["X"][0,:,0]>xlims[0]) & (data3d["X"][0,:,0]<xlims[1]))[0] #LIMITS MUST MATCH FROM ABOVE FOR MASK TO WORK
        trim_y = np.where((data3d["Y"][:,0,0]>ylims[0]) & (data3d["Y"][:,0,0]<ylims[1]))[0]
        trim_z = np.where((data3d["Z"][0,0,:]>zlims[0]) & (data3d["Z"][0,0,:]<zlims[1]))[0]


    # Read in entropy data for entropy_map_case, where read_Data=False  means to read in the saved data instead of recompute
    if ((plot_preset == "entropy_FAC_map") and (not read_data)):
        if time in filesS.keys():
            fileS = str(filesS[time])
            with open(dir+"S_data/"+fileS, 'rb') as f:
                print("reading entropy data:",str(dir+"S_data/"+fileS))
                entropy_map = pickle.load(f) 
            S_array = entropy_map['S']
            FAC_array = entropy_map['FAC']
            file_found = True
            
        else:
            print("No data for entropy found at this time step")
            file_found = False
    
    #### BEGIN PLOTTING PRESETS ####

    if plot_preset == "Background_field":

        if read_data:
            
            # Unpack data
            X3d = data3d["X"]
            Y3d = data3d["Y"]
            Z3d = data3d["Z"]
    
            # Compute average
            avgs = average_value(["Bx","By","Bz","Jx","Jy","Jz","pxxS0","pyyS0","pzzS0","pxxS1","pyyS1","pzzS1",'rhoS1'],float(time),0,5,dt = 1,type='numpy',path=dir_3D)
            Bx_avg = avgs["Bx"]
            By_avg = avgs["By"]
            Bz_avg = avgs["Bz"]
            Jx_avg = avgs["Jx"]
            Jy_avg = avgs["Jy"]
            Jz_avg = avgs["Jz"] 
            pxxS0 = avgs["pxxS0"] #[nPa]
            pyyS0 = avgs["pyyS0"] #[nPa]
            pzzS0 = avgs["pzzS0"] #[nPa]
            pxxS1 = avgs["pxxS1"] #[nPa]
            pyyS1 = avgs["pyyS1"] #[nPa]
            pzzS1 = avgs["pzzS1"] #[nPa]
            rho = avgs["rhoS1"] #amu/cc
        
            # Compute average field magnitude
            print("Averaging complete, setting up field line tracing")
            Bmag_avg = np.sqrt(Bx_avg**2+By_avg**2+Bz_avg**2)
            V_integrand = 1/Bmag_avg
            # Fix nan and infs
            V_integrand[np.isnan(V_integrand)] = 0
            V_integrand[V_integrand == np.inf] = 0

            # Compute average pressure
            p_avg = (pxxS0+pyyS0+pzzS0+pxxS1+pyyS1+pzzS1)/3 # [nPa]
    
            # Work out where the magnetic equator is
            zplane = 0.2 #np.mean(Y)
            izplane = np.where((Z3d[0,0,:]>zplane))[0][0]
            
            # Set up field line tracing
            tracer,grid = get_tracer(X3d,Y3d,Z3d,Bx_avg,By_avg,Bz_avg,nsteps = 4000,step_size = 5e-4)
        
            # Declare fieldline tracing seeds from each point in zoom region of current sheet
            seeds = np.array([np.ravel(X3d[:,:,izplane]),np.ravel(Y3d[:,:,izplane]),np.ravel(Z3d[:,:,izplane])]).T
            print("Beginning tracing for",len(seeds),"seeds")
            tracer.trace(seeds, grid)
            print("Complete!")
    
        # Declare array to save all the seeds corresponding to closed field lines, and the flux tube entropy and length
        V_list = np.zeros((len(seeds), 1))
        S_list_eq = np.zeros((len(seeds), 1))
    
        # Declare an interpolator object to compute the volume at each point along the field lines
        V_interpolator = RegularGridInterpolator((X3d[0,:,0], Y3d[:,0,0], Z3d[0,0,:]), np.swapaxes(V_integrand,0,1), bounds_error=False, fill_value=None)

        # Compute entropy content integrand according to : 10.5194/angeo-22-1773-2004
        # S = \int p^{1/\gamma) ds/B
        bx,by,bz = [Bx_avg,By_avg,Bz_avg]/(Bmag_avg)
        p = (pxxS0+pyyS0+pzzS0+pxxS1+pyyS1+pzzS1)/3 # [nPa]
        gamma = 5/3
        S_integrand = p**(1/gamma) / Bmag_avg #[(nPa)^5/3 / nT]
        
        # Declare an interpolator object to compute the volume at each point along the field lines
        print('defining interpolator')
        S_interpolator = RegularGridInterpolator((X3d[0,:,0], Y3d[:,0,0], Z3d[0,0,:]), np.swapaxes(S_integrand,0,1), bounds_error=False, fill_value=None)


        # Interate through each trace
        for iseed, seed in enumerate(seeds):
            
            # Report progress
            if iseed%2e3 == 0:
                print("Checked",iseed,"out of",len(seeds)) 
    
            # Check if its a closed field line, by seeing if both ends are *close to* planet
            trace = tracer.xs[iseed]
            if (np.sum(trace[0,:]**2)<1.2**2) and (np.sum(trace[-1,:]**2)<1.2**2):

                # INTERPOLATE: interpolate the volume value at each point along the field line
                # Compute distances between points
                diffs = np.diff(trace, axis=0)
                l = np.sqrt(np.sum(diffs**2, axis=1))
                V_interp = np.sum(V_interpolator(trace))
                S_interp = np.nansum(S_interpolator(trace))
                V_list[iseed] = np.sum(l*V_interp)
                S_list_eq[iseed] = np.nansum(l*S_interp)


        # Reshape the array
        V_array = V_list.reshape(X3d[:,:,izplane].shape)
        S_array_eq = S_list_eq.reshape(X3d[:,:,izplane].shape)

        ######### NEW SECTION ############
        # Copied from entropy_FAC_map: better practice for S and FAC mapping
        #Declare radial distance of surface map
        r_proj = 1.1

        # declare resolution
        n_long = 64
        n_lat = 128
        theta_lims = np.array([55,90]) * np.pi/180
        long_lims = np.array([120,240]) * np.pi/180
        center = [0,0,0.2]
    
        ZMSM3d = Z3d - 0.2
        RMSM3d = np.sqrt(X3d**2+Y3d**2+(ZMSM3d)**2)
    
        # Create the meshgrid of points at the chosen distance
        phi_ax = np.linspace(*long_lims, n_long)
        theta_ax = np.linspace(*theta_lims, n_lat) # theta in MSO
        phi,theta = np.meshgrid(phi_ax,theta_ax)
        x_proj = r_proj * np.sin(theta) * np.cos(phi) + center[0]
        y_proj = r_proj * np.sin(theta) * np.sin(phi) + center[1]
        z_proj = r_proj * np.cos(theta) + center[2]
            
        # Compute mlat/long meshgrid
        long,lat = np.meshgrid(phi_ax,theta_ax)
        mlat = np.arcsin((z_proj-0.2)/r_proj) * 180/np.pi
        long = long * 180/np.pi
            
        # Translate into a list of seeds for tracing
        print('making seeds')
        point_seeds = np.array([np.ravel(x_proj), np.ravel(y_proj), np.ravel(z_proj)]).T
    
        # Create a one-to-one list to store the entropy in (by default -1 for open field lines)
        S_list_planet = np.zeros_like(np.ravel(x_proj))-1

        # Compute FACs
        J_para = (np.abs(bx)*Jx_avg + np.abs(by)*Jy_avg + np.abs(bz)*Jz_avg)*1e9 # [nA/m^2]
        J_para_x = J_para*np.abs(bx)
        J_para_y = J_para*np.abs(by)
        J_para_z = J_para*np.abs(bz)

        # Create a radial (from dipole center) interpolator for FAC
        FAC_rhat = (J_para_x * X3d + J_para_y * Y3d + J_para_z * ZMSM3d) / (RMSM3d)
        FAC_interpolator = RegularGridInterpolator((X3d[0,:,0], Y3d[:,0,0], Z3d[0,0,:]), np.swapaxes(FAC_rhat,0,1), bounds_error=False, fill_value=None)
        FAC_list = FAC_interpolator(point_seeds)
        FAC_array = FAC_list.reshape(mlat.shape) # [nA/m^2]

        # Set up field line tracing
        print('Getting tracer grid')
        tracer,grid = get_tracer(X3d,Y3d,Z3d,Bx_avg,By_avg,Bz_avg)

        # Trace field lines
        print("beginning tracing of",len(point_seeds),"lines!")
        tracer.trace(point_seeds, grid)
        print("done!")

        # Interate through each trace
        for iseed, seed in enumerate(point_seeds):
            
            # Get x,y,z
            trace = tracer.xs[iseed]
            trace_x = trace[:,0]
            trace_y = trace[:,0]
            trace_z = trace[:,0]

            # Filter to closed field lines
            if (np.sum(trace[0,:]**2)<1.2**2) and (np.sum(trace[-1,:]**2)<1.2**2):
                
                # INTERPOLATE: interpolate the volume value at each point along the field line
                # Compute distances between points
                diffs = np.diff(trace, axis=0)
                l = np.sqrt(np.sum(diffs**2, axis=1))
                S_interp = np.nansum(S_interpolator(trace))
                S_list_planet[iseed] = np.sum(l*S_interp) #[(nPa)^5/3 R_M / nT]

        S_array_planet = S_list_planet.reshape(mlat.shape)

        print("Done! Saving background field and flux tube volume of closed field lines to:",str(dir+"background_field/background_field"))

        # Compose the result
        # Bx,By,Bz are full 3d arrays of the average field [nT]
        # V is the flux tube volume for closed field lines in the magnetic equator (UPDATE MAY 27 2025: PROBABLY WRONG)
        # p is average total pressure [nPa]
        # n is average density
        # S is flux tube entropy from 10.5194/angeo-22-1773-2004 [(nPa)^5/3 R_M / nT]
        # FAC is radial FAC current density [nA/m^2, sign is indep. of field direction, so it indicates up or down]
        background_field = {"Bx":Bx_avg,"By":By_avg,"Bz":Bz_avg,"V":V_array,"p":p_avg,"n":rho,
                           "S_eq":S_array_eq,"S_planet":S_array_planet,'FAC':FAC_array} # Units: nT, nT, nT, R_M/nT, nPa, amu/cc
        
        # Save the data
        pickle.dump(background_field, open(str(dir+"background_field/background_field"), 'wb'))

        break

    if plot_preset == "Adibaticity_1D":

        if read_data:
            # Unpack data
            X3d = data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
            By3d = data3d["By"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
            Bz3d = data3d["Bz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
            #avgs = average_value(["pxxS0","pyyS0","pzzS0","pxxS1","pyyS1","pzzS1",],float(time),0,500,dt = 1.0,type='numpy',path=dir_3D)
            pxxS0 = data3d["pxxS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pyyS0 = data3d["pyyS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pzzS0 = data3d["pzzS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pxxS1 = data3d["pxxS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pyyS1 = data3d["pyyS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pzzS1 = data3d["pzzS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]

            # Unpack average field data
            Bx_avg = background_field["Bx"] # [nT]
            By_avg = background_field["By"] # [nT]
            Bz_avg = background_field["Bz"] # [nT]
            V_avg = background_field["V"] # R_M/nT
            p_avg = background_field["p"] # [nPa]

        # Ratio of specific heats
        gamma = 5/3

        # Compute adiabatic pressure, up to some constant?
        p_adiabat = V_avg**(-gamma)
        # Fix nan and infs
        p_adiabat[np.isnan(p_adiabat)] = 0
        p_adiabat[p_adiabat == np.inf] = 0

        # Compute current pressure
        p = (pxxS0+pyyS0+pzzS0+pxxS1+pyyS1+pzzS1)/3 # [nPa]

        # Attempt to normalize to p
        p_adiabat = p_adiabat/np.min(p_adiabat[p_adiabat>0]) # unclear units

        # Work out where the magnetic equator is
        zplane = 0.2 #np.mean(Y)
        izplane = np.where((Z3d[0,0,:]>zplane))[0][0]

        # Choose y line
        yline = -0.75
        iyline = np.where((Y3d[:,0,0]>yline))[0][0]

        # Set up field line tracing
        PIC_dx = 0.01562501
        ny,nx,nz = Bx3d.shape
        field = np.zeros((nx,ny,nz,3))
        field[:,:,:,0] = np.transpose(Bx3d,axes=[1,0,2])
        field[:,:,:,1] = np.transpose(By3d,axes=[1,0,2])
        field[:,:,:,2] = np.transpose(Bz3d,axes=[1,0,2])
        grid_spacing = [PIC_dx,PIC_dx,PIC_dx]
        grid = VectorGrid(field, grid_spacing, origin_coord = [X3d.min(),Y3d.min(),Z3d.min()])
        nsteps = 1000
        step_size = PIC_dx/2
        tracer = StreamTracer(nsteps, step_size)

        # Define integrand for flux tube volume
        Bmag = np.sqrt(Bx3d**2+By3d**2+Bz3d**2)
        V_integrand = 1/Bmag
        # Fix nan and infs
        V_integrand[np.isnan(V_integrand)] = 0
        V_integrand[V_integrand == np.inf] = 0

        # Declare fieldline tracing seeds from each point along chosen line
        seeds = np.array([np.ravel(X3d[iyline,:,izplane]),np.ravel(Y3d[iyline,:,izplane]),np.ravel(Z3d[iyline,:,izplane])]).T
        print("Beginning tracing for",len(seeds),"seeds")
        tracer.trace(seeds, grid)
        print("Complete!")

         # Declare array to save all the seeds corresponding to closed field lines, and the flux tube entropy and length
        V_list = np.zeros((len(seeds), 1)) - 1 # Use a -1 to indicate open field lines
    
        # Declare an interpolator object to compute the volume at each point along the field lines
        interpolator = RegularGridInterpolator((X3d[0,:,0], Y3d[:,0,0], Z3d[0,0,:]), np.swapaxes(V_integrand,0,1), bounds_error=False, fill_value=None)

        # Interate through each trace
        for iseed, seed in enumerate(seeds):
    
            # Check if its a closed field line, by seeing if both ends are *close to* planet
            trace = tracer.xs[iseed]
            if (np.sum(trace[0,:]**2)<1.2**2) and (np.sum(trace[-1,:]**2)<1.2**2):

                # INTERPOLATE: interpolate the volume value at each point along the field line
                # Compute distances between points
                diffs = np.diff(trace, axis=0)
                l = np.sqrt(np.sum(diffs**2, axis=1))
                V_interp = np.sum(interpolator(trace))
                V_list[iseed] = np.sum(l*V_interp)


        # Reshape the array
        #V_array = V_list.reshape(X3d[iyline,:,izplane].shape)

        # Compute entropy
        H = p[iyline,:,izplane] * np.ravel(V_list)**gamma

        # Declare figure
        fig,ax = plt.subplots(figsize=(9,5), ncols=1, constrained_layout=True)

        ax.plot(X3d[iyline,:,izplane],V_avg[iyline,:],color='tab:blue',linestyle='dashed')
        ax.plot(X3d[iyline,:,izplane],V_list,label='Volume [$R_M$/nT]',color='tab:blue',linestyle='solid')
        ax.plot(X3d[iyline,:,izplane],p_avg[iyline,:,izplane],color='tab:orange',linestyle='dashed')
        ax.plot(X3d[iyline,:,izplane],p[iyline,:,izplane],label='Pressure [nPa]',color='tab:orange',linestyle='solid')
        ax.plot(X3d[iyline,:,izplane],(p_avg[iyline,:,izplane]*(V_avg[iyline,:])**gamma),color='tab:green',linestyle='dashed')
        ax.plot(X3d[iyline,:,izplane],H,label='Entropy [nPa ($R_M$/nT)$^{5/3}$]',color='tab:green',linestyle='solid')
        ax.plot(X3d[iyline,:,izplane],Bz_avg[iyline,:,izplane],color='tab:red',linestyle='dashed')
        ax.plot(X3d[iyline,:,izplane],Bz3d[iyline,:,izplane],label='$B_z$ [nT]',color='tab:red',linestyle='solid')

        ax.set_yscale('log')
        ax.set_xlim(-1.75,-1)
        ax.set_ylim(1e-1,1e3)
        ax.legend()

        ax.set_title(str("Pressure profile along Y = "+str(yline)+" at t = "+time+"s"))
        ax.set_xlabel("X [$R_M$]")

    '''
    if plot_preset == 'Adiabaticity_2D':

        if read_data:
            # Unpack data
            X3d = data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
            By3d = data3d["By"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
            Bz3d = data3d["Bz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
            pxxS0 = data3d["pxxS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pxyS0 = data3d["pxyS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pxzS0 = data3d["pxzS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pyyS0 = data3d["pyyS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pyzS0 = data3d["pyzS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pzzS0 = data3d["pzzS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pxxS1 = data3d["pxxS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pxyS1 = data3d["pxyS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pxzS1 = data3d["pxzS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pyyS1 = data3d["pyyS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pyzS1 = data3d["pyzS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pzzS1 = data3d["pzzS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]

            # Unpack average field data
            Bx_avg = background_field["Bx"] # [nT]
            By_avg = background_field["By"] # [nT]
            Bz_avg = background_field["Bz"] # [nT]
            V_avg = background_field["V"] # R_M/nT

            # Ratio of specific heats
            gamma = 5/3

            # Compute adiabatic pressure, up to some constant?
            p_adiabat = V_avg**(-gamma)
            # Fix nan and infs
            p_adiabat[np.isnan(p_adiabat)] = 0
            p_adiabat[p_adiabat == np.inf] = 0

            # Compute actual pressure
            p = (pxxS0+pyyS0+pzzS0+pxxS1+pyyS1+pzzS1)/3 # [nPa]

            # Attempt to normalize to p
            p_adiabat = p_adiabat/np.mean(p_adiabat) * np.mean(p[:,:,izplane])

        # Work out where the magnetic equator is
        zplane = 0.2 #np.mean(Y)
        izplane = np.where((Z3d[0,0,:]>zplane))[0][0]

        # Declare figure
        fig,axs = plt.subplots(figsize=(18,7), ncols=2, constrained_layout=True)

        # Show adiabatic pressure profile
        
        adiabat_plot = axs[0].imshow(np.log10(p_adiabat), origin='lower', 
                                     extent = [np.min(X3d),np.max(X3d),np.min(Y3d),np.max(Y3d)],cmap='Spectral')
        #plt.imshow(np.log10(p_adiabat),origin='lower')
        # Show instantenous pressure
        p_plot = axs[1].imshow(np.log10(p[:,:,izplane]), origin='lower', vmin = -2, vmax = 2,
                               extent = [np.min(X3d),np.max(X3d),np.min(Y3d),np.max(Y3d)],cmap='Spectral')

        # Configure axes ticks
        x_major_ticks = np.arange(xlims[0], xlims[1], 0.5)
        x_minor_ticks = np.arange(xlims[0], xlims[1], 0.1)
        y_major_ticks = np.arange(ylims[0], ylims[1], 0.4)
        y_minor_ticks = np.arange(ylims[0], ylims[1], 0.1)
    
        ax.set_xticks(x_major_ticks)
        ax.set_xticks(x_minor_ticks, minor=True)
        ax.set_yticks(y_major_ticks)
        ax.set_yticks(y_minor_ticks, minor=True)
                
        # Tidy axes
        for ax in axs:
            ax.set_aspect(1)
            ax.set_xlim(-2.5,-0.75)
            ax.set_ylim(-1.2,1.2)
            ax.set_xlabel("X [$R_M$]")
            ax.set_ylabel("Y [$R_M$]")
            ax.grid()

        # Colorbar
        clb1 = fig.colorbar(adiabat_plot,ax=axs[0],shrink=0.5)
        clb1.ax.set_title("log $p_{adiabat}$ [nPa]")
        clb2 = fig.colorbar(p_plot,ax=axs[1],shrink=0.5)
        clb2.ax.set_title("log $p$ [nPa]")

        # Title
        axs[0].set_title(str("Background adiabatic pressure"))
        axs[1].set_title(str("Total thermal pressure at t = "+time+"s"))    
    '''
    
    if plot_preset == "FR_DF_formation":

        if read_data:
            # Unpack data
            X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            By3d = data3d["By"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bz3d = data3d["Bz"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #nT

            # Unpack average field data
            Bx_avg = background_field["Bx"] # [nT]
            By_avg = background_field["By"] # [nT]
            Bz_avg = background_field["Bz"] # [nT]
            
            # Compute averages
            #avgs = average_value(["Bz","mag_elev"],float(time),-20,-10,dt = 2.0,type='numpy',path=dir_3D)
            
            # Compute derived terms
            #mag_elev3d = np.arctan(Bz3d/np.sqrt(Bx3d**2+By3d**2)) #[rad]
            deltaBz3d = Bz3d - Bz_avg #avgs["Bz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            #deltamag_elev3d = mag_elev3d - avgs["mag_elev"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]

        # Find DFs, neglecting FR overlap
        DF_mask = (deltaBz3d > 10) #& (deltamag_elev3d > np.pi/8) #& (np.invert(FR_mask))
        DF_only_mask = DF_mask & (np.invert(FR_mask))
        FR_only_mask = FR_mask & (np.invert(DF_mask))
        DF_FR_mask = DF_mask & FR_mask

        # Plot
        fig = plt.figure(figsize=(20,12), constrained_layout=True)
        ax1 = fig.add_subplot(121, projection="3d",computed_zorder=False)
        ax2 = fig.add_subplot(122, projection="3d",computed_zorder=False)

        if np.any(FR_mask):
            for ax in [ax1,ax2]: 
                if ax == ax1:
                    # Declare masks which set whether the points are in front of or behind the planet
                    in_front = Z3d > 0
                elif ax == ax2:
                    in_front = Y3d < 0
                behind = np.invert(in_front)
                ax.scatter(X3d[FR_only_mask & in_front],Y3d[FR_only_mask & in_front],Z3d[FR_only_mask & in_front],c=-Z3d[FR_only_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Blues",alpha = 0.02, zorder = 2)
                ax.scatter(X3d[FR_only_mask & behind],Y3d[FR_only_mask & behind],Z3d[FR_only_mask & behind],c=-Z3d[FR_only_mask & behind],vmin=-1.2,vmax=0.9,cmap="Blues",alpha = 0.02, zorder = 0.75)
                ax.scatter(X3d[DF_FR_mask & in_front],Y3d[DF_FR_mask & in_front],Z3d[DF_FR_mask & in_front],c=-Z3d[DF_FR_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Purples",alpha = 0.2, zorder = 2)
                ax.scatter(X3d[DF_FR_mask & behind],Y3d[DF_FR_mask & behind],Z3d[DF_FR_mask & behind],c=-Z3d[DF_FR_mask & behind],vmin=-1.2,vmax=0.9,cmap="Purples",alpha = 0.2, zorder = 0.75)
                ax.scatter(X3d[DF_only_mask & in_front],Y3d[DF_only_mask & in_front],Z3d[DF_only_mask & in_front],c=-Z3d[DF_only_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.02, zorder = 2)
                ax.scatter(X3d[DF_only_mask & behind],Y3d[DF_only_mask & behind],Z3d[DF_only_mask & behind],c=-Z3d[DF_only_mask & behind],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.02, zorder = 0.75)
            
                # Update limits to be centered with max range
                #ax.set_xlim(np.mean(xlims) - max_range / 2, np.mean(xlims) + max_range / 2)
                #ax.set_ylim(np.mean(ylims) - max_range / 2, np.mean(ylims) + max_range / 2)
                #ax.set_zlim(np.mean(zlims) - max_range / 2, np.mean(zlims) + max_range / 2)
                ax.set_xlim(x_region)
                ax.set_ylim(y_region)
                ax.set_zlim(z_region)
                ax.set_aspect('equal')
            
                ax.set_xlabel("X [$R_M$]")
                ax.set_ylabel("Y [$R_M$]")
                ax.set_zlabel("Z [$R_M$]")

        # Show planet
        plot_sphere(ax1,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],zlims=[0,2])
        plot_sphere(ax1,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],zlims=[0,2])
        plot_sphere(ax2,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],ylims=[-2,0])
        plot_sphere(ax2,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],ylims=[-2,0])
        
        # Set viewing angle
        ax1.view_init(elev=90, azim=-90)
        ax2.view_init(elev=0, azim=-90)

        # Save the DF mask for this time step
        #pickle.dump(DF_mask, open(str(dir+"DFs/DF_mask_t_"+'{:06.2f}'.format(round(float(time),2))), 'wb') )

        # Add titles
        ax1.set_title(str("$FRs$ (blue) and DFs (red) at t="+time+"s"),fontsize=12,y=1.0, pad=-14)
        ax2.set_title(str("$FRs$ (blue) and DFs (red) at t="+time+"s"),fontsize=12,y=1.0, pad=-14)

    if plot_preset == "dBz_dt_formation":

        if read_data:
            # Unpack data
            X3d = data3d["X"]
            Y3d = data3d["Y"]
            Z3d = data3d["Z"]
            Bx3d = data3d["Bx"]
            By3d = data3d["By"]
            Bz3d = data3d["Bz"] #nT

            # Unpack average field data
            #Bx_avg = background_field["Bx"] # [nT]
            #By_avg = background_field["By"] # [nT]
            #Bz_avg = background_field["Bz"] # [nT]
            
            # Compute dBz_dt
            dBz_dt = compute_dt(["Bz"],float(time),type='numpy',path=dir_3D)["Bz"]

            # Compute delta B_z
            #deltaBz = Bz3d - Bz_avg
            
        # Find DFs, neglecting FR overlap
        #DF_mask1 = ((dBz_dt > 20) & (dBz_dt < 50) & (deltaBz > 0)) # [nT/s , nT]
        #DF_mask2 = ((dBz_dt > 20) & (dBz_dt < 50)) # & (deltaBz > 0)) # [nT/s , nT]
        DF_mask = ((dBz_dt > 100)) # & (deltaBz > 0)) # [nT/s , nT]
        FR_only_mask = FR_mask & (np.invert(DF_mask))
        DF_only_mask = DF_mask & (np.invert(FR_mask))
        both_mask = DF_mask & FR_mask

        # Plot
        fig = plt.figure(figsize=(20,12), constrained_layout=True)
        ax1 = fig.add_subplot(121, projection="3d",computed_zorder=False)
        ax2 = fig.add_subplot(122, projection="3d",computed_zorder=False)

        if np.any(DF_mask):
            for ax in [ax1,ax2]: 
                if ax == ax1:
                    # Declare masks which set whether the points are in front of or behind the planet
                    in_front = Z3d > 0
                elif ax == ax2:
                    in_front = Y3d < 0
                behind = np.invert(in_front)
                cmaps = ['Reds','Purples','Blues'] #['Blues','Purples','Reds']
                for i,mask in enumerate([DF_only_mask, both_mask, FR_only_mask]):
                    ax.scatter(X3d[mask & in_front],Y3d[mask & in_front],Z3d[mask & in_front],c=-Z3d[mask & in_front],vmin=-1.2,vmax=0.9,cmap=cmaps[i],alpha = 0.05 + i/50, zorder = 2)
                    ax.scatter(X3d[mask & behind],Y3d[mask & behind],Z3d[mask & behind],c=-Z3d[mask & behind],vmin=-1.2,vmax=0.9,cmap=cmaps[i],alpha = 0.05 + i/50, zorder = 0.75)
            
                # Update limits to be centered with max range
                #ax.set_xlim(np.mean(xlims) - max_range / 2, np.mean(xlims) + max_range / 2)
                #ax.set_ylim(np.mean(ylims) - max_range / 2, np.mean(ylims) + max_range / 2)
                #ax.set_zlim(np.mean(zlims) - max_range / 2, np.mean(zlims) + max_range / 2)
                ax.set_xlim(x_region)
                ax.set_ylim(y_region)
                ax.set_zlim(z_region)
                ax.set_aspect('equal')
            
                ax.set_xlabel("X [$R_M$]")
                ax.set_ylabel("Y [$R_M$]")
                ax.set_zlabel("Z [$R_M$]")

                draw_cube_outline(ax, [np.min(X3d),np.max(X3d)], [np.min(Y3d),np.max(Y3d)], [np.min(Z3d),np.max(Z3d)],zorder=1000)

        # Show planet
        plot_sphere(ax1,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=x_region,zlims=z_region)
        plot_sphere(ax1,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=x_region,zlims=z_region)
        plot_sphere(ax2,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=x_region,ylims=y_region,zlims=z_region)
        plot_sphere(ax2,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=x_region,ylims=y_region,zlims=z_region)
        
        # Set viewing angle
        ax1.view_init(elev=90, azim=-90)
        ax2.view_init(elev=0, azim=-90)

        # Save the DF mask for this time step
        pickle.dump(DF_mask, open(str(dir+"DFs/DF_mask_t_"+'{:06.2f}'.format(round(float(time),2))), 'wb') )

        # Add titles
        ax1.set_title(str("DFs (red) at t="+time+"s"),fontsize=12,y=1.0, pad=-14)
        ax2.set_title(str("DFs (red) at t="+time+"s"),fontsize=12,y=1.0, pad=-14)

        plt.show()

    if plot_preset == "DF_heating":

        if read_data:
            # Unpack data
            X3d = data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            By3d = data3d["By"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bz3d = data3d["Bz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #nT
            pxxS0 = data3d["pxxS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pyyS0 = data3d["pyyS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pzzS0 = data3d["pzzS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pxxS1 = data3d["pxxS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
            pyyS1 = data3d["pyyS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]pyzS1 = data3d["pyzS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pzzS1 = data3d["pzzS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]

            # Unpack average field data
            Bx_avg = background_field["Bx"] # [nT]
            By_avg = background_field["By"] # [nT]
            Bz_avg = background_field["Bz"] # [nT]
            p_avg = background_field["p"] # [nPa]
            
            # Compute pressure
            p_tot = (pxxS0+pyyS0+pzzS0+pxxS1+pyyS1+pzzS1)/3
            
            # Compute derived terms
            #mag_elev3d = np.arctan(Bz3d/np.sqrt(Bx3d**2+By3d**2)) #[rad]
            deltaBz3d = Bz3d - Bz_avg #avgs["Bz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            #deltamag_elev3d = mag_elev3d - avgs["mag_elev"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]

        # Find DFs, neglecting FR overlap
        heating_mask = (p_tot-p_avg)/p_avg > 2
        DF_mask = (deltaBz3d > 10) #& (deltamag_elev3d > np.pi/8) #& (np.invert(FR_mask))
        DF_only_mask = DF_mask & (np.invert(heating_mask))
        heating_only_mask = heating_mask & (np.invert(DF_mask))
        DF_heating_mask = heating_mask & DF_mask

        # Plot
        fig = plt.figure(figsize=(20,12), constrained_layout=True)
        ax1 = fig.add_subplot(121, projection="3d",computed_zorder=False)
        ax2 = fig.add_subplot(122, projection="3d",computed_zorder=False)

        if np.any(FR_mask):
            for ax in [ax1,ax2]: 
                if ax == ax1:
                    # Declare masks which set whether the points are in front of or behind the planet
                    in_front = Z3d > 0
                elif ax == ax2:
                    in_front = Y3d < 0
                behind = np.invert(in_front)
                ax.scatter(X3d[heating_only_mask & in_front],Y3d[heating_only_mask & in_front],Z3d[heating_only_mask & in_front],c=-Z3d[heating_only_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Blues",alpha = 0.02, zorder = 2)
                ax.scatter(X3d[heating_only_mask & behind],Y3d[heating_only_mask & behind],Z3d[heating_only_mask & behind],c=-Z3d[heating_only_mask & behind],vmin=-1.2,vmax=0.9,cmap="Blues",alpha = 0.02, zorder = 0.75)
                ax.scatter(X3d[DF_heating_mask & in_front],Y3d[DF_heating_mask & in_front],Z3d[DF_heating_mask & in_front],c=-Z3d[DF_heating_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Purples",alpha = 0.2, zorder = 2)
                ax.scatter(X3d[DF_heating_mask & behind],Y3d[DF_heating_mask & behind],Z3d[DF_heating_mask & behind],c=-Z3d[DF_heating_mask & behind],vmin=-1.2,vmax=0.9,cmap="Purples",alpha = 0.2, zorder = 0.75)
                ax.scatter(X3d[DF_only_mask & in_front],Y3d[DF_only_mask & in_front],Z3d[DF_only_mask & in_front],c=-Z3d[DF_only_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.02, zorder = 2)
                ax.scatter(X3d[DF_only_mask & behind],Y3d[DF_only_mask & behind],Z3d[DF_only_mask & behind],c=-Z3d[DF_only_mask & behind],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.02, zorder = 0.75)
            
                # Update limits to be centered with max range
                ax.set_xlim(np.mean(xlims) - max_range / 2, np.mean(xlims) + max_range / 2)
                ax.set_ylim(np.mean(ylims) - max_range / 2, np.mean(ylims) + max_range / 2)
                ax.set_zlim(np.mean(zlims) - max_range / 2, np.mean(zlims) + max_range / 2)
            
                ax.set_xlabel("X [$R_M$]")
                ax.set_ylabel("Y [$R_M$]")
                ax.set_zlabel("Z [$R_M$]")

        # Show planet
        plot_sphere(ax1,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],zlims=[0,2])
        plot_sphere(ax1,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],zlims=[0,2])
        plot_sphere(ax2,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],ylims=[-2,0])
        plot_sphere(ax2,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],ylims=[-2,0])
        
        # Set viewing angle
        ax1.view_init(elev=90, azim=-90)
        ax2.view_init(elev=0, azim=-90)

        # Add titles
        ax1.set_title(str("$(p-p_{avg})/p_{avg}>2$ (blue) and DFs (red) at t="+time+"s"),fontsize=12,y=1.0, pad=-14)
        ax2.set_title(str("$(p-p_{avg})/p_{avg}>2$ (blue) and DFs (red) at t="+time+"s"),fontsize=12,y=1.0, pad=-14)

    if plot_preset == "DF_example_extraction":

        # Declare starting position
        if iter==0:
            '''
            log:

            FINAL EXAMPLES
            Run3:
            "run3_eg1_DDF":
            t = 165.7: (-1.75,-0.9,0.18)
            
            "run5_eg2_FRDF": (should be run3_eg2_FRDF)
            t = 162.6: (-2.05,0.5,0.2)
            "run3_eg2a_FRDF": ("a" index means full domain mask)
            t = 162.6: (-2.05,0.5,0.2)

            "run3_eg3_DDF":
            t = 140.6: (-1.6,0.5,0.2)
            "run3_eg3a_DDF":
            t = 140.6: (-1.6,0.5,0.2)

            Now using dBz_dt mask
            "run3_eg4a_FRDF": 
            t = 150.50: (-2,0.3,0.15)
            "run3_eg5a_child": 
            t = 157.0: (-1.45,0.22,0.2)

            "run3_eg7a_DDF":
            t = 140.40: (-1.65,0.52,0.2)

            "run3_paper_DDF":
            t = 148.9: (-1.9,0.7,0.17)

            Run4:
            Now using dBz_dt mask
            "run4_eg1a_FRDF": (is actually (mostly) a DDF!!!!)
            t = 121.6: (-2.25,0.4,0.2)
            "run4_eg2a_FRDF":
            t = 118.00: (-2.15,-0.8,-0.05)
            "run4_eg3a_double":
            t = 127.80: (-2.25,0.75,0.15)
            "run4_190_dusk"
            t = 190, loc = (-2.3,0.5,0.2)
            "run4_164_dawn"
            t = 164, loc = (-2.3,-0.3,0.05)
            
            '''
            
            
            DF_start = loc #(-1.5,0.15,0.2)
            DF_name = "run3_paper_DDF"

        if read_data:
            # Unpack data
            X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            uex = data3d["uxS0"] * 1e3/R_M #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] * 1e3/R_M # [R_M/s]
            uey = data3d["uyS0"] * 1e3/R_M #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] * 1e3/R_M # [R_M/s]
            uez = data3d["uzS0"] * 1e3/R_M #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] * 1e3/R_M # [R_M/s]
            Bx3d = data3d["Bx"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            By3d = data3d["By"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bz3d = data3d["Bz"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]

        # Find indices matching DF_start position
        DF_start_indices = (np.where((Y3d[:,0,0]>DF_start[1]))[0][0], 
                            np.where((X3d[0,:,0]>DF_start[0]))[0][0], 
                            np.where((Z3d[0,0,:]>DF_start[2]))[0][0])
            
        if not DF_mask[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]] and iter==0:

            distance, indices = distance_transform_edt(~DF_mask, return_indices=True)
            
            # Extract the index of the nearest 'true' cell
            DF_start_indices = tuple(indices[:, DF_start_indices[0], DF_start_indices[1], DF_start_indices[2]])
            
            # If position doesn't line up with a DF, then show a debug plot
            #print("No DF found at position")
            #fig = plt.figure(figsize=(20,12), constrained_layout=True)
            #ax = fig.add_subplot(111, projection="3d",computed_zorder=False)
            #ax.scatter(X3d[DF_mask],Y3d[DF_mask],Z3d[DF_mask],c='black',alpha=0.01)
            #ax.scatter(X3d[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]],
            #           Y3d[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]],
            #           Z3d[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]],
            #           c='red',alpha=1)
            # # Update limits to be centered with max range
            #ax.set_xlim(np.mean(xlims) - max_range / 2, np.mean(xlims) + max_range / 2)
            #ax.set_ylim(np.mean(ylims) - max_range / 2, np.mean(ylims) + max_range / 2)
            #ax.set_zlim(np.mean(zlims) - max_range / 2, np.mean(zlims) + max_range / 2)
        #
            #ax.set_xlabel("X [$R_M$]")
            #ax.set_ylabel("Y [$R_M$]")
            #ax.set_zlabel("Z [$R_M$]")
            #ax.view_init(elev=90, azim=-90)
            #plt.show()

            #break

        elif not DF_mask[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]]:
            # We drifted outside the DF; try to move back in
            print("landed outside the DF, moving back to the closest one")
            # Compute the distance transform
            distance, indices = distance_transform_edt(~DF_mask, return_indices=True)
            
            # Extract the index of the nearest 'true' cell
            DF_start_indices = tuple(indices[:, DF_start_indices[0], DF_start_indices[1], DF_start_indices[2]])

            # Nearest DF found at:
            print("Nearest DF found at",X3d[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]],
                                       Y3d[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]],
                                       Z3d[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]])

        else:
            # If the position matches up with a DF, work out all the cells attached to it
            print("DF identified at",DF_start)

        # Create a structure array to define connectivity (26-connectivity in 3D)
        structure = np.ones((3, 3, 3), dtype=bool)
        
        # Label the connected components
        labeled_array, num_features = label(DF_mask, structure=structure)
        
        # Find the label of the component that the starting cell belongs to
        starting_label = labeled_array[DF_start_indices]
        
        # Create a mask for the connected component
        mask = (labeled_array == starting_label)
            
        # Declare fig.
        fig = plt.figure(figsize=(20,12), constrained_layout=True)
        ax1 = fig.add_subplot(121, projection="3d",computed_zorder=False)
        ax2 = fig.add_subplot(122, projection="3d",computed_zorder=False)

        # Find mask for other DFs
        other_DFs = DF_mask & np.invert(mask)

        for ax in [ax1,ax2]: 
            if ax == ax1:
                # Declare masks which set whether the points are in front of or behind the planet
                in_front = Z3d > 0
            elif ax == ax2:
                in_front = Y3d < 0
            behind = np.invert(in_front)
            ax.scatter(X3d[mask & in_front],Y3d[mask & in_front],Z3d[mask & in_front],c=-Z3d[mask & in_front],vmin=-1.2,vmax=0.9,cmap="Blues",alpha = 0.1, zorder = 2)
            ax.scatter(X3d[mask & behind],Y3d[mask & behind],Z3d[mask & behind],c=-Z3d[mask & behind],vmin=-1.2,vmax=0.9,cmap="Blues",alpha = 0.1, zorder = 0.75)
            ax.scatter(X3d[other_DFs & in_front],Y3d[other_DFs & in_front],Z3d[other_DFs & in_front],c=-Z3d[other_DFs & in_front],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.01, zorder = 2)
            ax.scatter(X3d[other_DFs & behind],Y3d[other_DFs & behind],Z3d[other_DFs & behind],c=-Z3d[other_DFs & behind],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.01, zorder = 0.75)
            #ax.scatter(X3d[DF_only_mask & in_front],Y3d[DF_only_mask & in_front],Z3d[DF_only_mask & in_front],c=-Z3d[DF_only_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.02, zorder = 2)
            #ax.scatter(X3d[DF_only_mask & behind],Y3d[DF_only_mask & behind],Z3d[DF_only_mask & behind],c=-Z3d[DF_only_mask & behind],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.02, zorder = 0.75)

            # Show DF seed
            ax.scatter(X3d[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]],
                       Y3d[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]],
                       Z3d[DF_start_indices[0],DF_start_indices[1],DF_start_indices[2]],
                       c='black',alpha=1,s=20,marker='x',zorder=10)
            if iter>0:
                ax.scatter(X3d[DF_start_indicies_prior[0],DF_start_indicies_prior[1],DF_start_indicies_prior[2]],
                       Y3d[DF_start_indicies_prior[0],DF_start_indicies_prior[1],DF_start_indicies_prior[2]],
                       Z3d[DF_start_indicies_prior[0],DF_start_indicies_prior[1],DF_start_indicies_prior[2]],
                       c='grey',alpha=1,s=15,marker='o',zorder=10)
            
            # Update limits to be centered with max range
            ax.set_xlim(x_region)
            ax.set_ylim(y_region)
            ax.set_zlim(z_region)
            ax.set_aspect('equal')
        
            ax.set_xlabel("X [$R_M$]")
            ax.set_ylabel("Y [$R_M$]")
            ax.set_zlabel("Z [$R_M$]")

        # Show planet
        plot_sphere(ax1,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],zlims=[0,2])
        plot_sphere(ax1,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],zlims=[0,2])
        plot_sphere(ax2,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],ylims=[-2,0])
        plot_sphere(ax2,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],ylims=[-2,0])
        
        # Set viewing angle
        ax1.view_init(elev=90, azim=-90)
        ax2.view_init(elev=0, azim=-90)    

        # Save this mask
        pickle.dump(mask, open(str(dir+"DFs/DF_"+DF_name+"_t_"+'{:06.2f}'.format(round(float(time),2))), 'wb') )

        # Update the DF seed for the next time step using the electro velocity
        mean_ux = np.mean(uex[mask])
        mean_uy = np.mean(uey[mask])
        mean_uz = np.mean(uez[mask])
        DF_start_indicies_prior = DF_start_indices # Save as backup for next timestep
        DF_start = (DF_start[0] + (mean_ux*dt/3),
                    DF_start[1] + (mean_uy*dt/3),
                    DF_start[2] + (mean_uz*dt/3))

    if plot_preset == "DF_example_filtering":

        '''
        log:
        "dawn_flank_example": start at t = 142.60
        DF_xmax = -1.95
        std_fac = 2.5
        vel_fac = 3
        n_proj = 7

        then
        t = 144.2
        DF_xmax = -1.00
        std_fac = 2.75
        vel_fac = 1.0
        n_proj = 5

        then
        t = 145.8
        DF_xmax = -1.75
        DF_ymin = -0.85
        std_fac = 3.5
        vel_fac = 0.5
        n_proj = 5

        then 
        t = 150.70
        DF_xmax = -0.85
        DF_ymin = -1
        std_fac = 3.5
        vel_fac = 0.1
        n_proj = 5

        "dusk_flank_example": start at t = 151.00
        DF_xmax = -1
        DF_ymin = -0.3
        std_fac = 3.5
        vel_fac = 4
        n_proj = 5

        then 
        t = 152.1
        DF_xmax = -1.5
        DF_ymin = -0.2
        std_fac = 3.5
        vel_fac = 4
        n_proj = 5

        "new_dBz_dusk_example": start at 161.45
        DF_xmax = -1
        DF_ymin = -1
        std_fac = 3.5
        vel_fac = 0.5
        n_proj = 7

        "new_dBz_rope_example": start at 162.65
        DF_xmax = -1
        DF_ymin = -1
        std_fac = 4
        vel_fac = 1
        n_proj = 5

        "new_dBz_df_example"" start at 160.25

        "new_dBz_dawn_example" start at 165.7
        
        "eg5_DDF" start at 140.6

        "eg6_FRDF": starts at 137.5

        "run4_eg1_FRDF": starts at 123 (MISLABELLED AS "eg6_FRDF")

        FINAL EXAMPLES

        "run3_eg1_DDF": starts at 165.70
        
        "run3_eg2_FRDF": starts at 162.6 (MISLABELLED AS "run5_eg2_FRDF")
        "run3_eg2a_FRDF": starts at 162.6

        "run3_eg3_DDF": starts at 140.6
        "run3_eg3a_DDF": starts at 140.6
        "run3_eg4a_FRDF": starts at 150.5
        "run3_eg5a_child": starts at 157
        "run3_eg7a_DDF": starts at 140.45
        "run3_paper_DDF": starts at 148.8

        "run4_eg1a_FRDF": starts at 121.70 (is really a DDF)
        "run4_eg2a_FRDF": starts at 118.00
        "run4_eg3a_double": starts at 127.80
        "run4_190_dusk": starts at 189.9
        '''

        # Set up the parameters for the fitting
        # Choose initial min/max values of DF, e.g. to slice off extensional features in the first timestep
        DF_name = "run3_paper_DDF" # NOTE: Still needs to be manually set, see following paragraph
        DF_xmax = -1.9
        DF_xmin = -3
        DF_ymin = 0.0
        DF_ymax = 1
        std_fac = 4
        vel_fac = 3
        n_proj = 6
        
        # Load in example mask
        # Copy in an example name from above, making sure to add "\" before each underscore
        # e.g. examplefilesDF = get_files(str(dir+"DFs/"),key="DF\_dawn\_flank\_example\_t\_...\...",read_time = True)
        exampleDFfiles = get_files(str(dir+"DFs/"),key="DF\_run3\_paper\_DDF\_t\_...\...",read_time = True)

        exampleDFfile = str(exampleDFfiles[time])
        with open(dir+"DFs/"+exampleDFfile, 'rb') as f:
            print("reading DF example data:",str(dir+"DFs/"+exampleDFfile))
            DF_example_mask = pickle.load(f) 

        if read_data:
            # Unpack data
            X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            uex = data3d["uxS0"]* 1e3/R_M #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [R_M/s]
            uey = data3d["uyS0"]* 1e3/R_M #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [R_M/s]
            uez = data3d["uzS0"]* 1e3/R_M #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [R_M/s]
            uix = data3d["uxS1"]* 1e3/R_M #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [R_M/s]
            uiy = data3d["uyS1"]* 1e3/R_M #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [R_M/s]
            uiz = data3d["uzS1"]* 1e3/R_M #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [R_M/s]
            #Bx3d = data3d["Bx"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            #By3d = data3d["By"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            #Bz3d = data3d["Bz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]

        if iter == 0:
            # At first time step, trim off any bits we don't want
            filtered_mask = np.copy(DF_example_mask)
            filtered_mask[X3d>DF_xmax] = False
            filtered_mask[X3d<DF_xmin] = False
            filtered_mask[Y3d<DF_ymin] = False
            filtered_mask[Y3d>DF_ymax] = False
            # compute the baseline statistics of the DF, for all other steps this is done at the end
            x_mean = np.mean(X3d[filtered_mask])
            x_std = np.std(X3d[filtered_mask])
            y_mean = np.mean(Y3d[filtered_mask])
            y_std = np.std(Y3d[filtered_mask])
            z_mean = np.mean(Z3d[filtered_mask])
            z_std = np.std(Z3d[filtered_mask])
            mean_ux = np.mean(uex[filtered_mask])
            mean_uy = np.mean(uey[filtered_mask])
            mean_uz = np.mean(uez[filtered_mask])

        else:
            # For other steps, trim off data points that are not overlapping or not within x std of the previous one
            # Also trim off all points that re more than ux*dt*vel_fac away from the center
            overlap_mask = DF_example_mask & filtered_mask_prior
            std_mask = DF_example_mask & (X3d > x_mean-std_fac*x_std) & (X3d < x_mean+std_fac*x_std) & (Y3d > y_mean-std_fac*y_std) & (Y3d < y_mean+std_fac*y_std) & (Z3d > z_mean-std_fac*z_std) & (Z3d < z_mean+std_fac*z_std) 
            partial_filtered_mask = overlap_mask | std_mask

            # From all identified cells, project where they are moving to next, to make a "halo" of the DF 
            # To do this, we iterate through each DF cell
            DF_indices = np.argwhere(filtered_mask_prior)
            proj_mask = np.zeros_like(filtered_mask_prior,dtype = bool)

            # Iterate through indices
            for index in DF_indices:
                for ivel_fac in np.linspace(-vel_fac, vel_fac, n_proj):
                # Propogate in the forward direction
                    x_proj = X3d[index[0],index[1],index[2]] + mean_ux * ivel_fac * dt
                    y_proj = Y3d[index[0],index[1],index[2]] + mean_uy * ivel_fac * dt
                    z_proj = Z3d[index[0],index[1],index[2]] + mean_uz * ivel_fac * dt * 2 # HOTFIX: increase dispersion in z direction
                    proj_mask = proj_mask | create_enclosing_box_mask(X3d, Y3d, Z3d, [x_proj,y_proj,z_proj])

            filtered_mask = proj_mask & partial_filtered_mask

        # Declare fig.
        fig = plt.figure(figsize=(20,12), constrained_layout=True)
        ax1 = fig.add_subplot(121, projection="3d",computed_zorder=False)
        ax2 = fig.add_subplot(122, projection="3d",computed_zorder=False)

        # Find mask for other DFs
        #0:if iter==
        excluded = DF_example_mask & np.invert(filtered_mask)
        #else:
        #    halo = proj_mask & np.invert(filtered_mask)
        #halo)    excluded = DF_example_mask & np.invert(

        for ax in [ax1,ax2]: 
            if ax == ax1:
                # Declare masks which set whether the points are in front of or behind the planet
                in_front = Z3d > 0
            elif ax == ax2:
                in_front = Y3d < 0
            behind = np.invert(in_front)
            ax.scatter(X3d[filtered_mask & in_front],Y3d[filtered_mask & in_front],Z3d[filtered_mask & in_front],c=-Z3d[filtered_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Blues",alpha = 0.5, zorder = 2)
            ax.scatter(X3d[filtered_mask & behind],Y3d[filtered_mask & behind],Z3d[filtered_mask & behind],c=-Z3d[filtered_mask & behind],vmin=-1.2,vmax=0.9,cmap="Blues",alpha = 0.5, zorder = 0.75)
            ax.scatter(X3d[excluded & in_front],Y3d[excluded & in_front],Z3d[excluded & in_front],c=-Z3d[excluded & in_front],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.3, zorder = 2)
            ax.scatter(X3d[excluded & behind],Y3d[excluded & behind],Z3d[excluded & behind],c=-Z3d[excluded & behind],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.3, zorder = 0.75)
            #if iter>0:
            #    ax.scatter(X3d[halo & in_front],Y3d[halo & in_front],Z3d[halo & in_front],c=-Z3d[halo & in_front],vmin=-1.2,vmax=0.9,cmap="Oranges",alpha = 0.01, zorder = 2)
            #    ax.scatter(X3d[halo & behind],Y3d[halo & behind],Z3d[halo & behind],c=-Z3d[halo & behind],vmin=-1.2,vmax=0.9,cmap="Oranges",alpha = 0.01, zorder = 0.75)
            
            #ax.scatter(X3d[DF_only_mask & in_front],Y3d[DF_only_mask & in_front],Z3d[DF_only_mask & in_front],c=-Z3d[DF_only_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.02, zorder = 2)
            #ax.scatter(X3d[DF_only_mask & behind],Y3d[DF_only_mask & behind],Z3d[DF_only_mask & behind],c=-Z3d[DF_only_mask & behind],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.02, zorder = 0.75)
        
            ax.set_xlabel("X [$R_M$]")
            ax.set_ylabel("Y [$R_M$]")
            ax.set_zlabel("Z [$R_M$]")

            # Show bounding box
        ax1.plot([x_mean-std_fac*x_std,x_mean+std_fac*x_std,x_mean+std_fac*x_std,x_mean-std_fac*x_std,x_mean-std_fac*x_std], 
                [y_mean-std_fac*y_std,y_mean-std_fac*y_std,y_mean+std_fac*y_std,y_mean+std_fac*y_std,y_mean-std_fac*y_std], 
                [z_mean,z_mean,z_mean,z_mean,z_mean], c='black',alpha=1,zorder=3)
        #ax2.plot([x_mean-std_fac*x_std,x_mean+std_fac*x_std,x_mean+std_fac*x_std,x_mean-std_fac*x_std,x_mean-std_fac*x_std], 
        #        [y_mean,y_mean,y_mean,y_mean,y_mean], 
        #        [z_mean-std_fac*z_std,z_mean-std_fac*z_std,z_mean+std_fac*z_std,z_mean+std_fac*z_std,z_mean-std_fac*z_std], c='black',alpha=1,zorder=3)

        # Show planet
        plot_sphere(ax1,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],zlims=[0,2])
        plot_sphere(ax1,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],zlims=[0,2])
        #plot_sphere(ax2,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],ylims=[-2,0])
        #plot_sphere(ax2,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],ylims=[-2,0])

        # Update ax1 lims
        ax1.set_xlim(x_region)
        ax1.set_ylim(y_region)
        ax1.set_zlim(z_region)
        ax1.set_aspect('equal')

        # Update ax2 lims to be a zoom of the DF
        x_range_DF = 2*std_fac*x_std
        y_range_DF = 2*std_fac*y_std
        z_range_DF = 2*std_fac*z_std
        
        # Find the maximum range among x, y, z
        max_range_DF = max(x_range_DF, y_range_DF, z_range_DF)

        ax1.set_xlim(x_region)
        ax1.set_ylim(y_region)
        ax1.set_zlim(z_region)
        ax1.set_aspect('equal')
        
        # Set viewing angle
        ax1.view_init(elev=90, azim=-90)
        ax2.view_init(elev=45, azim=-45)    

        # Save the current mask to compare to next iteration
        filtered_mask_prior = np.copy(filtered_mask)
        
        # compute the baseline statistics of the DF, which will be used for the next iteration
        x_mean = np.mean(X3d[filtered_mask])
        x_std = np.std(X3d[filtered_mask])
        y_mean = np.mean(Y3d[filtered_mask])
        y_std = np.std(Y3d[filtered_mask])
        z_mean = np.mean(Z3d[filtered_mask])
        z_std = np.std(Z3d[filtered_mask])
        mean_ux = np.mean(uex[filtered_mask])
        mean_uy = np.mean(uey[filtered_mask])
        mean_uz = np.mean(uez[filtered_mask])
        

        # Save the filtered mask
        pickle.dump(filtered_mask, open(str(dir+"DFs/DF_filtered_"+DF_name+"_t_"+'{:06.2f}'.format(round(float(time),2))), 'wb') )

    if plot_preset == "DF_example_summary":

        if read_data:
            time_ls = []
            Theta_z_ls = []
            Theta_y_ls = []
            frac_rope_flux_ls = []
            Bz_df_ls = []
            p_df_ls = []
            V_df_ls = []
            H_df_ls = []
            S_df_ls = []
            x_df_ls = []
            y_df_ls = []
            z_df_ls = []
            n_df_ls = []
            Bz_avg_ls = []
            p_avg_ls = []
            V_avg_ls = []
            H_avg_ls = []
            S_avg_ls = []
            n_avg_ls = []
    
            # Load in example mask
            # Copy in an example name from above, making sure to add "\" before each underscore
            # e.g. examplefilesDF = get_files(str(dir+"DFs/"),key="DF\_dawn\_flank\_example\_t\_...\...",read_time = True)
            DF_name = "run3_paper_DDF"
            exampleDFfiles = get_files(str(dir+"DFs/"),key="DF\_filtered\_run3\_paper\_DDF\_t\_...\...",read_time = True)
            
            for time in exampleDFfiles.keys():
                exampleDFfile = str(exampleDFfiles[time])
                with open(dir+"DFs/"+exampleDFfile, 'rb') as f:
                    print("reading DF example mask:",str(dir+"DFs/"+exampleDFfile))
                    mask = pickle.load(f) 
                file3D = str(files3D[time])
                with open(dir_3D+file3D, 'rb') as f:
                    print("reading 3d data for summary plot: ",str(dir_3D+file3D))
                    data3d = pickle.load(f) 
    
                # Unpack data
                X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
                Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
                Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
                Bx = data3d["Bx"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
                By = data3d["By"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
                Bz = data3d["Bz"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
                rho = data3d["rhoS1"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[amu/cc]
                pxxS0 = data3d["pxxS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                pyyS0 = data3d["pyyS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                pzzS0 = data3d["pzzS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                pxxS1 = data3d["pxxS1"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                pyyS1 = data3d["pyyS1"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                pzzS1 = data3d["pzzS1"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3)*1e9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nA/m^2]
    
                # Unpack average field data
                Bx_avg = background_field["Bx"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
                By_avg = background_field["By"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
                Bz_avg = background_field["Bz"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
                V_avg = background_field["V"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1]] # R_M/nT
                p_avg = background_field["p"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nPa]
                n_avg = background_field["n"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # amu/cc
                S_avg = background_field["S_eq"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1]] # [(nPa)^5/3 R_M / nT]

                # Temp hotfix for weird Volume values
                V_avg[V_avg<1e-3] = 0
                V_avg[V_avg==0] = np.nan

                # adiabatic index
                gamma = 5/3
                
                # Precomputations
                p_tot = (pxxS0+pyyS0+pzzS0+pxxS1+pyyS1+pzzS1)/3
                Bmag = np.sqrt(Bx**2+By**2+Bz**2) #[nT]
                V_integrand = 1/Bmag # [1/nT]
                V_integrand[np.isnan(V_integrand)] = 0
                V_integrand[V_integrand == np.inf] = 0
                S_integrand = p_tot**(1/gamma)/Bmag # [nPa/nT]
                
                # Define masks for slices through DF
                y_mean = np.sum(Y3d[mask]*Bz[mask])/np.sum(Bz[mask])
                iy_mean = np.where(Y3d[:,0,0]>y_mean)[0][0]
                xz_slice_mask = (Y3d == Y3d[iy_mean,0,0]) & mask
                z_mean = np.sum(Z3d[mask]*Bz[mask])/np.sum(Bz[mask])
                iz_mean = np.where(Z3d[0,0,:]>z_mean)[0][0]
                xy_slice_mask = (Z3d == Z3d[0,0,iz_mean]) & mask
                iz_equator = np.where(Z3d[0,0,:]>0.2)[0][0]
                equator_mask = (Z3d == Z3d[0,0,iz_equator]) & mask

                # Set up field line tracing
                tracer,grid = get_tracer(X3d,Y3d,Z3d,Bx,By,Bz,nsteps = 4000,step_size = 5e-4)

                # Declare fieldline tracing seeds from each point in zoom region of current sheet
                seeds = np.array([np.ravel(X3d[mask]),np.ravel(Y3d[mask]),np.ravel(Z3d[mask])]).T
                print("Beginning tracing for",len(seeds),"seeds")
                tracer.trace(seeds, grid)
                print("Complete!")
        
                # Declare array to save all the seeds corresponding to closed field lines, and the flux tube entropy and length
                # -1 indicates an unclosed flux rope
                V_list = np.zeros((len(seeds), 1)) - 1
                S_list = np.zeros((len(seeds), 1)) - 1
            
                # Declare an interpolator object to compute the volume at each point along the field lines
                V_interpolator = RegularGridInterpolator((X3d[0,:,0], Y3d[:,0,0], Z3d[0,0,:]), np.swapaxes(V_integrand,0,1), bounds_error=False, fill_value=None)
                S_interpolator = RegularGridInterpolator((X3d[0,:,0], Y3d[:,0,0], Z3d[0,0,:]), np.swapaxes(S_integrand,0,1), bounds_error=False, fill_value=None)

                # To compute the full Bz flux, we need to subtract out any negative contribution from the negative side of the flux rope
                # We have already traced the needed field lines, we now need to create a mask of all the cells in the xy_slice_mask plane 
                flux_mask = np.zeros_like(xy_slice_mask, dtype=bool) # This will be true in the z_mean plane everywhere that a DF field line intersects (including in non-df regions)
                rope_mask = np.zeros_like(xy_slice_mask, dtype=bool) # This will be true in the z_mean plane for all DF cells that are flux ropey
                
                # Interate through each trace
                for iseed, seed in enumerate(seeds):
            
                    # Check if its a closed field line, by seeing if both ends are *close to* planet
                    trace = tracer.xs[iseed]
                    if (np.sum(trace[0,:]**2)<1.5**2) and (np.sum(trace[-1,:]**2)<1.5**2):
        
                        # INTERPOLATE: interpolate the volume value at each point along the field line
                        # Compute distances between points
                        diffs = np.diff(trace, axis=0)
                        l = np.sqrt(np.sum(diffs**2, axis=1))
                        V_interp = np.nansum(V_interpolator(trace))
                        V_list[iseed] = np.sum(l*V_interp) 
                        S_interp = np.nansum(S_interpolator(trace))
                        S_list[iseed] = np.sum(l*S_interp) #[(nPa)^5/3 R_M / nT]

                    # Work out the coordinates of all of the plane intersections
                    intersections = plane_intersection(trace[:,0], trace[:,1], trace[:,2], plane_z=z_mean)

                    if len(intersections)>0:
                        for intersection in intersections:
                            intersect_iy,intersect_ix,intersect_iz = nearest_cell(X3d, Y3d, Z3d, intersection)
                            #intersect_ix = np.where(X3d[0,:,0]>intersection[0])[0][0]
                            #intersect_iy = np.where(Y3d[:,0,0]>intersection[1])[0][0]
                            flux_mask[intersect_iy,intersect_ix,iz_mean] = True
                        if len(intersections)>1:
                            seed_iy,seed_ix,seed_iz = nearest_cell(X3d, Y3d, Z3d, seed)
                            #seed_ix = np.where(X3d[0,:,0]>seed[0])[0][0]
                            #seed_iy = np.where(Y3d[:,0,0]>seed[1])[0][0]
                            rope_mask[seed_iy,seed_ix,seed_iz] = True

                # Reformat list of flux tube volumes
                V_array = np.ravel(np.array(V_list))
                S_array = np.ravel(np.array(S_list))

                # To avoid averaging over non-closed lines, set -1 to nan
                V_array[V_array == -1] = np.nan
                S_array[S_array == -1] = np.nan

                # Compute entropy using equatorial pressures
                H_array = np.ravel(p_tot[mask]) * (V_array)**gamma
                
                # Compute total Bz flux associated with DF in z_mean plane, including subtractions from flux rope
                Theta_z = np.sum(Bz[flux_mask]) * 1e-9 * (R_M/64)**2 * 1e-6 #[MWb]

                # Save a plot of the DF cross section at this time step, for reference
                fig, axs = plt.subplots(ncols = 2, figsize = (10,5))
    
                plot0 = axs[0].imshow(By[iy_mean,:,:].T,origin='lower',extent = (np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)),
                     cmap = 'coolwarm',vmin=-50,vmax=50)
                axs[0].contour(X3d[iy_mean,:,:],Z3d[iy_mean,:,:],np.sum(mask,axis=0),[0.5],colors=['black'],linestyles='dashed')
                #axs[0].contour(X3d[iy_mean,:,:],Z3d[iy_mean,:,:],Jy3d[iy_mean,:,:],colors='green',levels=[-30,-20,-10],linestyles='solid')
                axs[1].imshow(Bz[:,:,iz_mean],origin='lower',extent = (np.min(X3d),np.max(X3d),np.min(Y3d),np.max(Y3d)),
                              cmap = 'coolwarm',vmin=-50,vmax=50)
                axs[1].contour(X3d[:,:,iz_mean],Y3d[:,:,iz_mean],np.sum(mask,axis=2),[0.5],colors=['black'],linestyles='dashed',linewidths=[2])
                axs[1].contour(X3d[:,:,iz_mean],Y3d[:,:,iz_mean],rope_mask[:,:,iz_mean],colors=['green'],linewidths=[0.9])
                axs[1].contour(X3d[:,:,iz_mean],Y3d[:,:,iz_mean],flux_mask[:,:,iz_mean],colors=['yellow'],linewidths=[0.9])

                # Show some field line traces
                #trace_skip = 1
                #for iseed, seed in enumerate(seeds[::trace_skip,:]):
                #    trace = tracer.xs[iseed]
                #    if np.abs(seed[1]-y_mean)<(1/64):
                #        axs[0].plot(trace[:,0],trace[:,2],linewidth=0.1,color='black')
                # Make a regular grid for plt streamplot
                Xgrid,Zgrid = np.meshgrid(np.linspace(X3d[iy_mean,0,0],X3d[iy_mean,-1,0],len(X3d[iy_mean,:,0])),
                            np.linspace(Z3d[iy_mean,0,0],Z3d[iy_mean,0,-1],len(Z3d[iy_mean,0,:])))

                #streamplot = axs[0].streamplot(Xgrid,Zgrid,Bx[iy_mean,:,:].T,Bz[iy_mean,:,:].T, broken_streamlines=False,
                #                density=10,linewidth=0.05,arrowsize=0,color='black')

                streamplot = axs[0].streamplot(Xgrid,Zgrid,Bx[iy_mean,:,:].T,Bz[iy_mean,:,:].T, broken_streamlines=False,
                   start_points = np.array([X3d[iy_mean,:,iz_mean],Z3d[iy_mean,:,iz_mean]]).T,
                    linewidth=0.05,arrowsize=0,color='black')

                 # Compute flux rope flux content using ellipse fit
                ellipse_mask = fit_FR_ellipse2(streamplot,min_density = 5)
                Theta_y = np.sum(By[iy_mean,:,:][ellipse_mask]*(R_M/64)**2)*1e-12 # [kWb]

                # Compute the ratio of roped to unroped flux
                frac_rope_flux = (np.sum((Bz[rope_mask])) + np.sum(By[iy_mean,:,:][ellipse_mask])) / np.sum(Bz[xy_slice_mask & np.invert(rope_mask)])

                # Show the ellipse fit
                axs[0].contour(X3d[iy_mean,:,:],Z3d[iy_mean,:,:],ellipse_mask)
    
                # Work out range to show to ensure both are evenly sized squares
                # Precomputations for nice axes limits: calculate the range for each axis
                x_range_plot = np.max(X3d[mask]) - np.min(X3d[mask])
                y_range_plot = np.max(Y3d[xz_slice_mask]) - np.min(Y3d[xz_slice_mask])
                z_range_plot = np.max(Z3d[xy_slice_mask]) - np.min(Z3d[xy_slice_mask])
                
                # Find the maximum range among x, y, z
                max_range_plot = 0.6 #max(x_range_plot, y_range_plot, z_range_plot)

                # Tidy axes
                for axi in axs:
                    axi.set_xlabel("X [$R_M$]")
                    axi.set_aspect(1)
                    axi.grid()
                    #axi.set_xlim(np.mean(X3d[xy_slice_mask]) - max_range_plot, np.mean(X3d[xy_slice_mask]) + max_range_plot)
                    axi.set_xlim(*x_region)
                    
                #axs[0].set_ylim(np.mean(Z3d[xy_slice_mask]) - max_range_plot, np.mean(Z3d[xy_slice_mask]) + max_range_plot)
                axs[0].set_ylim(*z_region)
                #axs[1].set_ylim(np.mean(Y3d[xz_slice_mask]) - max_range_plot, np.mean(Y3d[xz_slice_mask]) + max_range_plot)
                axs[1].set_ylim(*y_region)
                    
                # Colorbar
                clb1 = fig.colorbar(plot0,ax=axs[:],shrink=0.5)
                clb1.ax.set_title("$B$ [nT]")
    
                # Label
                axs[0].set_ylabel("Z [$R_M$]")
                axs[1].set_ylabel("Y [$R_M$]")
                axs[0].set_title(str("t = "+"%.2f"%round(float(time),2)+"\n$B_y$ in y = "+str(round(y_mean,2))+" plane"))
                axs[1].set_title(str("$B_z$ in z = "+str(round(z_mean,2))+" plane"))
                
                # Save
                fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"slice_"+DF_name+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi=300)
                plt.show()
                plt.close(fig)

                # Append variables of interest to lists
                time_ls.append(float(time))
                Theta_z_ls.append(Theta_z) # [MWb]
                Theta_y_ls.append(Theta_y) # [kWb]
                frac_rope_flux_ls.append(frac_rope_flux)
                Bz_df_ls.append(np.nanmean(Bz[mask])) # [nT]
                p_df_ls.append(np.nanmean(p_tot[mask])) # [nPa]
                V_df_ls.append(np.nanmean(V_array[V_array>0])) # [R_M/nT]
                H_df_ls.append(np.nanmean(H_array[V_array>0]))
                S_df_ls.append(np.nanmean(S_array[V_array>0]))
                x_df_ls.append(np.nanmean(X3d[mask]))
                y_df_ls.append(np.nanmean(Y3d[mask]))
                z_df_ls.append(np.nanmean(Z3d[mask]))
                n_df_ls.append(np.nanmean(rho[mask]))

                # Save background field quantities for the DF
                Bz_avg_ls.append(np.mean(Bz_avg[mask])) # [nT]
                p_avg_ls.append(np.mean(p_avg[mask])) # [nPa]
                V_avg_ls.append(np.nanmean(V_avg[mask[:,:,iz_equator]])) # [R_M/nT]
                H_avg_ls.append(np.nanmean(p_avg[equator_mask]*V_avg[mask[:,:,iz_equator]]**gamma))
                S_avg_ls.append(np.nanmean(S_avg[mask[:,:,iz_equator]]))
                n_avg_ls.append(np.mean(n_avg[mask])) # [amu/cc]
                
        # Create time series plot
        # Declare figure
        fig,axs = plt.subplots(figsize=(16,8), ncols=2, constrained_layout=True)

        R_ls = np.sqrt(np.array(x_df_ls)**2+np.array(y_df_ls)**2+np.array(z_df_ls)**2)

        x_axes = [R_ls, time_ls]
        x_labels = ["R [$R_M$]","t [sec]"]
        for i, axi in enumerate(axs):
            axi.plot(x_axes[i],n_avg_ls,color='tab:gray',linestyle='dashed')
            axi.plot(x_axes[i],n_df_ls,label='rho [amu/cc]',color='tab:gray',linestyle='solid')
            axi.plot(x_axes[i],V_avg_ls,color='tab:blue',linestyle='dashed')
            axi.plot(x_axes[i],V_df_ls,label='Volume [$R_M$/nT]',color='tab:blue',linestyle='solid')
            axi.plot(x_axes[i],p_avg_ls,color='tab:orange',linestyle='dashed')
            axi.plot(x_axes[i],p_df_ls,label='Pressure [nPa]',color='tab:orange',linestyle='solid')
            axi.plot(x_axes[i],S_avg_ls,color='tab:green',linestyle='dashed')
            #axi.plot(x_axes[i],H_df_ls,label='Entropy [nPa ($R_M$/nT)$^{5/3}$]',color='tab:green',linestyle='solid')
            axi.plot(x_axes[i],S_df_ls,label='Entropy [(nPa)$^{5/3}$ ($R_M$/nT)]',color='tab:green',linestyle='solid')
            axi.plot(x_axes[i],Bz_df_ls,label='$B_z$ [nT]',color='tab:red')
            axi.plot(x_axes[i],Bz_avg_ls,color='tab:red',linestyle='dashed')
            axi.plot(x_axes[i],np.array(np.abs(Theta_y_ls)),label='Y flux content [kWb]',color='tab:purple')
            axi.plot(x_axes[i],frac_rope_flux_ls,label='Rope:Dipolar flux',color='black', linestyle='dotted')
        
            axi.set_yscale('log')
            axi.set_ylim(1e-1,1e3)

            axi.set_xlim(np.min(x_axes[i]),np.max(x_axes[i]))

            axi.grid()

            axi.tick_params(axis='both',labelsize=15)

        '''
        # Show flux content
        ax0 = axs[0].twinx()
        ax1 = axs[1].twinx()
        ax0.plot(R_ls,frac_rope_flux_ls,color='black', linestyle='dotted')
        ax1.plot(time_ls,frac_rope_flux_ls,color='black', linestyle='dotted')
        ax0.set_ylim(0,1)
        ax1.set_ylim(0,1)
        ax0.set_ylabel("Frac. of flux in rope (dotted)")
        ax1.set_ylabel("Frac. of flux in rope (dotted)")
        #ax.set_xticks(np.arange(round(float(time_ls[0]),0), round(float(time_ls[-1]),0)))
        '''
        # legend
        axs[1].legend(fontsize=15,ncols=2)

        axs[0].set_xlabel(x_labels[0])

        # Set manual ticks for axs[1]
        tick_indices = []#[i_plot_start]
        tick_positions = []#[t[i_plot_start]]
        tick_labels = []#["HH:MM:SS   \n$X_{MSM}$ [$R_M$]   \n$Y_{MSM}$ [$R_M$]   \n$Z_{MSM}$ [$R_M$]   "]
        
        for it,itime in enumerate(time_ls):
            if itime*10%5 == 0: #Label cadence = 0.5s
                tick_indices.append(it)
                tick_positions.append(itime)
                tick_labels.append(str(str(itime)+'\n'+str(round(x_df_ls[it],2))+'\n'+str(round(y_df_ls[it],2))+'\n'+str(round(z_df_ls[it],2))))
            
        # Show major ticks
        axs[1].set_xticks(tick_positions, tick_labels, fontsize=15)

        # Invert x axis for axs[0]
        axs[0].xaxis.set_inverted(True)

        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+DF_name+'.png'),bbox_inches='tight',dpi=300)
        plt.show()
        plt.close(fig)

        break

    if plot_preset == "DF_example_visualizer":

        if iter==0 and read_data:
            # Declare lists to save the total energy content of the region
            t_ls = []
            i_kinetic_ls = []
            e_kinetic_ls = []
            i_para_thermal_ls = []
            i_perp_thermal_ls = []
            e_para_thermal_ls = []
            e_perp_thermal_ls = []
            mag_ls = []
            Jy_east_ls = []
            Jy_west_ls = []

        if read_data:

            # Choose which example to plot
            DF_name = "run4_eg1_FRDF"
            exampleDFfiles = get_files(str(dir+"DFs/"),key="DF\_filteredrun4\_eg1\_FRDF\_t\_...\...",read_time = True)

            # Set xmin as the tailmost cell of the DF:
            exampleDFfile = str(exampleDFfiles[time])
            with open(dir+"DFs/"+exampleDFfile, 'rb') as f:
                print("reading DF example mask:",str(dir+"DFs/"+exampleDFfile))
                example_mask = pickle.load(f)
            x_df_min = np.min(data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]][example_mask])
            
            # Set mask region
            # run3_eg2_FRDF:
            #x_region = [-2,-1.5]
            #y_region = [0.1,0.5]
            #z_region = [0,0.4]
            # run3_eg3_DDF:
            #x_region = [-1.6,-0.9]
            #y_region = [0.15,0.55]
            #z_region = [0.0,0.4]
            # run4_eg1_FRDF:
            x_region = [-2.5,-0.6]
            y_region = [-1,0]
            z_region = [-0.3,0.7]

            # Turn this into a new trim region
            trim2_x = np.where((data3d["X"][0,:,0]>x_region[0]) & (data3d["X"][0,:,0]<x_region[1]))[0] #LIMITS MUST MATCH FROM ABOVE FOR MASK TO WORK
            trim2_y = np.where((data3d["Y"][:,0,0]>y_region[0]) & (data3d["Y"][:,0,0]<y_region[1]))[0]
            trim2_z = np.where((data3d["Z"][0,0,:]>z_region[0]) & (data3d["Z"][0,0,:]<z_region[1]))[0]

            # Special treatment for the mask x trim, since it was already created based off a trimmed region
            trimmask_x = np.where((data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]][0,:,0]>x_region[0]) & (data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]][0,:,0]<x_region[1]))[0] 
            trimmask_z = np.where((data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]][0,0,:]>z_region[0]) & (data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]][0,0,:]<z_region[1]))[0] 
            example_mask = example_mask[trim2_y[0]:trim2_y[-1],trimmask_x[0]:trimmask_x[-1],trimmask_z[0]:trimmask_z[-1]]
            
            # Unpack data
            X3d = data3d["X"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]
            Y3d = data3d["Y"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]
            Z3d = data3d["Z"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]
            Bx3d = data3d["Bx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[T]
            By3d = data3d["By"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[T]
            Bz3d = data3d["Bz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[T]
            Xcs = datacs["X"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1]]
            Ycs = datacs["Y"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1]]
            Zcs = datacs["Z"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1]]
            Bxcs = datacs["Bx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1]] #[nT]
            Bycs = datacs["By"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1]] #[nT]
            Bzcs = datacs["Bz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1]] #[nT]
            Ex3d = data3d["Ex"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-6 # V/m
            Ey3d = data3d["Ey"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-6 # V/m
            Ez3d = data3d["Ez"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-6 # V/m
            n = data3d["rhoS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e6 # 1/m^3
            rho = n*amu # [kg/m^3]
            Jx3d = (e*data3d["rhoS1"]*1e6*(data3d["uxS1"]-data3d["uxS0"])*1e3)[trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]] # [A/m^2]
            Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3)[trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]] # [A/m^2]
            Jz3d = (e*data3d["rhoS1"]*1e6*(data3d["uzS1"]-data3d["uzS0"])*1e3)[trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]] # [A/m^2]
            Jy3d_curl = data3d["Jy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]] #[A/m^2]
            dp_dx3d = data3d["dp_dx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa/m] = [N/m^3]
            dp_dy3d = data3d["dp_dy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa/m]
            dp_dz3d = data3d["dp_dz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa/m]
            #dp_perp_dx = data3d["dp_perp_dx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa/m] = [N/m^3]
            #dp_perp_dy = data3d["dp_perp_dy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa/m]
            #dp_perp_dz = data3d["dp_perp_dz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa/m]
            pxxS0 = data3d["pxxS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pxyS0 = data3d["pxyS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pxzS0 = data3d["pxzS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pyyS0 = data3d["pyyS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pyzS0 = data3d["pyzS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pzzS0 = data3d["pzzS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pxxS1 = data3d["pxxS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pxyS1 = data3d["pxyS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pxzS1 = data3d["pxzS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pyyS1 = data3d["pyyS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pyzS1 = data3d["pyzS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            pzzS1 = data3d["pzzS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
            dB_dx = data3d["dB_dx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 # T/m
            dB_dy = data3d["dB_dy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 # T/m
            dB_dz = data3d["dB_dz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 # T/m
            uix3d = data3d["uxS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
            uiy3d = data3d["uyS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
            uiz3d = data3d["uzS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
            duix_dx = data3d["duix_dx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
            duix_dy = data3d["duix_dy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
            duix_dz = data3d["duix_dz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
            duiz_dx = data3d["duiz_dx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
            duiz_dy = data3d["duiz_dy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
            duiz_dz = data3d["duiz_dz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
            uex3d = data3d["uxS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
            uey3d = data3d["uyS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
            uez3d = data3d["uzS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
            #duex_dx = data3d["duex_dx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim_z[0]:trim_z[-1]]*1e3 # [1/s]
            #duex_dy = data3d["duex_dy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim_z[0]:trim_z[-1]]*1e3 # [1/s]
            #duex_dz = data3d["duex_dz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim_z[0]:trim_z[-1]]*1e3 # [1/s]
            #duez_dx = data3d["duez_dx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim_z[0]:trim_z[-1]]*1e3 # [1/s]
            #duez_dy = data3d["duez_dy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim_z[0]:trim_z[-1]]*1e3 # [1/s]
            #duez_dz = data3d["duez_dz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim_z[0]:trim_z[-1]]*1e3 # [1/s]

            # Unpack average field data
            Bx_avg = background_field["Bx"][trim2_y[0]:trim2_y[-1],trimmask_x[0]-1:trimmask_x[-1],trimmask_z[0]:trimmask_z[-1]] # [nT]
            By_avg = background_field["By"][trim2_y[0]:trim2_y[-1],trimmask_x[0]-1:trimmask_x[-1],trimmask_z[0]:trimmask_z[-1]] # [nT]
            Bz_avg = background_field["Bz"][trim2_y[0]:trim2_y[-1],trimmask_x[0]-1:trimmask_x[-1],trimmask_z[0]:trimmask_z[-1]] # [nT]
            V_avg = background_field["V"][trim2_y[0]:trim2_y[-1],trimmask_x[0]-1:trimmask_x[-1]] # R_M/nT
            p_avg = background_field["p"][trim2_y[0]:trim2_y[-1],trimmask_x[0]-1:trimmask_x[-1],trimmask_z[0]:trimmask_z[-1]] # [nPa]

            # work out the center planes of the DF
            if np.any(example_mask): # Check we actually include any of the DF in the region to avoid crash
                y_mean = np.mean(Y3d[example_mask])
                iy_mean = np.where(Y3d[:,0,0]>y_mean)[0][0]
                z_mean = np.mean(Z3d[example_mask])
                iz_mean = np.where(Z3d[0,0,:]>z_mean)[0][0]

            # Compute parallel/perpendiular pressures
            P11S0,P22S0,P33S0 = compute_para_perp(Bx3d,By3d,Bz3d,Ex3d,Ey3d,Ez3d,pxxS0,pyyS0,pzzS0,pxyS0,pxzS0,pyzS0) # [Pa]
            P11S1,P22S1,P33S1 = compute_para_perp(Bx3d,By3d,Bz3d,Ex3d,Ey3d,Ez3d,pxxS1,pyyS1,pzzS1,pxyS1,pxzS1,pyzS1) # [Pa]

            # Compute temperatures
            T_perp_S0 = (P11S0+P22S0)/2/n/k_b #  # [K]
            T_para_S0 = (P33S0)/n/k_b # [K]
            T_perp_S1 = (P11S1+P22S1)/2/n/k_b #  # [K]
            T_para_S1 = (P33S1)/n/k_b # [K]

            # Compute reconnection scores
            L, D_e, APhi, root_Q, S = compute_recon_score(Bx3d,By3d,Bz3d,Ex3d,Ey3d,Ez3d,Jx3d,Jy3d,Jz3d,uex3d,uey3d,uez3d,pxxS0,pyyS0,pzzS0,pxyS0,pxzS0,pyzS0)
            
            # Set up field line tracing
            PIC_dx = 0.01562501
            ny,nx,nz = Bx3d.shape
            field = np.zeros((nx,ny,nz,3))
            field[:,:,:,0] = np.transpose(Bx3d*1e-9,axes=[1,0,2])
            field[:,:,:,1] = np.transpose(By3d*1e-9,axes=[1,0,2])
            field[:,:,:,2] = np.transpose(Bz3d*1e-9,axes=[1,0,2])
            grid_spacing = [PIC_dx,PIC_dx,PIC_dx]
            grid = VectorGrid(field, grid_spacing, origin_coord = [X3d.min(),Y3d.min(),Z3d.min()])
            nsteps = 2000
            step_size = PIC_dx/2
            tracer = StreamTracer(nsteps, step_size)
        
            # Define the seeds
            if np.any(example_mask):
                DF_seed_skip = 10
                DF_seeds = np.array([np.ravel(X3d[example_mask]),np.ravel(Y3d[example_mask]),np.ravel(Z3d[example_mask])]).T[::DF_seed_skip,:]
            cs_seed_skip = 5
            cs_seeds = np.array([np.ravel(Xcs[::cs_seed_skip,::cs_seed_skip]),
                                 np.ravel(Ycs[::cs_seed_skip,::cs_seed_skip]),
                                 np.ravel(Zcs[::cs_seed_skip,::cs_seed_skip])]).T
            # Uncomment to fill the volume with seeds instead of just the CS
            #cs_seeds = np.array([np.ravel(X3d[::cs_seed_skip,::cs_seed_skip,::cs_seed_skip]),np.ravel(Y3d[::cs_seed_skip,::cs_seed_skip,::cs_seed_skip]),np.ravel(Z3d[::cs_seed_skip,::cs_seed_skip,::cs_seed_skip])]).T

            # Plot
            fig = plt.figure(figsize=(16,12), constrained_layout=True)
            ax = fig.add_subplot(111, projection="3d",computed_zorder=False)
            ax.set_proj_type('persp', focal_length=0.2)  # FOV = 157.4 deg
    
            # Show DF example region
            #in_front = Y3d < 0
            #behind = np.invert(in_front)
            #ax.scatter(X3d[example_mask & in_front],Y3d[example_mask & in_front],Z3d[example_mask & in_front],c=-Z3d[example_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.3, zorder = 10)
            #ax.scatter(X3d[example_mask & behind],Y3d[example_mask & behind],Z3d[example_mask & behind],c=-Z3d[example_mask & behind],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.3, zorder = 7.5)
    
            # Draw bounding box
            draw_cube_outline(ax, x_region+[0.02,-0.02], y_region+[0.02,-0.02], z_region+[0.02,-0.02])
    
            # Show projected reconnection score
            S_lims = [5,6]
            xz_surf = plot_colored_surface(ax, X3d[-1,:,:], Y3d[-1,:,:], Z3d[-1,:,:], np.max(S,axis=0), vmin = S_lims[0], vmax = S_lims[1], cmap = 'viridis', alpha = 1, zorder = 2)
            xy_surf = plot_colored_surface(ax, X3d[:,:,0], Y3d[:,:,0], Z3d[:,:,0], np.max(S,axis=2), vmin = S_lims[0], vmax = S_lims[1], cmap = 'viridis', alpha = 1, zorder = 2)
    
            # Plot Bz in cs
            Bz_lims = [-100,100]
            cs_surf = plot_colored_surface(ax, Xcs[:,:], Ycs[:,:], Zcs[:,:], Bzcs, vmin = Bz_lims[0], vmax = Bz_lims[1], cmap = 'bwr', alpha = 0.5, zorder = 5.0, shading = True)
            
            # Show reconnection regions
            #ax.scatter(X3d[recon_mask & in_front],Y3d[recon_mask & in_front],Z3d[recon_mask & in_front],color="Yellow",alpha = 0.5, zorder = 2)
            #ax.scatter(X3d[recon_mask & behind],Y3d[recon_mask & behind],Z3d[recon_mask & behind],color="Yellow",alpha = 0.5, zorder = 0.75)
    
            # Flatten the meshgrid to interpolate Z(x,y)
            down_resolve=10
            griddata_points = np.column_stack((Xcs[::down_resolve,::down_resolve].ravel(), Ycs[::down_resolve,::down_resolve].ravel()))
            griddata_values = Zcs[::down_resolve,::down_resolve].ravel()
            
            # Show DF field lines
            if np.any(example_mask):
                print("starting trace")
                tracer.trace(DF_seeds, grid)
                print("tracing done")
                color = 'green'
                lw = 2.0
                zorder = 1000
                
                for iseed, seed in enumerate(DF_seeds):
                
                    # Unpack line data
                    trace_x = tracer.xs[iseed][:,0]
                    trace_y = tracer.xs[iseed][:,1]
                    trace_z = tracer.xs[iseed][:,2]
            
                    # Work out which parts of the traces are above and below
                    above = np.where(trace_z>=griddata(griddata_points, griddata_values, tracer.xs[iseed][:,0:2], method='linear'))[0]
                    below = np.where(trace_z<=griddata(griddata_points, griddata_values, tracer.xs[iseed][:,0:2], method='linear'))[0]
            
                    # Plot the streamlines as a series of lines, without connecting between places where the indexing jumps
                    start = 0 
                    for j in range(1,len(above)):
                        if (above[j]-above[j-1]>1) or (j==(len(above)-1)):
                            ax.plot(trace_x[above[start:j]],trace_y[above[start:j]],trace_z[above[start:j]],
                                   color=color,lw=lw,zorder=zorder) 
                            start = j
                        
                    start = 0
                    for j in range(1,len(below)):
                        if (below[j]-below[j-1]>1) or (j==(len(below)-1)):
                            ax.plot(trace_x[below[start:j]],trace_y[below[start:j]],trace_z[below[start:j]],
                                   color=color,lw=lw/3,zorder=zorder) 
                            start = j
    
            # Show background field lines
            print("starting trace")
            tracer.trace(cs_seeds, grid)
            print("tracing done")
            color = 'white'
            lw = 0.8
            zorder = 1000
    
            for iseed, seed in enumerate(cs_seeds):
            
                # Unpack line data
                trace_x = tracer.xs[iseed][:,0]
                trace_y = tracer.xs[iseed][:,1]
                trace_z = tracer.xs[iseed][:,2]
    
                # Work out which parts of the traces are above and below
                above = np.where(trace_z>=griddata(griddata_points, griddata_values, tracer.xs[iseed][:,0:2], method='linear'))[0]
                below = np.where(trace_z<=griddata(griddata_points, griddata_values, tracer.xs[iseed][:,0:2], method='linear'))[0]
    
                # Plot the streamlines as a series of lines, without connecting between places where the indexing jumps
                start = 0 
                for j in range(1,len(above)):
                    if (above[j]-above[j-1]>1) or (j==(len(above)-1)):
                        ax.plot(trace_x[above[start:j]],trace_y[above[start:j]],trace_z[above[start:j]],
                               color=color,lw=lw,zorder=zorder) 
                        start = j
                    
                start = 0
                for j in range(1,len(below)):
                    if (below[j]-below[j-1]>1) or (j==(len(below)-1)):
                        ax.plot(trace_x[below[start:j]],trace_y[below[start:j]],trace_z[below[start:j]],
                               color=color,lw=lw/3,zorder=zorder) 
                        start = j
                #ax.plot(trace_x,trace_y,trace_z,c='white',lw=0.5,zorder=100)
            
            ax.set_xlim(x_region)
            ax.set_ylim(y_region)
            ax.set_zlim(z_region)
            ax.set_aspect('equal')
    
            ax.set_title(str("DF fieldlines at t = "+time+"s"))
        
            ax.set_xlabel("X [$R_M$]")
            ax.set_ylabel("Y [$R_M$]")
            ax.set_zlabel("Z [$R_M$]")
    
            # Add a color bar for Bz
            norm = plt.Normalize(*Bz_lims)
            m = cm.ScalarMappable(cmap=cm.bwr, norm=norm)
            m.set_array(Bzcs)
            clb1 = fig.colorbar(m, ax=ax, shrink=0.3, aspect=7)#,anchor=(-0.5,0.3))
            clb1.ax.tick_params(labelsize=12)
            clb1.ax.set_title('$B_{z}$ [nT]',fontsize=12)#,pad=10)
    
            # Add a color bar for S
            norm = plt.Normalize(*S_lims)
            m = cm.ScalarMappable(cmap=cm.viridis, norm=norm)
            m.set_array(S)
            clb2 = fig.colorbar(m, ax=ax, shrink=0.3, aspect=7)#,anchor=(-20,0.9))
            clb2.ax.tick_params(labelsize=12)
            clb2.ax.set_title('S',fontsize=12)#,pad=10)
    
            # Show planet
            plot_sphere(ax,radius=1,color='lightgrey',alpha=0.8,zorder=3,xlims=x_region,ylims=y_region,zlims=zlims)
            plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=2.25,xlims=x_region,ylims=y_region,zlims=zlims)
            
            # Set viewing angle
            ax.view_init(elev=15, azim=-115)
    
            # Save to lists
            t_ls.append(float(time))
            i_kinetic_ls.append(np.nanmean(0.5*rho*(uix3d**2+uiy3d**2+uiz3d**2)*(R_M/64)**3))
            e_kinetic_ls.append(np.nanmean(0.5*rho/mi_me*(uex3d**2+uey3d**2+uez3d**2)*(R_M/64)**3))
            i_para_thermal_ls.append(np.nanmean(3/2 * n * k_b * T_para_S1*(R_M/64)**3))
            i_perp_thermal_ls.append(np.nanmean(3/2 * n * k_b * T_perp_S1*(R_M/64)**3))
            e_para_thermal_ls.append(np.nanmean(3/2 * n * k_b * T_para_S0*(R_M/64)**3))
            e_perp_thermal_ls.append(np.nanmean(3/2 * n * k_b * T_perp_S0*(R_M/64)**3))
            mag_ls.append(np.nanmean((Bx3d**2+By3d**2+Bz3d**2)/(2*mu_0)*(R_M/64)**3))
    
            # Compute max current for each xz
            Jy_max = np.nanmax(Jy3d,axis=0)
            Jy_min = np.nanmin(Jy3d,axis=0)
            Jy_east_ls.append(np.nanmean(Jy3d[Jy3d<0])) # A
            Jy_west_ls.append(np.nanmean(Jy3d[Jy3d>0])) # A

    if plot_preset == "DF_example_visualizer2":

         # Load in example mask
        # Copy in an example name from above, making sure to add "\" before each underscore
        # e.g. examplefilesDF = get_files(str(dir+"DFs/"),key="DF\_dawn\_flank\_example\_t\_...\...",read_time = True)
        DF_name = "run3_eg3_DDF"
        exampleDFfiles = get_files(str(dir+"DFs/"),key="DF\_filtered\_run3\_eg3\_DDF\_t\_...\...",read_time = True)
        #plot_times = ['166.00','168.00','171.00'] # run3_eg1_DDF
        #plot_times = ['162.80','163.90','165.00'] # run3_eg2_FRDF
        plot_times = ['140.80','141.60','142.50'] # run3_eg3_DDF
        #plot_times = ['124.00','128.90','129.95'] # run4_eg1_FRDF
        #plot_times = ['160.50','162.00','163.80']
        #plot_times = ['165.90','167.90','170.15']
        #plot_times = ['140.70','141.70','142.70']
        #plot_times = ['137.90','138.90','141.10'] # eg6
        #df_colors = ['maroon','lightcoral','rosybrown']
        df_colors = ['tomato','turquoise','orchid']
        J_colors = ['darkgreen','mediumseagreen','lightgreen']

        fig,ax = plt.subplots(figsize=(18,8), constrained_layout=True)

        if read_data:
            
            # Unpack average field data
            Bx_avg = background_field["Bx"] # [nT]
            By_avg = background_field["By"] # [nT]
            Bz_avg = background_field["Bz"] # [nT]
            V_avg = background_field["V"] # R_M/nT
            p_avg = background_field["p"] # [nPa]
    
            # Create list to record y slice planes
            y_ls = []
        
            for i,time in enumerate(plot_times):
                exampleDFfile = str(exampleDFfiles[time])
                with open(dir+"DFs/"+exampleDFfile, 'rb') as f:
                    print("reading DF example mask:",str(dir+"DFs/"+exampleDFfile))
                    mask = pickle.load(f) 
                file3D = str(files3D[time])
                with open(dir_3D+file3D, 'rb') as f:
                    print("reading 3d data for summary plot: ",str(dir_3D+file3D))
                    data3d = pickle.load(f) 
    
                # Unpack data
                X3d = data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
                Y3d = data3d["Y"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
                Z3d = data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
                Bx3d = data3d["Bx"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
                By3d = data3d["By"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
                Bz3d = data3d["Bz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [nT]
                rho = data3d["rhoS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[amu/cc]
                pxxS0 = data3d["pxxS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                pyyS0 = data3d["pyyS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                pzzS0 = data3d["pzzS0"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                pxxS1 = data3d["pxxS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                pyyS1 = data3d["pyyS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                pzzS1 = data3d["pzzS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nPa]
                Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3)[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e9 # [nA/m^2]
    
                # Work out which y plane to slice
                y_mean = np.mean(Y3d[mask])
                iy_mean = np.where(Y3d[:,0,0]>y_mean)[0][0]
                y_ls.append(iy_mean)
    
                # Create flattened_mask
                flat_mask = np.sum(mask,axis=0)
                flat_mask = flat_mask > 0
                
                # Show DF outline
                ax.contourf(X3d[iy_mean,:,:],Z3d[iy_mean,:,:],flat_mask,alpha = [0,0.4],colors=[df_colors[i]])
                ax.contour(X3d[iy_mean,:,:],Z3d[iy_mean,:,:],flat_mask,colors='black',linewidths = 0.1)
                # Show Jy contours
                #if i==1:
                #    ax.contourf(X3d[iy_mean,:,:],Z3d[iy_mean,:,:],Jy3d[iy_mean,:,:],levels=[-2000,-100,-10],alpha = [0.4,0.2], colors=J_colors[i])
                #    ax.contour(X3d[iy_mean,:,:],Z3d[iy_mean,:,:],Jy3d[iy_mean,:,:],levels=[-10],colors='black',linewidths = 0.1)
    
                # Use contourplot
                # Show background streamplot
                Xgrid,Zgrid = np.meshgrid(np.linspace(X3d[iy_mean,0,0],X3d[iy_mean,-1,0],len(X3d[iy_mean,:,0])),
                                    np.linspace(Z3d[iy_mean,0,0],Z3d[iy_mean,0,-1],len(Z3d[iy_mean,0,:])))
                # Define the seeds for the traces
                #num_lines = 40
                skip = 6 #len(X3d[mask]) // num_lines
                seeds = np.array([np.ravel(X3d[iy_mean,:,:][mask[iy_mean,:,:]][::skip]),
                                  np.ravel(Z3d[iy_mean,:,:][mask[iy_mean,:,:]][::skip])]).T
                streamplot = ax.streamplot(Xgrid,Zgrid,Bx3d[iy_mean,:,:].T,Bz3d[iy_mean,:,:].T, broken_streamlines=False,
                                           start_points = seeds, linewidth=0.75,color=df_colors[i])
    
                # Show text
                plt.text(np.mean(X3d[mask])-0.1, np.mean(Z3d[mask])+0.35, str("t = "+str(time)+"\nY = "+str(round(y_mean,1))), 
                         color=df_colors[i], fontsize=15, bbox=dict(facecolor='lightgrey', alpha=0.5))
    
            # Compute mean y plane
            avg_iy = int(np.mean(np.array(y_ls)))
            #avg_iy = np.where(Y3d[:,0,0]>avg_y)[0][0]
    
            # Work out average Jy for this time period 
            #Jyavg = average_value(["Jy"],float(plot_times[0]),0,int(float(plot_times[2])-float(plot_times[0])),dt = 0.1,type='numpy',path=dir_3D)["Jy"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] * 1e9 # [nA/m^2]

        # Show background streamplot
        Xgrid,Zgrid = np.meshgrid(np.linspace(X3d[avg_iy,0,0],X3d[avg_iy,-1,0],len(X3d[avg_iy,:,0])),
                            np.linspace(Z3d[avg_iy,0,0],Z3d[avg_iy,0,-1],len(Z3d[avg_iy,0,:])))

        streamplot = ax.streamplot(Xgrid,Zgrid,Bx_avg[avg_iy,:,:].T,Bz_avg[avg_iy,:,:].T, broken_streamlines=False,
                        density = 0.75, linewidth=0.75,color='lightgray')

        # Show avg Jy
        #Jy_plot = ax.imshow(Jyavg[avg_iy,:,:],origin='lower',extent=[np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)],
        #                    vmin=-1000,vmax=1000,cmap='coolwarm')
        #Jy_plot = ax.contourf(X3d[avg_iy,:,:], Z3d[avg_iy,:,:], Jyavg[avg_iy,:,:], [-500,-250,0,250,500],
        #                    vmin=-1000,vmax=1000,cmap='coolwarm')

        ax.set_xlim(-2,-0.75)
        ax.set_ylim(-0.10,0.5)
        ax.set_aspect(1)
            
        # Colorbar
        #clb1 = fig.colorbar(Jy_plot,ax=ax,shrink=0.5)
        #clb1.ax.set_title("avg. $J_y$ [nA/m$^2$]")

        # Label
        ax.set_xlabel("X [$R_M$]")
        ax.set_ylabel("Z [$R_M$]")
        ax.set_title(str("t = "+"%.2f"%round(float(time),2)+"\n$B_z$ in y = "+str(round(y_mean,2))+" plane"))
        
        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+DF_name+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi=300)
        plt.show()
        plt.close(fig)

        break

    if plot_preset == "DF_example_visualizer3":

        # Choose which example to plot
        DF_name = "run3_eg4a_FRDF"
        exampleDFfiles = get_files(str(dir+"DFs/"),key="DF\_filtered\_run3\_eg4a\_FRDF\_t\_...\...",read_time = True)

        #plot_times = ['162.80','164.00','165.20'] # run3_eg2_FRDF
        plot_times = ['150.40','151.70','153.00'] # run3_eg4a_DDF
        #plot_times = ['140.45','141.45','142.45'] # run3_eg7a_DDF
        #plot_times = ['118.00','122.00','124.50'] # run4_eg2a_FRDF
        #plot_times = ['122.00','122.80','123.60'] # run4_eg1a_FRDF
        
        # Set mask region
        # run3_eg2_FRDF:
        #x_region = [-2.5,-0.6]
        #y_region = [-1,0]
        #z_region = [-0.2,0.6]
        # run3_eg3_DDF:
        #x_region = [-1.8,-0.6]
        #y_region = [0,0.75]
        #z_region = [-0.2,0.6]
        
        #df_colors = ['tomato','turquoise','orchid']
        df_colors_patches = ['mediumslateblue','limegreen','gold']
        df_colors_fieldlines = ['mediumblue','seagreen','goldenrod']

        # Plot
        fig = plt.figure(figsize=(24,18), constrained_layout=True)
        ax = fig.add_subplot(111, projection="3d",computed_zorder=False)
        ax.set_proj_type('persp', focal_length=0.5)  # FOV = 157.4 deg

        if read_data:
    
            # Turn this into a new trim region
            trim2_x = np.where((data3d["X"][0,:,0]>x_region[0]) & (data3d["X"][0,:,0]<x_region[1]))[0] 
            trim2_y = np.where((data3d["Y"][:,0,0]>y_region[0]) & (data3d["Y"][:,0,0]<y_region[1]))[0]
            trim2_z = np.where((data3d["Z"][0,0,:]>z_region[0]) & (data3d["Z"][0,0,:]<z_region[1]))[0]

            # Special treatment for the mask x trim, since it was already created based off a trimmed region
            trimmask_x = np.where((data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]][0,:,0]>x_region[0]) & (data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]][0,:,0]<x_region[1]))[0] 
            trimmask_z = np.where((data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]][0,0,:]>z_region[0]) & (data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]][0,0,:]<z_region[1]))[0] 

             # Unpack average field data
            #Bx_avg = background_field["Bx"][trim2_y[0]:trim2_y[-1],trimmask_x[0]:trimmask_x[-1],trimmask_z[0]:trimmask_z[-1]]*1e-9 # [T]
            #By_avg = background_field["By"][trim2_y[0]:trim2_y[-1],trimmask_x[0]:trimmask_x[-1],trimmask_z[0]:trimmask_z[-1]]*1e-9 # [T]
            #Bz_avg = background_field["Bz"][trim2_y[0]:trim2_y[-1],trimmask_x[0]:trimmask_x[-1],trimmask_z[0]:trimmask_z[-1]]*1e-9 # [T]
            #V_avg = background_field["V"][trim2_y[0]:trim2_y[-1],trimmask_x[0]:trimmask_x[-1]] # R_M/nT
            #p_avg = background_field["p"][trim2_y[0]:trim2_y[-1],trimmask_x[0]:trimmask_x[-1],trimmask_z[0]:trimmask_z[-1]] # [nPa]
        
            for i,time in enumerate(plot_times):
                exampleDFfile = str(exampleDFfiles[time])
                with open(dir+"DFs/"+exampleDFfile, 'rb') as f:
                    print("reading DF example mask:",str(dir+"DFs/"+exampleDFfile))
                    example_mask = pickle.load(f)[trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]] 
                file3D = str(files3D[time])
                with open(dir_3D+file3D, 'rb') as f:
                    print("reading 3d data for summary plot: ",str(dir_3D+file3D))
                    data3d = pickle.load(f) 
                filecs = str(filescs[time])
                with open(dir+filecs, 'rb') as f:
                    print("reading cs data for summary plot: ",str(dir+filecs))
                    datacs = pickle.load(f) 
                
                # Unpack data
                X3d = data3d["X"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]
                Xcs = datacs["X"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1]]
                Y3d = data3d["Y"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]
                Ycs = datacs["Y"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1]]
                Z3d = data3d["Z"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]
                Zcs = datacs["Z"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1]]
                Bx3d = data3d["Bx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[T]
                By3d = data3d["By"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[T]
                Bz3d = data3d["Bz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[T]
                Ex3d = data3d["Ex"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-6 # V/m
                Ey3d = data3d["Ey"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-6 # V/m
                Ez3d = data3d["Ez"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-6 # V/m
                n = data3d["rhoS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e6 # 1/m^3
                rho = n*amu # [kg/m^3]
                Jx3d = (e*data3d["rhoS1"]*1e6*(data3d["uxS1"]-data3d["uxS0"])*1e3)[trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]] # [A/m^2]
                Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3)[trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]] # [A/m^2]
                Jz3d = (e*data3d["rhoS1"]*1e6*(data3d["uzS1"]-data3d["uzS0"])*1e3)[trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]] # [A/m^2]
                pxxS0 = data3d["pxxS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pxyS0 = data3d["pxyS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pxzS0 = data3d["pxzS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pyyS0 = data3d["pyyS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pyzS0 = data3d["pyzS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pzzS0 = data3d["pzzS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pxxS1 = data3d["pxxS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pxyS1 = data3d["pxyS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pxzS1 = data3d["pxzS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pyyS1 = data3d["pyyS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pyzS1 = data3d["pyzS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                pzzS1 = data3d["pzzS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 #[Pa]
                dB_dx = data3d["dB_dx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 # T/m
                dB_dy = data3d["dB_dy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 # T/m
                dB_dz = data3d["dB_dz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e-9 # T/m
                uix3d = data3d["uxS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
                uiy3d = data3d["uyS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
                uiz3d = data3d["uzS1"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
                duix_dx = data3d["duix_dx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
                duix_dy = data3d["duix_dy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
                duix_dz = data3d["duix_dz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
                duiz_dx = data3d["duiz_dx"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
                duiz_dy = data3d["duiz_dy"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
                duiz_dz = data3d["duiz_dz"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [1/s]
                uex3d = data3d["uxS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
                uey3d = data3d["uyS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]
                uez3d = data3d["uzS0"][trim2_y[0]:trim2_y[-1],trim2_x[0]:trim2_x[-1],trim2_z[0]:trim2_z[-1]]*1e3 # [m/s]

                # Define plane object to store where the DF is at each time
                if i==0:
                    eq_plane = np.zeros_like(X3d[:,:,0])
                    xz_plane = np.zeros_like(X3d[0,:,:])

                # work out the index for the magnetic equator
                iz_eq = np.where(Z3d[0,0,:]>0.2)[0][0]

                # Compute reconnection scores
                L, D_e, APhi, root_Q, S = compute_recon_score(Bx3d,By3d,Bz3d,Ex3d,Ey3d,Ez3d,Jx3d,Jy3d,Jz3d,uex3d,uey3d,uez3d,pxxS0,pyyS0,pzzS0,pxyS0,pxzS0,pyzS0)

                # Find DF "shadow" in xy and xz
                df_xz_mask = np.sum(example_mask,axis=2)>0 
                X_df_xz = np.ravel(X3d[:,:,iz_eq][df_xz_mask])
                Y_df_xz = np.ravel(Y3d[:,:,iz_eq][df_xz_mask])
                Z_df_xz = np.ravel(Z3d[:,:,iz_eq][df_xz_mask])
                df_xy_mask = np.sum(example_mask,axis=0)>0 
                X_df_xy = np.ravel(X3d[0,:,:][df_xy_mask])
                Y_df_xy = np.ravel(Y3d[0,:,:][df_xy_mask])
                Z_df_xy = np.ravel(Z3d[0,:,:][df_xy_mask])
                eq_plane[df_xz_mask] = i+1
                xz_plane[df_xy_mask] = i+1
                
                # Set up field line tracing
                tracer,grid = get_tracer(X3d,Y3d,Z3d,Bx3d,By3d,Bz3d)
    
                # Define the seeds
                num_traces = 5
                DF_seed_skip = int(np.sum(example_mask) / num_traces)

                DF_seeds = np.array([np.ravel(X3d[example_mask]),np.ravel(Y3d[example_mask]),np.ravel(Z3d[example_mask])]).T#[::DF_seed_skip,:]

                # Confine seeds to equator
                DF_seeds = DF_seeds[np.abs(DF_seeds[:,2]-0.15)<0.05]

                # Reduce number of seeds
                DF_seed_skip = int(len(DF_seeds) / num_traces)
                DF_seeds = DF_seeds[::DF_seed_skip,:]
            
                 # Flatten the meshgrid to interpolate Z(x,y)
                down_resolve=16
                cs_points = np.column_stack((Xcs[::down_resolve,::down_resolve].ravel(), Ycs[::down_resolve,::down_resolve].ravel()))
                cs_values = Zcs[::down_resolve,::down_resolve].ravel()
        
                # Show DF field lines
                tracer.trace(DF_seeds, grid)
                color = df_colors_fieldlines[i]
                lw = 2.0
                zorder_abo = 10 - i
                zorder_bel = 1.75
                
                for iseed, seed in enumerate(DF_seeds):
                
                    # Unpack line data
                    trace_x = tracer.xs[iseed][:,0]
                    trace_y = tracer.xs[iseed][:,1]
                    trace_z = tracer.xs[iseed][:,2]
            
                    # Work out which parts of the traces are above and below
                    above = np.where(tracer.xs[iseed][:,2]>=griddata(cs_points, cs_values, tracer.xs[iseed][:,0:2], method='linear'))[0]
                    below = np.where(tracer.xs[iseed][:,2]<griddata(cs_points, cs_values, tracer.xs[iseed][:,0:2], method='linear'))[0]
            
                    # Plot the streamlines as a series of lines, without connecting between places where the indexing jumps
                    start = 0 
                    for j in range(1,len(above)):
                        if (above[j]-above[j-1]>1) or (j==(len(above)-1)):
                            ax.plot(trace_x[above[start:j]],trace_y[above[start:j]],trace_z[above[start:j]],
                                   color=color,lw=lw,zorder=zorder_abo) 
                            start = j
                        
                    start = 0
                    for j in range(1,len(below)):
                        if (below[j]-below[j-1]>1) or (j==(len(below)-1)):
                            ax.plot(trace_x[below[start:j]],trace_y[below[start:j]],trace_z[below[start:j]],
                                   color=color,lw=lw/2,zorder=zorder_bel,linestyle='dashed') 
                            start = j

                # You can specify a Z vector where the heights for each X, Y are constant
                #ax.plot_trisurf(X_df_xz, Y_df_xz, Z_df_xz, color=df_colors_patches[i], edgecolor='none', zorder=2)

                # Project DF shadow onto walls and floor
                #floor_plot = ax.plot_trisurf(X_df_xz, Y_df_xz, Z_df_xz*0 + z_region[0], color=df_colors_patches[i], edgecolor='none', zorder=0.1)
                #wall_plot = ax.plot_trisurf(X_df_xy, np.linspace(y_region[0],y_region[0]+0.01,len(Z_df_xy)), Z_df_xy, color=df_colors_patches[i], edgecolor='none', zorder=0.1)

                # Optionally show reconnection sites
                S_lims = [5,6]
                if i == 0 :
                    S_plotxy = plot_colored_surface(ax, X3d[:,:,0], Y3d[:,:,0], Z3d[:,:,0], np.max(S,axis=2), vmin = S_lims[0], vmax = S_lims[1], cmap = 'Reds', alpha = 1, zorder = 0.1, 
                                       shading = False, nan_threshold = 0)
                    S_plotxz = plot_colored_surface(ax, X3d[0,:,:], Y3d[0,:,:], Z3d[0,:,:], np.max(S,axis=0), vmin = S_lims[0], vmax = S_lims[1], cmap = 'Reds', alpha = 1, zorder = 0.1, 
                                       shading = False, nan_threshold = 0)
                    # Adding custom shading using semi-opaque layer
                    S_plotxy_shade = plot_colored_surface(ax, X3d[:,:,0], Y3d[:,:,0], Z3d[:,:,0], Z3d[:,:,0], vmin = np.min(Z3d[0,:,:]), vmax = np.max(Z3d[0,:,:]),
                                    cmap = 'gray', alpha = 0.1, zorder = 0.1, shading = False, nan_threshold = -2)
                    S_plotxz_shade = plot_colored_surface(ax, X3d[0,:,:], Y3d[0,:,:], Z3d[0,:,:], Z3d[0,:,:], vmin = np.min(Z3d[0,:,:]), vmax = np.max(Z3d[0,:,:]),
                                    cmap = 'gray', alpha = 0.1, zorder = 0.1, shading = False, nan_threshold = -2)

                # Show text
                ax.text(np.mean(X3d[example_mask])-0.1, np.mean(Y3d[example_mask]), np.mean(Z3d[example_mask])+0.5, str("t = "+str(time)), 
                             color=df_colors_fieldlines[i], fontsize=15, bbox=dict(facecolor='lightgrey', alpha=0.75),zorder=500)

        # Show magnetic equator plane
        eq_plane[np.sqrt(X3d**2+Y3d**2+Z3d**2)[:,:,0]<=0.8] = -1
        eq_plot = plot_colored_surface(ax, X3d[:,:,0], Y3d[:,:,0], Z3d[:,:,0], eq_plane, vmin = 0, vmax = 3, cmap = 'viridis', alpha = 0.5, zorder = 1.7, 
                                       shading = False, nan_threshold = 0.1)

        # Show meridonial plane
        xz_plane[np.sqrt(X3d**2+Y3d**2+Z3d**2)[0,:,:]<=0.8] = -1
        xz_plot = plot_colored_surface(ax, X3d[0,:,:], Y3d[0,:,:], Z3d[0,:,:], xz_plane, vmin = 0, vmax = 3, cmap = 'viridis', alpha = 0.5, zorder = 1.7, 
                                       shading = False, nan_threshold = 0.1)


        
        ax.set_xlim(x_region)
        ax.set_ylim(y_region)
        ax.set_zlim(z_region)
        ax.set_aspect('equal')

        ax.set_title(str("DF fieldlines at t = "+time+"s"))
    
        ax.set_xlabel("X [$R_M$]")
        ax.set_ylabel("Y [$R_M$]")
        ax.set_zlabel("Z [$R_M$]")

        # Add a color bar for S
        norm = plt.Normalize(*S_lims)
        m = cm.ScalarMappable(cmap=cm.Reds, norm=norm)
        m.set_array(S)
        clb2 = fig.colorbar(m, ax=ax, shrink=0.3, aspect=7)#,anchor=(-20,0.9))
        clb2.ax.tick_params(labelsize=12)
        clb2.ax.set_title('S',fontsize=12)#,pad=10)

        # Show planet
        plot_sphere(ax,radius=1,color='lightgrey',alpha=0.9,zorder=1,xlims=x_region,ylims=y_region,zlims=z_region)
        plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=0.75,xlims=x_region,ylims=y_region,zlims=z_region)
        
        # Set viewing angle
        ax.view_init(elev=35, azim=105)

        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+DF_name+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi=300)
        plt.show()
        plt.close(fig)

        break

    if plot_preset == "force_equilib":

        if read_data:
            # Unpack data
            X3d = data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[T]
            By3d = data3d["By"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[T]
            Bz3d = data3d["Bz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[T]
            rho = data3d["rhoS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e6*amu # [kg/m^3]
            Jx3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uxS1"]-data3d["uxS0"])*1e3)[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jy3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uyS1"]-data3d["uyS0"])*1e3)[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jz3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uyS1"]-data3d["uyS0"])*1e3)[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            dp_dx3d = data3d["dp_dx"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa/m] = [N/m^3]
            dp_dy3d = data3d["dp_dy"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa/m]
            dp_dz3d = data3d["dp_dz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa/m]
            
            # Compute derived terms
            JxB_x = (Jy3d*Bz3d - Jz3d*By3d)
            JxB_y = (Jz3d*Bx3d - Jx3d*Bz3d)
            JxB_z = (Jx3d*By3d - Jy3d*Bx3d)

            # Force balance
            F_x = (JxB_x - dp_dx3d) *1e9 * 1e3 #mN / km^3
            F_y = (JxB_y - dp_dy3d) *1e9 * 1e3 #mN / km^3
            F_z = (JxB_z - dp_dz3d) *1e9 * 1e3 #mN / km^3

            # MHD acceleration
            a_x = (JxB_x - dp_dx3d) / rho / 1e3 #km/s /s
            a_y = (JxB_y - dp_dy3d) / rho / 1e3 #km/s /s
            a_z = (JxB_z - dp_dz3d) / rho / 1e3 #km/s /s

            # Diamagnetic current
            Jy3d_dia = (Bz3d*dp_dx3d - Bx3d*dp_dz3d) / (Bx3d**2+By3d**2+Bz3d**2)

        # Plot
        fig = plt.figure(figsize=(24,30), constrained_layout=True)
        ax1a = fig.add_subplot(321, projection="3d",computed_zorder=False)
        ax1b = fig.add_subplot(322, projection="3d",computed_zorder=False)
        ax2a = fig.add_subplot(323, projection="3d",computed_zorder=False)
        ax2b = fig.add_subplot(324, projection="3d",computed_zorder=False)
        ax3a = fig.add_subplot(325, projection="3d",computed_zorder=False)
        ax3b = fig.add_subplot(326, projection="3d",computed_zorder=False)

        # Declare masks which set whether the points are in front of or behind the planet
        above = (X3d**2 + Y3d**2 > 1) | (Z3d>0)
        in_front = (X3d**2 + Z3d**2 > 1) | (Y3d<0)
        
        if np.any(DF_mask): 
            # Plotting variables
            vmin = -0.01
            vmax = 0.01
            cmap = 'bwr'

            # Compute value-based alphas
            alphax = np.abs(F_x)/vmax #np.abs((JxB_x - dp_dx3d)[DF_mask & in_front]*1e14)/vmax
            alphay = np.abs(F_y)/vmax
            alphaz = np.abs(F_z)/vmax
            alphax[alphax>1] = 1
            alphay[alphay>1] = 1
            alphaz[alphaz>1] = 1
            
            plot_xa = ax1a.scatter(X3d[DF_mask & above],Y3d[DF_mask & above],Z3d[DF_mask & above],
                       c=F_x[DF_mask & above],cmap=cmap,alpha = alphax[DF_mask & above], zorder = 2, vmin=vmin,vmax=vmax)
            plot_xb = ax1b.scatter(X3d[DF_mask & in_front],Y3d[DF_mask & in_front],Z3d[DF_mask & in_front],
                       c=F_x[DF_mask & in_front],cmap=cmap,alpha = alphax[DF_mask & in_front], zorder = 2, vmin=vmin,vmax=vmax)
            
            plot_ya = ax2a.scatter(X3d[DF_mask & above],Y3d[DF_mask & above],Z3d[DF_mask & above],
                       c=F_y[DF_mask & above],cmap=cmap,alpha = alphay[DF_mask & above], zorder = 2, vmin=vmin,vmax=vmax)
            plot_yb = ax2b.scatter(X3d[DF_mask & in_front],Y3d[DF_mask & in_front],Z3d[DF_mask & in_front],
                       c=F_y[DF_mask & in_front],cmap=cmap,alpha = alphay[DF_mask & in_front], zorder = 2, vmin=vmin,vmax=vmax)
            
            plot_za = ax3a.scatter(X3d[DF_mask & above],Y3d[DF_mask & above],Z3d[DF_mask & above],
                       c=F_z[DF_mask & above],cmap=cmap,alpha = alphaz[DF_mask & above], zorder = 2, vmin=vmin,vmax=vmax)
            plot_zb = ax3b.scatter(X3d[DF_mask & in_front],Y3d[DF_mask & in_front],Z3d[DF_mask & in_front],
                       c=F_z[DF_mask & in_front],cmap=cmap,alpha = alphaz[DF_mask & in_front], zorder = 2, vmin=vmin,vmax=vmax)

            plot_ls = [plot_xa,plot_ya,plot_za,plot_xb,plot_yb,plot_zb]
            
            for axi,ax in enumerate([ax1a,ax2a,ax3a,ax1b,ax2b,ax3b]):
                # Show -Jy
                Jy_min=-50*1e-9
                if axi<3:
                    mask = (Jy3d<Jy_min) & above
                else:
                    mask = (Jy3d<Jy_min) & in_front
                Jy_diff = ((Jy3d-Jy3d_dia)/Jy3d_dia)
                alphaj = (2 - np.abs(Jy_diff))/2
                alphaj[alphaj<0.05]=0.05
                Jy_plot = ax.scatter(X3d[mask],Y3d[mask],Z3d[mask],c = Jy_diff[mask], cmap = 'viridis', 
                           vmin=-2,vmax=2, alpha=alphaj[mask], zorder = 2)
                
                # Update limits to be centered with max range
                ax.set_xlim(np.mean(xlims) - max_range / 2, np.mean(xlims) + max_range / 2)
                ax.set_ylim(np.mean(ylims) - max_range / 2, np.mean(ylims) + max_range / 2)
                ax.set_zlim(np.mean(zlims) - max_range / 2, np.mean(zlims) + max_range / 2)
            
                ax.set_xlabel("X [$R_M$]")
                ax.set_ylabel("Y [$R_M$]")
                ax.set_zlabel("Z [$R_M$]")

                # Show planet
                if axi<3:
                    plot_sphere(ax,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],zlims=[0,2])
                    plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],zlims=[0,2])
                else:
                    plot_sphere(ax,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],ylims=[-2,0])
                    plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],ylims=[-2,0])

                # Set viewing angle
                if axi<3:
                    ax.view_init(elev=90, azim=-90)
                else:
                    ax.view_init(elev=0, azim=-90)

                # Colorbar
                clb1 = fig.colorbar(plot_ls[axi],ax=ax,shrink=0.5)
                clb1.ax.set_title('Net force [mN/km$^3$]',fontsize=10,pad=10)
                Jy_clb_plot = ax.scatter(-100,-100,-100,c = [0],cmap='viridis',vmin=-2,vmax=2)
                clb2 = fig.colorbar(Jy_clb_plot,ax=ax,shrink=0.5)
                clb2.ax.set_title('$J_{y,tot} - J_{y,dia}$',fontsize=10,pad=10)

        # Add titles
        ax1a.set_title(str("X Force balance at t="+time+"s"),fontsize=12,y=1.0, pad=-14)
        ax2a.set_title(str("Y Force balance at t="+time+"s"),fontsize=12,y=1.0, pad=-14)
        ax3a.set_title(str("Y Force balance at t="+time+"s"),fontsize=12,y=1.0, pad=-14)

    if plot_preset == "Jy_components":

        # Control whether to plot based off DF
        read_DF = False

        if read_data:

            if iter==0:
                t_ls = []
                Jy_ls = []
                Jperp_y_ls = []
                Jy_dia_ls = []
                Jy_inrt_i_ls = []
                delta_Jperp_y_ls = []
    
            if read_DF:
                # Load in DF mask
                exampleDFfiles = get_files(str(dir+"DFs/"),key="DF\_filtered\_run3\_eg2a\_FRDF\_t\_...\...",read_time = True)
                
                exampleDFfile = str(exampleDFfiles[time])
                with open(dir+"DFs/"+exampleDFfile, 'rb') as f:
                    print("reading DF example mask:",str(dir+"DFs/"+exampleDFfile))
                    DF_mask = pickle.load(f) 
            
            # Unpack data
            X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[T]
            By3d = data3d["By"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[T]
            Bz3d = data3d["Bz"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[T]
            dBx_dx = data3d["dBx_dx"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[T/m]
            dBx_dy = data3d["dBx_dy"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[T/m]
            dBx_dz = data3d["dBx_dz"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[T/m]
            dBz_dx = data3d["dBz_dx"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[T/m]
            dBz_dy = data3d["dBz_dy"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[T/m]
            dBz_dz = data3d["dBz_dz"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[T/m]
            Ex3d = data3d["Ex"] #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # Hopefully in V/m
            Ey3d = data3d["Ey"] #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # Hopefully in V/m
            Ez3d = data3d["Ez"] #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # Hopefully in V/m
            n = data3d["rhoS1"]*1e6 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # 1/m^3
            rho = n*amu # [kg/m^3]
            Jx3d = (e*data3d["rhoS1"]*1e6*(data3d["uxS1"]-data3d["uxS0"])*1e3) #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3) #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jz3d = (e*data3d["rhoS1"]*1e6*(data3d["uzS1"]-data3d["uzS0"])*1e3) #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jy3d_curl = data3d["Jy"] #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[A/m^2]
            dp_dx3d = data3d["dp_dx"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa/m] = [N/m^3]
            dp_dy3d = data3d["dp_dy"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa/m]
            dp_dz3d = data3d["dp_dz"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa/m]
            dp_perp_dx = data3d["dp_perp_dx"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa/m] = [N/m^3]
            dp_perp_dy = data3d["dp_perp_dy"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa/m]
            dp_perp_dz = data3d["dp_perp_dz"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa/m]
            pxxS0 = data3d["pxxS0"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pxyS0 = data3d["pxyS0"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pxzS0 = data3d["pxzS0"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pyyS0 = data3d["pyyS0"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pyzS0 = data3d["pyzS0"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pzzS0 = data3d["pzzS0"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pxxS1 = data3d["pxxS1"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pxyS1 = data3d["pxyS1"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pxzS1 = data3d["pxzS1"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pyyS1 = data3d["pyyS1"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pyzS1 = data3d["pyzS1"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            pzzS1 = data3d["pzzS1"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[Pa]
            dB_dx = data3d["dB_dx"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # T/m
            dB_dy = data3d["dB_dy"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # T/m
            dB_dz = data3d["dB_dz"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # T/m
            uix = data3d["uxS1"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [m/s]
            uiy = data3d["uyS1"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [m/s]
            uiz = data3d["uzS1"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [m/s]
            duix_dx = data3d["duix_dx"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            duix_dy = data3d["duix_dy"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            duix_dz = data3d["duix_dz"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            duiz_dx = data3d["duiz_dx"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            duiz_dy = data3d["duiz_dy"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            duiz_dz = data3d["duiz_dz"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            uex = data3d["uxS0"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [m/s]
            uey = data3d["uyS0"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [m/s]
            uez = data3d["uzS0"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [m/s]
            duex_dx = data3d["duex_dx"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            duex_dy = data3d["duex_dy"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            duex_dz = data3d["duex_dz"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            duez_dx = data3d["duez_dx"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            duez_dy = data3d["duez_dy"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
            duez_dz = data3d["duez_dz"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [1/s]
    
            # Compute time derivatives
            dts = compute_dt(["uxS1","uzS1","uxS0","uzS0"],time,type='numpy',path=dir_3D)
            duix_dt = dts["uxS1"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[m/s^2]
            duiz_dt = dts["uzS1"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[m/s^2]
            duex_dt = dts["uxS0"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[m/s^2]
            duez_dt = dts["uzS0"]*1e3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[m/s^2]

            # Compute average current
            #avgs = average_value(["Bx","By","Bz","Jx","Jy","Jz"],float(time),-10,10,dt = 5.0,type='numpy',path=dir_3D)
            #Bx_avg = avgs["Bx"]*1e-9 # [T]
            #By_avg = avgs["By"]*1e-9 # [T]
            #Bz_avg = avgs["Bz"]*1e-9 # [T]
            #Jx_avg = avgs["Jx"]
            #Jy_avg = avgs["Jy"]
            #Jz_avg = avgs["Jz"]
            

            # Diamagnetic current
            #Jy3d_dia = (Bz3d*dp_dx3d - Bx3d*dp_dz3d) / (Bx3d**2+By3d**2+Bz3d**2)
    
            # Inertial current
            Bmag = np.sqrt(Bx3d**2+By3d**2+Bz3d**2) # T
            Jy3d_inrt_i = rho / Bmag**2 * ( Bz3d*(duix_dt + uix*duix_dx + uiy*duix_dy + uiz*duix_dz) - Bx3d*(duiz_dt + uix*duiz_dx + uiy*duiz_dy + uiz*duiz_dz) ) # [A/m^2]
            Jy3d_inrt_e = rho/mi_me / Bmag**2 * ( Bz3d*(duex_dt + uex*duex_dx + uey*duex_dy + uez*duex_dz) - Bx3d*(duez_dt + uex*duez_dx + uey*duez_dy + uez*duez_dz) ) # [A/m^2]
    
            # Compute FAC
            Jpara = (Jx3d*Bx3d + Jy3d*By3d + Jz3d*Bz3d)/Bmag
            Jpara_y = Jpara * By3d/Bmag
            Jperp_y = Jy3d - Jpara_y
    
            # Compute average FAC
            #Bmag_avg = np.sqrt(Bx_avg**2+By_avg**2+Bz_avg**2) # T
            #Jpara_avg= (Jx_avg*Bx_avg+Jy_avg*By_avg*Jz_avg*Bz_avg)/Bmag_avg
            #Jpara_y_avg = Jpara_avg * By_avg/Bmag_avg
            #Jperp_y_avg = Jy_avg - Jpara_y_avg
            
            # Compute grad-curv drift ring current
            # Compute magnetic unit vector
            bx,by,bz = 1/Bmag * [Bx3d,By3d,Bz3d]
        
            # Compute magnetic field unit vectors: u,v,b. u is parallel to ExB direction. v = uxb. We already have b.
            # In the end, our new system will be perp1, perp2, para (u,v,b). See back page of space physics 2024 notebook for derivation, or photos from sep. 4, 2024
            # Compute u first, making sure to normalize
            ux,uy,uz = [Ey3d*Bz3d-Ez3d*By3d, Ez3d*Bx3d-Ex3d*Bz3d, Ex3d*By3d-Ey3d*Bx3d]
            ux,uy,uz = 1/np.sqrt(ux**2+uy**2+uz**2)*[ux,uy,uz]
        
            # Then compute v (already normalized):
            vx,vy,vz = [uy*bz-uz*by, uz*bx-ux*bz, ux*by-uy*bx]
        
            # We now compute the three diagonal pressure components, using e.g. P_para = b . p . b
            P11S0 = ux*(pxxS0*ux+pxyS0*uy+pxzS0*uz) + uy*(pxyS0*ux+pyyS0*uy+pyzS0*uz) + uz*(pxzS0*ux+pyzS0*uy+pzzS0*uz)
            P22S0 = vx*(pxxS0*vx+pxyS0*vy+pxzS0*vz) + vy*(pxyS0*vx+pyyS0*vy+pyzS0*vz) + vz*(pxzS0*vx+pyzS0*vy+pzzS0*vz)
            P33S0 = bx*(pxxS0*bx+pxyS0*by+pxzS0*bz) + by*(pxyS0*bx+pyyS0*by+pyzS0*bz) + bz*(pxzS0*bx+pyzS0*by+pzzS0*bz)
            P11S1 = ux*(pxxS1*ux+pxyS1*uy+pxzS1*uz) + uy*(pxyS1*ux+pyyS1*uy+pyzS1*uz) + uz*(pxzS1*ux+pyzS1*uy+pzzS1*uz)
            P22S1 = vx*(pxxS1*vx+pxyS1*vy+pxzS1*vz) + vy*(pxyS1*vx+pyyS1*vy+pyzS1*vz) + vz*(pxzS1*vx+pyzS1*vy+pzzS1*vz)
            P33S1 = bx*(pxxS1*bx+pxyS1*by+pxzS1*bz) + by*(pxyS1*bx+pyyS1*by+pyzS1*bz) + bz*(pxzS1*bx+pyzS1*by+pzzS1*bz)
    
            # Compute perp/para total pressures
            p_perp = (P11S0+P22S0+P11S1+P22S1)/2
            p_para = P33S0 + P33S1
        
            # For now, we are interested in the difference parallel (Pzz) and perpendicular (Pxx and Pyy)
            # See Gurnett+Bhattacharjee "Introduction to Plasma Physics: Kinetic Theory and the Moment Equations" p151
            T_perp_S0 = (P11S0+P22S0)/2/n/k_b #  # [K]
            T_para_S0 = (P33S0)/n/k_b # [K]
            T_perp_S1 = (P11S1+P22S1)/2/n/k_b #  # [K]
            T_para_S1 = (P33S1)/n/k_b # [K]
    
            # Compute perp and para thermal velocities
            v_perp_th_S0 = np.sqrt(8*k_b*T_perp_S0/(np.pi*amu/mi_me)) #m/s
            v_para_th_S0 = np.sqrt(8*k_b*T_para_S0/(np.pi*amu/mi_me)) #m/s
            v_perp_th_S1 = np.sqrt(8*k_b*T_perp_S1/(np.pi*amu)) #m/s
            v_para_th_S1 = np.sqrt(8*k_b*T_para_S1/(np.pi*amu)) #m/s
        
            # Compute gc drift in m/s
            v_gc_y_S0 = -1 * amu/mi_me*(0.5*v_perp_th_S0**2+v_para_th_S0**2)/e*(Bz3d*dB_dx-Bx3d*dB_dz)/Bmag**3
            v_gc_y_S1 = amu*(0.5*v_perp_th_S1**2+v_para_th_S1**2)/e*(Bz3d*dB_dx-Bx3d*dB_dz)/Bmag**3
    
            # Compute gc current
            Jy3d_gc = e*n*(v_gc_y_S1-v_gc_y_S0)
    
            # Compute full diamagnetic current: J_{dia} = \frac{\vec{B}}{B^2} \left[ \nabla p_\perp - \left(p_\perp - p_\parallel\right)\frac{\left(\vec{B}\cdot\nabla\right) \vec{B}}{B}\right]
            Jy3d_dia = (1/Bmag**2) * (Bz3d * (dp_perp_dx - (p_perp-p_para)/Bmag*(Bx3d*dBx_dx+By3d*dBx_dy+Bz3d*dBx_dz)) - 
                                      Bx3d * (dp_perp_dz - (p_perp-p_para)/Bmag*(Bx3d*dBz_dx+By3d*dBz_dy+Bz3d*dBz_dz)))
            
            # Choose the yplane to plot
            if read_DF:
                yplane = np.mean(Y3d[DF_mask])
                DF_shadow = np.sum(DF_mask,axis=0)
            else:
                yplane = loc[1] #np.mean(Y)
            iy = np.where((Y3d[:,0,0]>yplane))[0][0]
    
            # Declare figure
            fig,axs = plt.subplots(figsize=(9,24), nrows=6, constrained_layout=True)
            cmap = "PRGn"
            J_lims = [-350,350]
        
            plot0 = axs[0].imshow(Jy3d[iy,:,:].T*1e9, origin='lower',vmin=J_lims[0],vmax=J_lims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
            plot1 = axs[1].imshow(Jperp_y[iy,:,:].T*1e9, origin='lower',vmin=J_lims[0],vmax=J_lims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
            axs[2].imshow(Jy3d_dia[iy,:,:].T*1e9, origin='lower',vmin=J_lims[0],vmax=J_lims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
            axs[3].imshow(Jy3d_inrt_i[iy,:,:].T*1e9, origin='lower',vmin=J_lims[0],vmax=J_lims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
            plot4 = axs[4].imshow((Jperp_y-Jy3d_dia-Jy3d_inrt_i)[iy,:,:].T*1e9, origin='lower',vmin=J_lims[0],vmax=J_lims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
            #plot4 = axs[4].imshow((Jperp_y-Jperp_y_avg)[iy,:,:].T*1e9, origin='lower',vmin=-500,vmax=500,cmap='bwr',extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
            plot5 = axs[5].imshow(np.log10(np.abs(Jy3d_dia[iy,:,:])/np.abs(Jy3d_inrt_i[iy,:,:])).T, cmap = 'rainbow',
                                  origin='lower',vmin=-1,vmax=1,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])

            # AQuiver plot
            #qskip = 3
            #vnorm = 200*1e3 # [km/s]
            #axs[4].quiver(X3d[iy,::qskip,::qskip],Z3d[iy,::qskip,::qskip],
            #      -uix[iy,::qskip,::qskip]/vnorm,uiz[iy,::qskip,::qskip]/vnorm,color='red',scale = 60)
    
            # Add context overlays
            Xgrid,Zgrid = np.meshgrid(np.linspace(X3d[iy,0,0],X3d[iy,-1,0],len(X3d[iy,:,0])),
                        np.linspace(Z3d[iy,0,0],Z3d[iy,0,-1],len(Z3d[iy,0,:])))
            for axi in axs:
                if read_DF:
                    axi.contour(X3d[iy,:,:],Z3d[iy,:,:],DF_shadow,[0.5],color='black')
                axi.streamplot(Xgrid,Zgrid,Bx3d[iy,:,:].T,Bz3d[iy,:,:].T, broken_streamlines=False, linewidth=0.25,arrowsize=0.25,color='darkblue',
                              density = 2)
                
            
            # Tidy axes
            for ax in axs:
                ax.set_aspect(1)
                ax.set_xlim(x_region[1],x_region[0])
                ax.set_ylim(z_region)
                ax.set_xlabel("X [$R_M$]")
                ax.set_ylabel("Z [$R_M$]")
                ax.grid()
    
            # Colorbar
            clb1 = fig.colorbar(plot0,ax=axs[:5],shrink=0.1)
            clb1.ax.set_title("$J_y$ [nA/m$^2$]")
            clb2 = fig.colorbar(plot5,ax=axs[5],shrink=0.6)
            clb2.ax.set_title("log$_{10}(J_{dia}/J_{inrt})$")
            #clb3 = fig.colorbar(plot4,ax=axs[4:],shrink=0.5)
            #clb3.ax.set_title("Residual")
    
            # Title
            axs[0].set_title(str("Y = "+str(round(yplane,2))+"\nTotal current at t = "+time+"s"))
            axs[1].set_title(str("Perpendicular current at t = "+time+"s"))
            axs[2].set_title(str("Diamagnetic current at t = "+time+"s"))
            axs[3].set_title(str("Ion inertial current at t = "+time+"s"))
            axs[4].set_title(str("$J_{perp} - J_{dia} - J_{inrt}$ at t = "+time+"s"))
            #axs[4].set_title(str("$\Delta J_{perp,y}$ at t = "+time+"s"))
            axs[5].set_title(str("Diamagnetic:inertial at t = "+time+"s"))
    
            # Save data for timeseries plotting
            if read_DF:
                sum_mask = X3d[iy,:,:] > np.min(X3d[DF_mask])
                #axs[0].contourf(X3d[iy,:,:],Z3d[iy,:,:],sum_mask,[0.5,10],colors=['green'],alpha=0.5)
            else:
                sum_mask = X3d[iy,:,:] > -1.75

            t_ls.append(float(time))
            Jy_ls.append(np.nansum(Jy3d[iy,:,:][sum_mask]*(R_M/64)**2))
            Jperp_y_ls.append(np.nansum(Jperp_y[iy,:,:][sum_mask]*(R_M/64)**2))
            Jy_dia_ls.append(np.nansum(Jy3d_dia[iy,:,:][sum_mask]*(R_M/64)**2))
            Jy_inrt_i_ls.append(np.nansum(Jy3d_inrt_i[iy,:,:][sum_mask]*(R_M/64)**2))
            #delta_Jperp_y_ls.append(np.nansum((Jperp_y-Jperp_y_avg)[iy,:,:][sum_mask]*(R_M/64)**2))

    if plot_preset == "deltaBz_xz":

        if read_data:
            
            # Unpack data
            X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"] # [nT]
            By3d = data3d["By"] # [nT]
            Bz3d = data3d["Bz"] # [nT]
            uix = data3d["uxS1"] # [km/s]
            uiy = data3d["uyS1"] # [km/s]
            uiz = data3d["uzS1"] # [km/s]
            uex = data3d["uxS0"] # [km/s]
            uey = data3d["uyS0"] # [km/s]
            uez = data3d["uzS0"] # [km/s]

            # Compute average
            #Bz_avg = average_value(["Bz"],float(time),-5,0,dt = 0.25,type='numpy',path=dir_3D)["Bz"]
            
            # Compute delta Bz
            deltaBz = Bz3d - background_field["Bz"] # Bz_avg

            yplane = loc[1] #np.mean(Y)
            iy = np.where((Y3d[:,0,0]>yplane))[0][0]
    
        # Declare figure
        fig,ax = plt.subplots(figsize=(9,7), nrows=1, constrained_layout=True)

        B_lims = [-20,20]

        # Plot delta Bz
        plot0 = ax.imshow(deltaBz[iy,:,:].T, origin='lower',vmin=B_lims[0],vmax=B_lims[1],cmap='bwr',extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])

        # Add streamplot for field lines
        Xgrid,Zgrid = np.meshgrid(np.linspace(X3d[iy,0,0],X3d[iy,-1,0],len(X3d[iy,:,0])),
                    np.linspace(Z3d[iy,0,0],Z3d[iy,0,-1],len(Z3d[iy,0,:])))
        ax.streamplot(Xgrid,Zgrid,Bx3d[iy,:,:].T,Bz3d[iy,:,:].T, broken_streamlines=False, linewidth=0.45,arrowsize=0.45,color='black',
                     density = 2.5)
        
        # Show DF cross section
        #ax.contourf(X3d[iy,:,:],Z3d[iy,:,:],DF_mask[iy,:,:],[0.999,1],colors='green',alpha=0.5)
        #ax.contour(X3d[iy,:,:],Z3d[iy,:,:],DF_mask[iy,:,:],[0.999],colors='green',linewidths=0.5)

        # Show velocity quivers
        qskip = 3
        vnorm = 200 # [km/s]
        ax.quiver(X3d[iy,::qskip,::qskip],Z3d[iy,::qskip,::qskip],
                  -uix[iy,::qskip,::qskip]/vnorm,uiz[iy,::qskip,::qskip]/vnorm,color='purple',scale = 60)
        
        # Tidy axes
        ax.set_aspect(1)
        ax.set_xlim(x_region[1],x_region[0])
        ax.set_ylim(z_region)
        ax.set_xlabel("X [$R_M$]")
        ax.set_ylabel("Z [$R_M$]")
        ax.grid()

        # Colorbar
        clb1 = fig.colorbar(plot0,ax=ax,shrink=0.7)
        clb1.ax.set_title("$\Delta B_z$ [nT]")

        # Title
        ax.set_title(str("Y = "+str(round(yplane,2))+"\n$\Delta B_z$, fieldlines, bulk velocity, and DFs at t = "+time+"s"))

    if plot_preset == "p_xz":

        if read_data:
            
            # Unpack data
            X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"] # [nT]
            By3d = data3d["By"] # [nT]
            Bz3d = data3d["Bz"] # [nT]
            pxxS0 = data3d["pxxS0"] #[nPa]
            pyyS0 = data3d["pyyS0"] #[nPa]
            pzzS0 = data3d["pzzS0"] #[nPa]
            pxxS1 = data3d["pxxS1"] #[nPa]
            pyyS1 = data3d["pyyS1"] #[nPa]
            pzzS1 = data3d["pzzS1"] #[nPa]
            uix = data3d["uxS1"] # [km/s]
            uiy = data3d["uyS1"] # [km/s]
            uiz = data3d["uzS1"] # [km/s]
            uex = data3d["uxS0"] # [km/s]
            uey = data3d["uyS0"] # [km/s]
            uez = data3d["uzS0"] # [km/s]

            # Compute average
            #Bz_avg = average_value(["Bz"],float(time),-5,0,dt = 0.25,type='numpy',path=dir_3D)["Bz"]
            
            # Compute p_tot
            p_tot = pxxS0+pyyS0+pzzS0+pxxS1+pyyS1+pzzS1

            yplane = loc[1] #np.mean(Y)
            iy = np.where((Y3d[:,0,0]>yplane))[0][0]
    
        # Declare figure
        fig,ax = plt.subplots(figsize=(9,7), nrows=1, constrained_layout=True)

        p_lims = [0,5]

        # Plot delta Bz
        plot0 = ax.imshow(p_tot[iy,:,:].T, origin='lower',vmin=p_lims[0],vmax=p_lims[1],cmap='rainbow',extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])

        # Add streamplot for field lines
        Xgrid,Zgrid = np.meshgrid(np.linspace(X3d[iy,0,0],X3d[iy,-1,0],len(X3d[iy,:,0])),
                    np.linspace(Z3d[iy,0,0],Z3d[iy,0,-1],len(Z3d[iy,0,:])))
        ax.streamplot(Xgrid,Zgrid,Bx3d[iy,:,:].T,Bz3d[iy,:,:].T, broken_streamlines=False, linewidth=0.45,arrowsize=0.45,color='black',
                     density = 2.5)
        
        # Show DF cross section
        #ax.contourf(X3d[iy,:,:],Z3d[iy,:,:],DF_mask[iy,:,:],[0.999,1],colors='green',alpha=0.5)
        #ax.contour(X3d[iy,:,:],Z3d[iy,:,:],DF_mask[iy,:,:],[0.999],colors='green',linewidths=0.5)

        # Show velocity quivers
        #qskip = 3
        #vnorm = 200 # [km/s]
        #ax.quiver(X3d[iy,::qskip,::qskip],Z3d[iy,::qskip,::qskip],
        #          -uix[iy,::qskip,::qskip]/vnorm,uiz[iy,::qskip,::qskip]/vnorm,color='purple',scale = 60)
        
        # Tidy axes
        ax.set_aspect(1)
        ax.set_xlim(x_region[1],x_region[0])
        ax.set_ylim(z_region)
        ax.set_xlabel("X [$R_M$]")
        ax.set_ylabel("Z [$R_M$]")
        ax.grid()

        # Colorbar
        clb1 = fig.colorbar(plot0,ax=ax,shrink=0.7)
        clb1.ax.set_title("$p_{tot}$ [nPa]")

        # Title
        ax.set_title(str("Y = "+str(round(yplane,2))+"\nPlasma pressure and fieldlines at t = "+time+"s"))

    if plot_preset == "xz_slice":

        # Stores data for available plotmodes
        B_dict = {"plot_var":str("B"),"nrows":3,"ncols":1,"clims":[-50,50],"units": str("nT"),"cmap":"bwr","fieldline":True,"vector":False}
        pi_dict = {"plot_var":str("p_i"),"nrows":3,"ncols":1,"clims":[-2,0.5],"units": str("log(nPa)"),"cmap":"viridis","fieldline":True,"vector":False}
        pe_dict = {"plot_var":str("p_e"),"nrows":3,"ncols":1,"clims":[-2,0.5],"units": str("log(nPa)"),"cmap":"viridis","fieldline":True,"vector":False}
        ui_dict = {"plot_var":str("u_i"),"nrows":3,"ncols":1,"clims":[-1000,1000],"units": str("km/s"),"cmap":"bwr","fieldline":True,"vector":False}
        ue_dict = {"plot_var":str("u_e"),"nrows":3,"ncols":1,"clims":[-2000,2000],"units": str("km/s"),"cmap":"bwr","fieldline":True,"vector":False}
        JxB_dict = {"plot_var":str("JxB"),"nrows":3,"ncols":1,"clims":[-1,1],"units": str("uN/m^3"),"cmap":"rainbow","fieldline":True,"vector":False}
        recon_dict = {"plot_var":str("u_{e,mag}"),"nrows":1,"ncols":1,"clims":[-2,2],"units": str("log($u_{e,para}/u_{e_perp}$)"),"cmap":"coolwarm","fieldline":True,"vector":True}
        T_dict = {"plot_var":str("T"),"nrows":2,"ncols":1,"clims":[-1,1],"units": str("log(keV)"),"cmap":"viridis","fieldline":True,"vector":False}

        plot_variables = [B_dict,pi_dict,pe_dict,ui_dict,ue_dict,recon_dict,T_dict]

        if read_data:
            
            # Unpack data
            X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"] # [nT]
            By3d = data3d["By"] # [nT]
            Bz3d = data3d["Bz"] # [nT]
            Ex3d = data3d["Ex"] # 
            Ey3d = data3d["Ey"] # 
            Ez3d = data3d["Ez"] # 
            rho = data3d["rhoS1"] # amu / cc
            n = rho*1e6
            pxxS0 = data3d["pxxS0"] #[nPa]
            pxyS0 = data3d["pxyS0"] #[nPa]
            pxzS0 = data3d["pxzS0"] #[nPa]
            pyyS0 = data3d["pyyS0"] #[nPa]
            pyzS0 = data3d["pyzS0"] #[nPa]
            pzzS0 = data3d["pzzS0"] #[nPa]
            pxxS1 = data3d["pxxS1"] #[nPa]
            pxyS1 = data3d["pxyS1"] #[nPa]
            pxzS1 = data3d["pxzS1"] #[nPa]
            pyyS1 = data3d["pyyS1"] #[nPa]
            pyzS1 = data3d["pyzS1"] #[nPa]
            pzzS1 = data3d["pzzS1"] #[nPa]
            uix = data3d["uxS1"] # [km/s]
            uiy = data3d["uyS1"] # [km/s]
            uiz = data3d["uzS1"] # [km/s]
            uex = data3d["uxS0"] # [km/s]
            uey = data3d["uyS0"] # [km/s]
            uez = data3d["uzS0"] # [km/s]
            Jx3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uxS1"]-data3d["uxS0"])*1e3) # [A/m^2]
            Jy3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uyS1"]-data3d["uyS0"])*1e3) # [A/m^2]
            Jz3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uzS1"]-data3d["uzS0"])*1e3) # [A/m^2]
            dB_dx = data3d["dB_dx"]*1e-9 # T/m
            dB_dy = data3d["dB_dy"]*1e-9 # T/m
            dB_dz = data3d["dB_dz"]*1e-9 # T/m
            if recon_dict in plot_variables:
                if "dBx_dx" in data3d.keys():
                    dBx_dx = data3d["dBx_dx"]*1e-9  #[T/m]
                    dBx_dy = data3d["dBx_dy"]*1e-9  #[T/m]
                    dBx_dz = data3d["dBx_dz"]*1e-9  #[T/m]
                    dBy_dx = data3d["dBy_dx"]*1e-9  #[T/m]
                    dBy_dy = data3d["dBy_dy"]*1e-9  #[T/m]
                    dBy_dz = data3d["dBy_dz"]*1e-9  #[T/m]
                    dBz_dx = data3d["dBz_dx"]*1e-9  #[T/m]
                    dBz_dy = data3d["dBz_dy"]*1e-9  #[T/m]
                    dBz_dz = data3d["dBz_dz"]*1e-9  #[T/m]
                else:
                    # Jacobian of magnetic field vector field was added late, not all numpy files have it.
                    print("Computing dBa_da jacobian")
                    dBx_dx = np.gradient(Bx3d*1e-9,dx,axis=1) #[T/m]
                    dBx_dy = np.gradient(Bx3d*1e-9,dx,axis=0) #[T/m]
                    dBx_dz = np.gradient(Bx3d*1e-9,dx,axis=2) #[T/m]
                    dBy_dx = np.gradient(By3d*1e-9,dx,axis=1) #[T/m]
                    dBy_dy = np.gradient(By3d*1e-9,dx,axis=0) #[T/m]
                    dBy_dz = np.gradient(By3d*1e-9,dx,axis=2) #[T/m]
                    dBz_dx = np.gradient(Bz3d*1e-9,dx,axis=1) #[T/m]
                    dBz_dy = np.gradient(Bz3d*1e-9,dx,axis=0) #[T/m]
                    dBz_dz = np.gradient(Bz3d*1e-9,dx,axis=2) #[T/m]
            
            # Compute p_tot
            p_tot = pxxS0+pyyS0+pzzS0+pxxS1+pyyS1+pzzS1

            yplane = loc[1] #np.mean(Y)
            iy = np.where((Y3d[:,0,0]>yplane))[0][0]

        for iplot, plot_dict in enumerate(plot_variables):
    
            # Declare figure
            fig,axs = plt.subplots(figsize=(8*plot_dict["ncols"],5*plot_dict["nrows"]), ncols=plot_dict["ncols"], nrows=plot_dict["nrows"], 
                                  constrained_layout=True)

            # Unpack plot controls
            clims = plot_dict["clims"]
            plot_var = plot_dict["plot_var"]
            cmap = plot_dict["cmap"]

            # Title (can be overwritten)
            np.atleast_1d(axs).flat[0].set_title(str("Y = "+str(round(yplane,2))+"\n$"+plot_var+"$ at t = "+time+"s"))

            # Show the data
            if plot_var == "B":
                plot0 = axs[0].imshow(Bx3d[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot1 = axs[1].imshow(By3d[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot2 = axs[2].imshow(Bz3d[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])

                clb = fig.colorbar(plot0,ax=axs[:],shrink=0.2)
                clb.ax.set_title(str("$"+plot_var+"$ ["+plot_dict["units"]+"]"))

            if plot_var == "p_i":
                p_perp1,p_perp2,p_para = compute_para_perp(Bx3d,By3d,Bz3d,Ex3d,Ey3d,Ez3d,pxxS1,pyyS1,pzzS1,pxyS1,pxzS1,pyzS1)
                plot0 = axs[0].imshow(np.log10(p_perp1+p_perp2)[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot1 = axs[1].imshow(np.log10(p_para)[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot2 = axs[2].imshow(np.log10(pxxS1+pyyS1+pzzS1)[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])

                # Overwrite title
                axs[0].set_title(str("Y = "+str(round(yplane,2))+"\n$p_{i,perp}$ at t = "+time+"s"))
                axs[1].set_title(str("$p_{i,para}$"))
                axs[2].set_title(str("$p_{i,tot}$"))

                clb = fig.colorbar(plot0,ax=axs[:],shrink=0.2)
                clb.ax.set_title(str("$"+plot_var+"$ ["+plot_dict["units"]+"]"))

            if plot_var == "p_e":
                p_perp1,p_perp2,p_para = compute_para_perp(Bx3d,By3d,Bz3d,Ex3d,Ey3d,Ez3d,pxxS0,pyyS0,pzzS0,pxyS0,pxzS0,pyzS0)
                plot0 = axs[0].imshow(np.log10(p_perp1+p_perp2)[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot1 = axs[1].imshow(np.log10(p_para)[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot2 = axs[2].imshow(np.log10(pxxS1+pyyS1+pzzS1)[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])

                # Overwrite title
                axs[0].set_title(str("Y = "+str(round(yplane,2))+"\n$p_{e,perp}$ at t = "+time+"s"))
                axs[1].set_title(str("$p_{e,para}$"))
                axs[2].set_title(str("$p_{e,tot}$"))

                clb = fig.colorbar(plot0,ax=axs[:],shrink=0.2)
                clb.ax.set_title(str("$"+plot_var+"$ ["+plot_dict["units"]+"]"))
            
            if plot_var == "u_i":
                plot0 = axs[0].imshow(uix[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot1 = axs[1].imshow(uiy[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot2 = axs[2].imshow(uiz[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])

                clb = fig.colorbar(plot0,ax=axs[:],shrink=0.2)
                clb.ax.set_title(str("$"+plot_var+"$ ["+plot_dict["units"]+"]"))

            if plot_var == "u_e":
                plot0 = axs[0].imshow(uex[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot1 = axs[1].imshow(uey[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot2 = axs[2].imshow(uez[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])

                clb = fig.colorbar(plot0,ax=axs[:],shrink=0.2)
                clb.ax.set_title(str("$"+plot_var+"$ ["+plot_dict["units"]+"]"))

            if plot_var == "JxB":
                JxB_x = (Jy3d*Bz3d - Jz3d*By3d)*1e6
                JxB_y = (Jz3d*Bx3d - Jx3d*Bz3d)*1e6
                JxB_z = (Jx3d*By3d - Jy3d*Bx3d)*1e6
                plot0 = axs[0].imshow(JxB_x[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot1 = axs[1].imshow(JxB_y[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot2 = axs[2].imshow(JxB_z[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])

                clb = fig.colorbar(plot0,ax=axs[:],shrink=0.2)
                clb.ax.set_title(str("$"+plot_var+"$ ["+plot_dict["units"]+"]"))

            if plot_var == "T":
                T_i = (pxxS1+pyyS1+pzzS1)/3 * 1e-9 / (n * k_b) / 11605 / 1e3
                T_e = (pxxS0+pyyS0+pzzS0)/3 * 1e-9 / (n * k_b) / 11605 / 1e3
                plot0 = axs[0].imshow(np.log10(T_i)[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot1 = axs[1].imshow(np.log10(T_e)[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])

                # Overwrite title
                axs[0].set_title(str("Y = "+str(round(yplane,2))+"\n$T_{i}$ at t = "+time+"s"))
                axs[1].set_title(str("$T_{e,}$"))

                clb = fig.colorbar(plot0,ax=axs[:],shrink=0.2)
                clb.ax.set_title(str("$"+plot_var+"$ ["+plot_dict["units"]+"]"))

            if plot_var == "u_{e,mag}":
                # Compute perp/para pressure
                p_perp1,p_perp2,p_para = compute_para_perp(Bx3d,By3d,Bz3d,Ex3d,Ey3d,Ez3d,pxxS0,pyyS0,pzzS0,pxyS0,pxzS0,pyzS0)
                perp_para = (p_perp1+p_perp2)/p_para
                # Compute reconnection sites
                recon_sites = compute_AEPIC_recon_score(Bx3d*1e-9,By3d*1e-9,Bz3d*1e-9,dB_dx,dB_dy,dB_dz,
                                                        dBx_dx,dBx_dy,dBx_dz,dBy_dx,dBy_dy,dBy_dz,dBz_dx,dBz_dy,dBz_dz,Jx3d,Jy3d,Jz3d,
                                                        c1_min = 0.005, c2_min = 5e-7)
                # Compute vperp/vpara
                B_mag = np.sqrt(Bx3d**2+By3d**2+Bz3d**2) # [nT]
                bx,by,bz = 1/B_mag * [Bx3d,By3d,Bz3d]
                ue_para = np.abs(uex*bx+uey*by+uez*bz)
                ue_mag = np.sqrt(uex**2+uey**2+uez**2)
                ue_perp = ue_mag - np.abs(ue_para)
                #plot0 = axs.imshow(np.log10(np.sqrt(uex**2+uey**2+uez**2))[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                #plot0 = axs.imshow(np.log10((pxxS0+pyyS0+pzzS0+pxxS1+pyyS1+pzzS1)/3)[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                plot0 = axs.imshow(np.log10(ue_para/ue_perp)[iy,:,:].T, origin='lower',vmin=clims[0],vmax=clims[1],cmap=cmap,extent = [np.min(X3d),np.max(X3d),np.min(Z3d),np.max(Z3d)])
                axs.contour(X3d[iy,:,:],Z3d[iy,:,:],np.sum(recon_sites[iy-1:iy+2,:,:],axis=0),[0.99],colors=['magenta'])
                clb = fig.colorbar(plot0,ax=axs,shrink=0.9)
                clb.ax.set_title(str(plot_dict["units"]))
                

            # Add streamlines, if desired
            if plot_dict["fieldline"]:
                # Add streamplot for field lines
                Xgrid,Zgrid = np.meshgrid(np.linspace(X3d[iy,0,0],X3d[iy,-1,0],len(X3d[iy,:,0])),
                        np.linspace(Z3d[iy,0,0],Z3d[iy,0,-1],len(Z3d[iy,0,:])))
                for axi in np.atleast_1d(axs).flat:
                    axi.streamplot(Xgrid,Zgrid,Bx3d[iy,:,:].T,Bz3d[iy,:,:].T, broken_streamlines=False, linewidth=0.45,arrowsize=0.45,color='black',
                         density = 2.0)

            if plot_dict["vector"]:
                # Add vector quiveres for electron flow
                skip=2
                Xgrid,Zgrid = np.meshgrid(np.linspace(X3d[iy,0,0],X3d[iy,-1,0],len(X3d[iy,:,0])),
                        np.linspace(Z3d[iy,0,0],Z3d[iy,0,-1],len(Z3d[iy,0,:])))
                for axi in np.atleast_1d(axs).flat:
                    axi.quiver(Xgrid[::skip,::skip],Zgrid[::skip,::skip],
                               #(uex/np.sqrt(uex**2+uez**2))[iy,::skip,::skip].T,(uez/np.sqrt(uex**2+uez**2))[iy,::skip,::skip].T, 
                               -(uex)[iy,::skip,::skip].T,(uez)[iy,::skip,::skip].T, 
                               color='black')#, scale=6e1) #negative sign to account for axes flip
            
            # Tidy axes
            for axi in np.atleast_1d(axs).flat:
                axi.set_aspect(1)
                axi.set_xlim(x_region[1],x_region[0])
                axi.set_ylim(z_region)
                axi.set_xlabel("X [$R_M$]")
                axi.set_ylabel("Z [$R_M$]")
                axi.grid()
    
            fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+plot_var+"_y="+str(yplane)+"_"+"%.2f"%round(float(time),2)+'.png'),
                        bbox_inches='tight',dpi=300)
            plt.show()
            plt.close(fig)

    if plot_preset == "Adiabaticity_2D":

        if read_data:
            
            # Unpack data
            X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            ZMSM = Z3d - 0.2
            RMSM = np.sqrt(X3d**2+Y3d**2+ZMSM**2)
            Bx3d = data3d["Bx"] # [nT]
            By3d = data3d["By"] # [nT]
            Bz3d = data3d["Bz"] # [nT]
            uix = data3d["uxS1"] # [km/s]
            uiy = data3d["uyS1"] # [km/s]
            uiz = data3d["uzS1"] # [km/s]
            uex = data3d["uxS0"] # [km/s]
            uey = data3d["uyS0"] # [km/s]
            uez = data3d["uzS0"] # [km/s]

            r_proj = entropy_map['R']
            mlat = entropy_map['mlat']
            long = entropy_map['long']
            S = entropy_map['S']
            FAC_array = entropy_map['FAC']

            # Convert the entropy map back into x,y,z
            theta = (90 - mlat)*np.pi/180
            phi = long * np.pi/180
            x_proj = (r_proj) * np.sin(theta)*np.cos(phi)
            y_proj = (r_proj) * np.sin(theta)*np.sin(phi)
            z_proj = (r_proj) * np.cos(theta) + 0.2 # Convert from MSM to MSO
    
            # Set B to nan inside r_proj, to avoid tracing field lines beyond there
            Bx3d[RMSM < r_proj*0.99] = np.nan
            By3d[RMSM < r_proj*0.99] = np.nan
            Bz3d[RMSM < r_proj*0.99] = np.nan

            # Find y plane to plot
            yplane = loc[1] #np.mean(Y)
            iy = np.where((Y3d[:,0,0]>yplane))[0][0]
            ixmin = np.where((X3d[0,:,0]>x_region[0]))[0][0]
            ixmax = np.where((X3d[0,:,0]<x_region[1]))[0][-1]
            izmin = np.where((Z3d[0,0,:]>z_region[0]))[0][0]
            izmax = np.where((Z3d[0,0,:]<z_region[1]))[0][-1]

             # Array to save flux tube entropy for each grid cell in xz. Start as a 1d array to match seed list format.
            S_xz = np.zeros_like(np.ravel(X3d[0,ixmin:ixmax,izmin:izmax])) - 1 # -1 is for unclosed field lines
    
            # Set up field line tracing
            #seeds = np.array([np.ravel(x_proj[seed_mask]), np.ravel(y_proj[seed_mask]), np.ravel(z_proj[seed_mask])]).T
            seeds = np.array([np.ravel(X3d[iy,ixmin:ixmax,izmin:izmax]), np.ravel(Y3d[iy,ixmin:ixmax,izmin:izmax]), 
                              np.ravel(Z3d[iy,ixmin:ixmax,izmin:izmax])]).T
            tracer,grid = get_tracer(X3d,Y3d,Z3d,Bx3d,By3d,Bz3d)
            print("beginning tracing of",len(seeds),"lines!")
            tracer.trace(seeds, grid)
            print("done!")
        
            for iseed, seed in enumerate(seeds):
            
                # Unpack line data
                trace_x = tracer.xs[iseed][:,0]
                trace_y = tracer.xs[iseed][:,1]
                trace_z = tracer.xs[iseed][:,2]
    
                # For the purposes of entropy cal, check if its a closed field line, by seeing if both ends are *close to* planet
                trace = tracer.xs[iseed]
                if (np.sum(trace[0,:]**2)<1.5**2) and (np.sum(trace[-1,:]**2)<1.5**2):
                    
                    # Find north hemisphere footpoint coords
                    if trace[0,2] > 0.2:
                        endpoint = trace[0,:]
                    else:
                        endpoint = trace[-1,:]
    
                    # Find long/lat indicies of footpoint
                    dist_map = np.sqrt((x_proj-endpoint[0])**2+(y_proj-endpoint[1])**2+(z_proj-endpoint[2])**2)
                    min_index = np.argmin(dist_map)
                    ilong,imlat = np.unravel_index(min_index, dist_map.shape)
    
                    S_xz[iseed] = S[ilong,imlat] 
    
            S_xz = S_xz.reshape(X3d[iy,ixmin:ixmax,izmin:izmax].shape)

        # Declare figure
        fig,ax = plt.subplots(figsize=(9,7), nrows=1, constrained_layout=True)

        S_lims = [0.5,3]
        cmap='Spectral'
        levels = np.linspace(*S_lims, 51)
        interp_method='none'

        # Plot delta Bz
        plot0 = ax.imshow(np.log10(S_xz).T, origin='lower',vmin=S_lims[0],vmax=S_lims[1],cmap=cmap,extent = [x_region[0],x_region[1],
                            z_region[0],z_region[1]],interpolation=interp_method)
        #plot0 = ax.contourf(X3d[iy,ixmin:ixmax,izmin:izmax],Z3d[iy,ixmin:ixmax,izmin:izmax],np.log10(S_xz),
        #                    levels=levels,cmap=cmap)

        # Add streamplot for field lines
        Xgrid,Zgrid = np.meshgrid(np.linspace(X3d[iy,0,0],X3d[iy,-1,0],len(X3d[iy,:,0])),
                    np.linspace(Z3d[iy,0,0],Z3d[iy,0,-1],len(Z3d[iy,0,:])))
        ax.streamplot(Xgrid,Zgrid,Bx3d[iy,:,:].T,Bz3d[iy,:,:].T, broken_streamlines=False, linewidth=0.45,arrowsize=0.45,color='black',
                     density = 2.5)
        
        # Show DF cross section
        #ax.contourf(X3d[iy,:,:],Z3d[iy,:,:],DF_mask[iy,:,:],[0.999,1],colors='green',alpha=0.5)
        #ax.contour(X3d[iy,:,:],Z3d[iy,:,:],DF_mask[iy,:,:],[0.999],colors='green',linewidths=0.5)

        # Show velocity quivers
        #qskip = 3
        #vnorm = 200 # [km/s]
        #ax.quiver(X3d[iy,::qskip,::qskip],Z3d[iy,::qskip,::qskip],
        #          -uix[iy,::qskip,::qskip]/vnorm,uiz[iy,::qskip,::qskip]/vnorm,color='purple',scale = 60)
        
        # Tidy axes
        ax.set_aspect(1)
        ax.set_xlim(x_region[1],x_region[0])
        ax.set_ylim(z_region)
        ax.set_xlabel("X [$R_M$]")
        ax.set_ylabel("Z [$R_M$]")
        ax.grid()

        # Colorbar
        clb1 = fig.colorbar(plot0,ax=ax,shrink=0.7)
        clb1.ax.set_title("$log_{10}(S)$")

        # Title
        ax.set_title(str("Y = "+str(round(yplane,2))+"\n Flux tube integrated entropy at t = "+time+"s"))

    if plot_preset == "DF_force_diagram":

        if read_data:
            # Unpack data
            X3d = data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[T]
            By3d = data3d["By"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[T]
            Bz3d = data3d["Bz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[T]
            rho = data3d["rhoS1"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e6*amu # [kg/m^3]
            Jx3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uxS1"]-data3d["uxS0"])*1e3)[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jy3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uyS1"]-data3d["uyS0"])*1e3)[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jz3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uyS1"]-data3d["uyS0"])*1e3)[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            dp_dx3d = data3d["dp_dx"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa/m] = [N/m^3]
            dp_dy3d = data3d["dp_dy"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa/m]
            dp_dz3d = data3d["dp_dz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa/m]
            
            # Compute derived terms
            JxB_x = (Jy3d*Bz3d - Jz3d*By3d)
            JxB_y = (Jz3d*Bx3d - Jx3d*Bz3d)
            JxB_z = (Jx3d*By3d - Jy3d*Bx3d)

            F_x = (JxB_x - dp_dx3d) *1e9 * 1e3 #mN / km^3
            F_y = (JxB_y - dp_dy3d) *1e9 * 1e3 #mN / km^3
            F_z = (JxB_z - dp_dz3d) *1e9 * 1e3 #mN / km^3

            # Declare figure
            fig,ax = plt.subplots(figsize=(10,7), constrained_layout=True)

            # Define function to group x
            def average_and_stddev(x, y):
                # Using a dictionary to map x values to list of corresponding y values
                data = defaultdict(list)
                
                # First, ensure all data is in float format
                x = list(map(float, x))
                y = list(map(float, y))
                
                # Populate the dictionary with corresponding y values for each x
                for xi, yi in zip(x, y):
                    data[xi].append(yi)
                
                # Prepare results
                unique_x = []
                avg_y = []
                stddev_y = []
                
                # Calculate mean and standard deviation for each unique x using numpy
                for xi in sorted(data.keys()):
                    unique_x.append(xi)
                    y_vals = np.array(data[xi])
                    avg_y.append(np.mean(y_vals))
                    if len(y_vals) > 1:
                        stddev_y.append(np.std(y_vals, ddof=1))  # Use ddof=1 for sample standard deviation
                    else:
                        stddev_y.append(0.0)  # Standard deviation is 0 if there is only one y-value for an x
                
                return np.array(unique_x), np.array(avg_y), np.array(stddev_y)

            # Compute averages
            x, JxB_avg, JxB_std = average_and_stddev(list(X3d[DF_mask]),list(JxB_x[DF_mask]))
            x, dp_dx_avg, dp_dx_std = average_and_stddev(list(X3d[DF_mask]),list(dp_dx3d[DF_mask]))
            x, F_x_avg, F_x_std = average_and_stddev(list(X3d[DF_mask]),list(F_x[DF_mask]))

            # Plot
            ax.plot(x,JxB_avg*1e12,color='red')
            ax.plot(x,-dp_dx_avg*1e12,color='blue')
            ax.plot(x,F_x_avg,color='black')
            ax.fill_between(x,(JxB_avg-JxB_std)*1e12,(JxB_avg+JxB_std)*1e12,color='red',alpha=0.1)
            ax.fill_between(x,(-dp_dx_avg-dp_dx_std)*1e12,(-dp_dx_avg+dp_dx_std)*1e12,color='blue',alpha=0.1)
            ax.axhline(y=0,color='black',linestyle='dashed')

            # Tidy axes
            ax.set_xlim(-2,-1)
            ax.set_ylim(-0.01,0.01)

    if plot_preset == "FACs":

        if read_data:
            # Unpack data
            X3d = data3d["X"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            By3d = data3d["By"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bz3d = data3d["Bz"][trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Jx3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uxS1"]-data3d["uxS0"])*1e3)[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jy3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uyS1"]-data3d["uyS0"])*1e3)[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jz3d = (e*(data3d["rhoS1"]*1e6)*(data3d["uzS1"]-data3d["uzS0"])*1e3)[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]

            # Project J into B
            J_para = (Bx3d*Jx3d+By3d*Jy3d+Bz3d*Jz3d)/np.sqrt(Bx3d**2+By3d**2+Bz3d**2)
            
        # Find DFs, neglecting FR overlap
        #DF_only_mask = DF_mask & (np.invert(FR_mask))
        #FR_only_mask = FR_mask & (np.invert(DF_mask))
        #DF_FR_mask = DF_mask & FR_mask

        # Plot
        fig = plt.figure(figsize=(20,12), constrained_layout=True)
        ax1 = fig.add_subplot(121, projection="3d",computed_zorder=False)
        ax2 = fig.add_subplot(122, projection="3d",computed_zorder=False)

        # FAC threshold
        FAC_min = 200 * 1e-9 #A/m^2
        FAC_mask = (np.abs(J_para) > FAC_min) & np.invert(DF_mask) # Don't show the FAC if it is inside the DF
        
        if np.any(FR_mask):
            for ax in [ax1,ax2]: 
                if ax == ax1:
                    temp_ylims = [0,1.2]
                    # Declare masks which set whether the points are in front of or behind the planet
                    in_front = Y3d>0
                elif ax == ax2:
                    in_front = Y3d<0
                    temp_ylims = [-1.2,0]
                behind = np.invert(in_front)
                
                ax.scatter(X3d[DF_mask & in_front],Y3d[DF_mask & in_front],Z3d[DF_mask & in_front],c=-Z3d[DF_mask & in_front],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.02, zorder = 2)
                #ax.scatter(X3d[DF_mask & behind],Y3d[DF_mask & behind],Z3d[DF_mask & behind],c=-Z3d[DF_mask & behind],vmin=-1.2,vmax=0.9,cmap="Reds",alpha = 0.02, zorder = 0.75)
                FAC_plot = ax.scatter(X3d[FAC_mask & in_front],Y3d[FAC_mask & in_front],Z3d[FAC_mask & in_front],c=J_para[FAC_mask & in_front]*1e9, cmap = 'PRGn',alpha=0.2,zorder=2,vmin=-500,vmax=500)
            
                # Update limits to be centered with max range
                ax.set_xlim(np.mean(xlims) - max_range / 2, np.mean(xlims) + max_range / 2)
                ax.set_ylim(np.mean(temp_ylims) - max_range / 2, np.mean(temp_ylims) + max_range / 2)
                ax.set_zlim(np.mean(zlims) - max_range / 2, np.mean(zlims) + max_range / 2)
            
                ax.set_xlabel("X [$R_M$]")
                ax.set_ylabel("Y [$R_M$]")
                ax.set_zlabel("Z [$R_M$]")

        # Show planet
        plot_sphere(ax1,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],ylims=[0,2])
        plot_sphere(ax1,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],ylims=[0,2])
        plot_sphere(ax2,radius=1,color='lightgrey',alpha=0.8,zorder=1,xlims=[-10,-0.5],ylims=[-2,0])
        plot_sphere(ax2,radius=0.8,color='grey',alpha=1,zorder=1.25,xlims=[-10,-0.5],ylims=[-2,0])
        
        # Set viewing angle
        ax1.view_init(elev=5, azim=85)
        ax2.view_init(elev=5, azim=-85)

        # Colorbar
        # Colorbar
        clb1 = fig.colorbar(FAC_plot,ax=[ax1,ax2],shrink=0.5)
        clb1.ax.set_title("$J_{FAC}$ [nA/m$^2$]")

        # Save the DF mask for this time step
        pickle.dump(DF_mask, open(str(dir+"DFs/DF_mask_t_"+'{:06.2f}'.format(round(float(time),2))), 'wb') )

        # Add titles
        ax1.set_title(str("DFs and FACs, premidnight, at t="+time+"s"),fontsize=12,y=1.0, pad=-14)
        ax2.set_title(str("DFs and FACs, postmidnight, at t="+time+"s"),fontsize=12,y=1.0, pad=-14)

    if plot_preset == "B_timeseries":
    
        if read_data: # Only do anything if read_data is true

            if iter==0:
                t_ls = []
                Bx_ls = []
                By_ls = []
                Bz_ls = []
                n_ls = []
                pi_ls = []
                pe_ls = []
                uix_ls = []
                uiy_ls = []
                uiz_ls = []
                uex_ls = []
                uey_ls = []
                uez_ls = []
    
            # Read in data
            X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
            Bx3d = data3d["Bx"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nT]
            By3d = data3d["By"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nT]
            Bz3d = data3d["Bz"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] #[nT]
            n = data3d['rhoS1']#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [amu/cc]
            uix = data3d["uxS1"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [km/s]
            uiy = data3d["uyS1"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [km/s]
            uiz = data3d["uzS1"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [km/s]
            uex = data3d["uxS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [km/s]
            uey = data3d["uyS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [km/s]
            uez = data3d["uzS0"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [km/s]
            pi = (data3d["pxxS1"]+data3d["pyyS1"]+data3d["pzzS1"])/3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]])/3
            pe = (data3d["pxxS0"]+data3d["pyyS0"]+data3d["pzzS0"])/3 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]])/3
    
            # Find indices for loc
            ix = np.where(X3d[0,:,0]>loc[0])[0][0]
            iy = np.where(Y3d[:,0,0]>loc[1])[0][0]
            iz = np.where(Z3d[0,0,:]>loc[2])[0][0]
    
            # Save data to list
            t_ls.append(float(time))
            Bx_ls.append(Bx3d[iy,ix,iz])
            By_ls.append(By3d[iy,ix,iz])
            Bz_ls.append(Bz3d[iy,ix,iz])
            n_ls.append(n[iy,ix,iz])
            pi_ls.append(pi[iy,ix,iz])
            pe_ls.append(pe[iy,ix,iz])
            uix_ls.append(uix[iy,ix,iz])
            uiy_ls.append(uiy[iy,ix,iz])
            uiz_ls.append(uiz[iy,ix,iz])
            uex_ls.append(uex[iy,ix,iz])
            uey_ls.append(uey[iy,ix,iz])
            uez_ls.append(uez[iy,ix,iz])

    if plot_preset == "reconnection_sites":

        if read_data:
            
            # Unpack data
            X3d = data3d["X"]
            Xcs = datacs["X"]
            Y3d = data3d["Y"]
            Ycs = datacs["Y"]
            Z3d = data3d["Z"]
            Zcs = datacs["Z"]
            Bx3d = data3d["Bx"]*1e-9 #[T]
            By3d = data3d["By"]*1e-9 #[T]
            Bz3d = data3d["Bz"]*1e-9 #[T]
            Bzcs = datacs["Bz"] #[nT]
            Ex3d = data3d["Ex"]*1e-6 # V/m
            Ey3d = data3d["Ey"]*1e-6 # V/m
            Ez3d = data3d["Ez"]*1e-6 # V/m
            n = data3d["rhoS1"]*1e6 # 1/m^3
            rho = n*amu # [kg/m^3]
            Jx3d = (e*data3d["rhoS1"]*1e6*(data3d["uxS1"]-data3d["uxS0"])*1e3) # [A/m^2]
            Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3) # [A/m^2]
            Jz3d = (e*data3d["rhoS1"]*1e6*(data3d["uzS1"]-data3d["uzS0"])*1e3) # [A/m^2]
            pxxS0 = data3d["pxxS0"]*1e-9 #[Pa]
            pxyS0 = data3d["pxyS0"]*1e-9 #[Pa]
            pxzS0 = data3d["pxzS0"]*1e-9 #[Pa]
            pyyS0 = data3d["pyyS0"]*1e-9 #[Pa]
            pyzS0 = data3d["pyzS0"]*1e-9 #[Pa]
            pzzS0 = data3d["pzzS0"]*1e-9 #[Pa]
            pxxS1 = data3d["pxxS1"]*1e-9 #[Pa]
            pxyS1 = data3d["pxyS1"]*1e-9 #[Pa]
            pxzS1 = data3d["pxzS1"]*1e-9 #[Pa]
            pyyS1 = data3d["pyyS1"]*1e-9 #[Pa]
            pyzS1 = data3d["pyzS1"]*1e-9 #[Pa]
            pzzS1 = data3d["pzzS1"]*1e-9 #[Pa]
            dB_dx = data3d["dB_dx"]*1e-9 # T/m
            dB_dy = data3d["dB_dy"]*1e-9 # T/m
            dB_dz = data3d["dB_dz"]*1e-9 # T/m
            uix3d = data3d["uxS1"]*1e3 # [m/s]
            uiy3d = data3d["uyS1"]*1e3 # [m/s]
            uixcs = datacs["uxS1"] # [km/s]
            uiycs = datacs["uyS1"] # [km/s]
            uiz3d = data3d["uzS1"]*1e3 # [m/s]
            uex3d = data3d["uxS0"]*1e3 # [m/s]
            uey3d = data3d["uyS0"]*1e3 # [m/s]
            uez3d = data3d["uzS0"]*1e3 # [m/s]

            # Compute reconnection score
            L, D_e, APhi, root_Q, S = compute_recon_score(Bx3d,By3d,Bz3d,Ex3d,Ey3d,Ez3d,Jx3d,Jy3d,Jz3d,uex3d,uey3d,uez3d,pxxS0,pyyS0,pzzS0,pxyS0,pxzS0,pyzS0)

            # Declare current sheet region to max over
            z_cs = [0.1,0.3]
            iz_cs = [np.where(Z3d[0,0,:]>z_cs[0])[0][0],np.where(Z3d[0,0,:]<z_cs[1])[0][-1]]
            S_max = np.max(S[:,:,iz_cs[0]:iz_cs[1]], axis = 2)

        # Setup figure
        fig,ax = plt.subplots(figsize = (10,7))

        # Plot
        B_lims = [-50,50]
        Bz_plot = ax.imshow(Bzcs, origin='lower', vmin=B_lims[0],vmax=B_lims[1],cmap='coolwarm',extent=[np.min(Xcs),np.max(Xcs),np.min(Ycs),np.max(Ycs)])
        #Bz_plot = ax.contourf(Xcs,Ycs,Bzcs, levels = np.linspace(-50,50,101),cmap='coolwarm')
        qskip = 8
        ui_plot = ax.quiver(Xcs[::qskip,::qskip],Ycs[::qskip,::qskip],uixcs[::qskip,::qskip],uiycs[::qskip,::qskip],
                            color='black')
        ax.contour(Xcs,Ycs,S_max,np.linspace(4.75,5.25,3),colors=['limegreen','forestgreen','darkgreen'],lw=0.2)

        # Add grid
        add_grid(ax,x_region,y_region,0.5,0.1,0.4,0.1)
        
        # Add title
        ax.set_title(str("Reconnection score and $B_z$ at t="+time+"s"),fontsize=12)
        ax.set_xlabel("X' [$R_M$]")
        ax.set_ylabel("Y' [$R_M$]")

        # Add colorbar
        fig.colorbar(Bz_plot, ax = ax, shrink = 0.7)

        # Set axes limits
        ax.set_xlim(x_region)
        ax.set_ylim(y_region)
        ax.set_aspect(1)

    if plot_preset == "AEPIC_reconnection_sites":

        if read_data:
            
            # Unpack data
            X3d = data3d["X"]
            Xcs = datacs["X"]
            Y3d = data3d["Y"]
            Ycs = datacs["Y"]
            Z3d = data3d["Z"]
            Zcs = datacs["Z"]
            Bx3d = data3d["Bx"]*1e-9 #[T]
            By3d = data3d["By"]*1e-9 #[T]
            Bz3d = data3d["Bz"]*1e-9 #[T]
            Bzcs = datacs["Bz"] #[nT]
            Ex3d = data3d["Ex"]*1e-6 # V/m
            Ey3d = data3d["Ey"]*1e-6 # V/m
            Ez3d = data3d["Ez"]*1e-6 # V/m
            n = data3d["rhoS1"]*1e6 # 1/m^3
            rho = n*amu # [kg/m^3]
            Jx3d = (e*data3d["rhoS1"]*1e6*(data3d["uxS1"]-data3d["uxS0"])*1e3) # [A/m^2]
            Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3) # [A/m^2]
            Jz3d = (e*data3d["rhoS1"]*1e6*(data3d["uzS1"]-data3d["uzS0"])*1e3) # [A/m^2]
            pxxS0 = data3d["pxxS0"]*1e-9 #[Pa]
            pxyS0 = data3d["pxyS0"]*1e-9 #[Pa]
            pxzS0 = data3d["pxzS0"]*1e-9 #[Pa]
            pyyS0 = data3d["pyyS0"]*1e-9 #[Pa]
            pyzS0 = data3d["pyzS0"]*1e-9 #[Pa]
            pzzS0 = data3d["pzzS0"]*1e-9 #[Pa]
            pxxS1 = data3d["pxxS1"]*1e-9 #[Pa]
            pxyS1 = data3d["pxyS1"]*1e-9 #[Pa]
            pxzS1 = data3d["pxzS1"]*1e-9 #[Pa]
            pyyS1 = data3d["pyyS1"]*1e-9 #[Pa]
            pyzS1 = data3d["pyzS1"]*1e-9 #[Pa]
            pzzS1 = data3d["pzzS1"]*1e-9 #[Pa]
            dB_dx = data3d["dB_dx"]*1e-9 # T/m
            dB_dy = data3d["dB_dy"]*1e-9 # T/m
            dB_dz = data3d["dB_dz"]*1e-9 # T/m
            dBx_dx = data3d["dBx_dx"]*1e-9  #[T/m]
            dBx_dy = data3d["dBx_dy"]*1e-9  #[T/m]
            dBx_dz = data3d["dBx_dz"]*1e-9  #[T/m]
            dBy_dx = data3d["dBy_dx"]*1e-9  #[T/m]
            dBy_dy = data3d["dBy_dy"]*1e-9  #[T/m]
            dBy_dz = data3d["dBy_dz"]*1e-9  #[T/m]
            dBz_dx = data3d["dBz_dx"]*1e-9  #[T/m]
            dBz_dy = data3d["dBz_dy"]*1e-9  #[T/m]
            dBz_dz = data3d["dBz_dz"]*1e-9  #[T/m]
            uix3d = data3d["uxS1"]*1e3 # [m/s]
            uiy3d = data3d["uyS1"]*1e3 # [m/s]
            uixcs = datacs["uxS1"] # [km/s]
            uiycs = datacs["uyS1"] # [km/s]
            uiz3d = data3d["uzS1"]*1e3 # [m/s]
            uex3d = data3d["uxS0"]*1e3 # [m/s]
            uey3d = data3d["uyS0"]*1e3 # [m/s]
            uez3d = data3d["uzS0"]*1e3 # [m/s]

            # Compute first critera of Wang et al. 2022
            epsilon = 1 # 1 nT, as given in paper
            c1 = (Jx3d**2+Jy3d**2+Jz3d**2)*dx / (np.sqrt((Jy3d*Bz3d-Jz3d*By3d)**2 + (Jz3d*Bx3d-Jx3d*Bz3d)**2 + (Jx3d*By3d-Jy3d*Bx3d)**2) + np.sqrt(Jx3d**2+Jy3d**2+Jz3d**2)*epsilon)
            # c1>0.8 is recommended threshold
            
            # Compute second criteria, based on curvature divergence. If large enough, it seperates x lines from o-lines or flux ropes
            Bmag = np.sqrt(Bx3d**2+By3d**2+Bz3d**2) # [nT]
            bx = Bx3d/Bmag
            by = By3d/Bmag
            bz = Bz3d/Bmag
            # We need partial derivatives of the magnetic field unit vectors.
            # To reduce the number of derivates to take, we compute them using the chain rule and the pre-computed derivatives
            dbx_dx = (1/Bmag) * (dBx_dx - bx*dB_dx)
            dbx_dy = (1/Bmag) * (dBx_dy - bx*dB_dy)
            dbx_dz = (1/Bmag) * (dBx_dz - bx*dB_dz)
            dby_dx = (1/Bmag) * (dBy_dx - by*dB_dx)
            dby_dy = (1/Bmag) * (dBy_dy - by*dB_dy)
            dby_dz = (1/Bmag) * (dBy_dz - by*dB_dz)
            dbz_dx = (1/Bmag) * (dBz_dx - bz*dB_dx)
            dbz_dy = (1/Bmag) * (dBz_dy - bz*dB_dy)
            dbz_dz = (1/Bmag) * (dBz_dz - bz*dB_dz)
            # Precompute the terms of c2
            pre_c2_x = bx*dbx_dx + by*dbx_dy + bz*dbx_dz
            pre_c2_y = bx*dby_dx + by*dby_dy + bz*dby_dz
            pre_c2_z = bx*dbz_dx + by*dbz_dy + bz*dbz_dz
            # Compute the gradient of each term
            c2_x = np.gradient(pre_c2_x,dx*R_M,axis=1)
            c2_y = np.gradient(pre_c2_y,dx*R_M,axis=0)
            c2_z = np.gradient(pre_c2_z,dx*R_M,axis=2)
            c2 = (c2_x+c2_y+c2_z)*(dx)**2
            
            # Declare current sheet region to plot through
            z_cs = [0.0,0.4]
            iz_cs = [np.where(Z3d[0,0,:]>z_cs[0])[0][0],np.where(Z3d[0,0,:]<z_cs[1])[0][-1]]
            
            # Downselect to cells which fulfill the criteria
            c1_min = 0.015
            c2_min = 1e-9

            # Combine criteria
            recon_sites = np.copy(c1)
            recon_sites[c1<c1_min] = 0
            recon_sites[c2<c2_min] = 0
            recon_sites[recon_sites>0]=1
            
            recon_sites_count = np.sum(recon_sites[:,:,iz_cs[0]:iz_cs[1]], axis = 2)

        # Setup figure
        fig,axs = plt.subplots(ncols=3 , figsize = (28,7))

        # Plot
        B_lims = [-50,50]
        #Bz_plot = ax.imshow(Bzcs, origin='lower', vmin=B_lims[0],vmax=B_lims[1],cmap='coolwarm',extent=[np.min(Xcs),np.max(Xcs),np.min(Ycs),np.max(Ycs)])
        #Bz_plot = ax.contourf(Xcs,Ycs,Bzcs, levels = np.linspace(-50,50,101),cmap='coolwarm')
        #qskip = 8
        #ui_plot = ax.quiver(Xcs[::qskip,::qskip],Ycs[::qskip,::qskip],uixcs[::qskip,::qskip],uiycs[::qskip,::qskip],
        #                    color='black')
        c1_lims = [0,0.025]
        c1_plot = axs[0].imshow(c1[:,:,np.where(Z3d[0,0,:]>0.2)[0][0]],origin='lower',cmap='viridis',vmin=c1_lims[0],vmax=c1_lims[1],
                                        extent=[np.min(Xcs),np.max(Xcs),np.min(Ycs),np.max(Ycs)])
        axs[0].contour(Xcs,Ycs,c1[:,:,np.where(Z3d[0,0,:]>0.2)[0][0]],[c1_min],colors=['red'],linewidths=[0.5])
        
        c2_lims = [-1e-8,1e-8]
        c2_plot = axs[1].imshow(c2[:,:,np.where(Z3d[0,0,:]>0.2)[0][0]],origin='lower',cmap='viridis',vmin=c2_lims[0],vmax=c2_lims[1],
                                        extent=[np.min(Xcs),np.max(Xcs),np.min(Ycs),np.max(Ycs)])
        axs[1].contour(Xcs,Ycs,c2[:,:,np.where(Z3d[0,0,:]>0.2)[0][0]],[c2_min],colors=['red'],linewidths=[0.5])
        
        recon_lims = [0,10]
        recon_sites_heatmap = axs[2].imshow(recon_sites_count,origin='lower',cmap='plasma',vmin=recon_lims[0],vmax=recon_lims[1],
                                        extent=[np.min(Xcs),np.max(Xcs),np.min(Ycs),np.max(Ycs)])
        
        # Add grid
        for axi in axs:
            add_grid(axi,x_region,y_region,0.5,0.1,0.4,0.1)
            axi.set_xlabel("X' [$R_M$]")
            axi.set_ylabel("Y' [$R_M$]")
            # Set axes limits
            axi.set_xlim(x_region)
            axi.set_ylim(y_region)
            axi.set_aspect(1)
        
        # Add title
        axs[0].set_title(str("c1 at t="+time+"s"),fontsize=12)
        axs[1].set_title(str("c2 at t="+time+"s"),fontsize=12)
        axs[2].set_title(str("Reconnection sites at t="+time+"s"),fontsize=12)

        # Add colorbar
        fig.colorbar(c1_plot, ax = axs[0], shrink = 0.7)
        fig.colorbar(c2_plot, ax = axs[1], shrink = 0.7)
        fig.colorbar(recon_sites_heatmap, ax = axs[2], shrink = 0.7)


    if plot_preset == "total_current":

        if iter==0 and read_data:
            # Declare lists to save the total energy content of the region
            t_ls = []
            J_x_ls = []
            J_y_ls = []
            J_z_ls = []
            J_E_ls = []
            J_W_ls = []
            J_up_FAC_ls = []
            J_down_FAC_ls = []

        if read_data:
            
            # Set mask region
            # run3_eg2_FRDF:
            R_region = [1.1,1.5]
            y_region = [-1,0]

            # Unpack data
            X3d = data3d["X"]
            Y3d = data3d["Y"]
            Z3d = data3d["Z"]
            R3d = np.sqrt(X3d**2+Y3d**2+Z3d**2)
            Bx3d = data3d["Bx"]*1e-9 #[T]
            By3d = data3d["By"]*1e-9 #[T]
            Bz3d = data3d["Bz"]*1e-9 #[T]
            n = data3d["rhoS1"]*1e6 # 1/m^3
            rho = n*amu # [kg/m^3]
            Jx3d = (e*data3d["rhoS1"]*1e6*(data3d["uxS1"]-data3d["uxS0"])*1e3) # [A/m^2]
            Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3) # [A/m^2]
            Jz3d = (e*data3d["rhoS1"]*1e6*(data3d["uzS1"]-data3d["uzS0"])*1e3) # [A/m^2]

            # Compute current components in magnetic coords
            B_mag = np.sqrt(Bx3d**2+By3d**2+Bz3d**2)
            bx,by,bz = [Bx3d,By3d,Bz3d]/B_mag
            J_para = (bx*Jx3d + by*Jy3d + bz*Jz3d)*1e9 # [nA/m^2]
            J_para_x = J_para*bx
            J_para_y = J_para*by
            J_para_z = J_para*bz
            J_perp_x = Jx3d*1e9 - J_para_x # [nA/m^2]
            J_perp_y = Jy3d*1e9 - J_para_y # [nA/m^2]
            J_perp_z = Jz3d*1e9 - J_para_z # [nA/m^2]
           
            # eastward unit vector
            theta = np.pi + np.arctan(Y3d/X3d)
            i_E_x = -np.sin(theta)
            i_E_y = np.cos(theta)
            
            # Compute east/west current
            J_E = J_perp_x * i_E_x +  J_perp_y * i_E_y # [nA/m^2]
    
            # Save to lists
            t_ls.append(float(time))
            J_E_ls.append(np.nansum(J_E[(J_E>0) & (Y3d>y_region[0]) & (Y3d<y_region[1]) & (R3d>R_region[0]) & (R3d<R_region[1])])) # [nA/m^2]
            J_W_ls.append(np.nansum(J_E[(J_E<0) & (Y3d>y_region[0]) & (Y3d<y_region[1]) & (R3d>R_region[0]) & (R3d<R_region[1])])) # [nA/m^2]
            J_x_ls.append(np.nansum(Jx3d[(Y3d>y_region[0]) & (Y3d<y_region[1]) & (R3d>R_region[0]) & (R3d<R_region[1])]))
            J_y_ls.append(np.nansum(Jy3d[(Y3d>y_region[0]) & (Y3d<y_region[1]) & (R3d>R_region[0]) & (R3d<R_region[1])]))
            J_z_ls.append(np.nansum(Jz3d[(Y3d>y_region[0]) & (Y3d<y_region[1]) & (R3d>R_region[0]) & (R3d<R_region[1])]))
            J_up_FAC_ls.append(np.nansum(np.abs(J_para)[(Jx3d<0) & (Y3d>y_region[0]) & (Y3d<y_region[1]) & (R3d>R_region[0]) & (R3d<R_region[1])]))
            J_down_FAC_ls.append(np.nansum(np.abs(J_para)[(Jx3d>0) & (Y3d>y_region[0]) & (Y3d<y_region[1]) & (R3d>R_region[0]) & (R3d<R_region[1])]))

            #if iter==0:
    if plot_preset == "current_spectra":

        if read_data: # If false, then skip to plotting

            # Settings for the plot
            yplane = loc[1] # Which y coord to compute current 
            #nxbins = 20 # How many bins to do along the tail
                
            # Unpack data
            X3d = data3d["X"]
            Y3d = data3d["Y"]
            Z3d = data3d["Z"]
            Bx3d = data3d["Bx"]*1e-9  #[T]
            By3d = data3d["By"]*1e-9  #[T]
            Bz3d = data3d["Bz"]*1e-9  #[T]
            dBx_dx = data3d["dBx_dx"]*1e-9  #[T/m]
            dBx_dy = data3d["dBx_dy"]*1e-9  #[T/m]
            dBx_dz = data3d["dBx_dz"]*1e-9  #[T/m]
            dBz_dx = data3d["dBz_dx"]*1e-9  #[T/m]
            dBz_dy = data3d["dBz_dy"]*1e-9  #[T/m]
            dBz_dz = data3d["dBz_dz"]*1e-9  #[T/m]
            Ex3d = data3d["Ex"]*1e-6  # [V/m]
            Ey3d = data3d["Ey"]*1e-6  # [V/m]
            Ez3d = data3d["Ez"]*1e-6  # [V/m]
            n = data3d["rhoS1"]*1e6 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # 1/m^3
            rho = n*amu # [kg/m^3]
            Jx3d = (e*data3d["rhoS1"]*1e6*(data3d["uxS1"]-data3d["uxS0"])*1e3) # [A/m^2]
            Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3) # [A/m^2]
            Jz3d = (e*data3d["rhoS1"]*1e6*(data3d["uzS1"]-data3d["uzS0"])*1e3) # [A/m^2]
            dp_dx3d = data3d["dp_dx"]*1e-9  #[Pa/m] = [N/m^3]
            dp_dy3d = data3d["dp_dy"]*1e-9  #[Pa/m]
            dp_dz3d = data3d["dp_dz"]*1e-9  #[Pa/m]
            dp_perp_dx = data3d["dp_perp_dx"]*1e-9  #[Pa/m] = [N/m^3]
            dp_perp_dy = data3d["dp_perp_dy"]*1e-9  #[Pa/m]
            dp_perp_dz = data3d["dp_perp_dz"]*1e-9  #[Pa/m]
            pxxS0 = data3d["pxxS0"]*1e-9 #[Pa]
            pxyS0 = data3d["pxyS0"]*1e-9 #[Pa]
            pxzS0 = data3d["pxzS0"]*1e-9 #[Pa]
            pyyS0 = data3d["pyyS0"]*1e-9 #[Pa]
            pyzS0 = data3d["pyzS0"]*1e-9 #[Pa]
            pzzS0 = data3d["pzzS0"]*1e-9 #[Pa]
            pxxS1 = data3d["pxxS1"]*1e-9 #[Pa]
            pxyS1 = data3d["pxyS1"]*1e-9 #[Pa]
            pxzS1 = data3d["pxzS1"]*1e-9 #[Pa]
            pyyS1 = data3d["pyyS1"]*1e-9 #[Pa]
            pyzS1 = data3d["pyzS1"]*1e-9 #[Pa]
            pzzS1 = data3d["pzzS1"]*1e-9 #[Pa]
            dB_dx = data3d["dB_dx"]*1e-9 # T/m
            dB_dy = data3d["dB_dy"]*1e-9 # T/m
            dB_dz = data3d["dB_dz"]*1e-9 # T/m
            uix = data3d["uxS1"]*1e3  # [m/s]
            uiy = data3d["uyS1"]*1e3  # [m/s]
            uiz = data3d["uzS1"]*1e3  # [m/s]
            duix_dx = data3d["duix_dx"]*1e3 # [1/s]
            duix_dy = data3d["duix_dy"]*1e3 # [1/s]
            duix_dz = data3d["duix_dz"]*1e3 # [1/s]
            duiz_dx = data3d["duiz_dx"]*1e3 # [1/s]
            duiz_dy = data3d["duiz_dy"]*1e3 # [1/s]
            duiz_dz = data3d["duiz_dz"]*1e3 # [1/s]
            uex = data3d["uxS0"]*1e3 # [m/s]
            uey = data3d["uyS0"]*1e3 # [m/s]
            uez = data3d["uzS0"]*1e3 # [m/s]
            duex_dx = data3d["duex_dx"]*1e3 # [1/s]
            duex_dy = data3d["duex_dy"]*1e3 # [1/s]
            duex_dz = data3d["duex_dz"]*1e3 # [1/s]
            duez_dx = data3d["duez_dx"]*1e3 # [1/s]
            duez_dy = data3d["duez_dy"]*1e3 # [1/s]
            duez_dz = data3d["duez_dz"]*1e3 # [1/s]

            if iter==0:

                nxbins = len(X3d[0,:,0][(X3d[0,:,0]>x_region[0]) & (X3d[0,:,0]<x_region[1])])
                ntbins = int((t_bound[1] - t_bound[0])/dt)
                dxbin = (x_region[1] - x_region[0])/nxbins
                t_array = np.array([np.linspace(t_bound[0],(t_bound[1]-dt),ntbins)]).T
                x_array = np.array([np.linspace(x_region[0],(x_region[1]-dxbin),nxbins)])
                Jy_array = np.zeros((ntbins,nxbins))
                Jy_perp_array = np.zeros((ntbins,nxbins))
                Jy_dia_array = np.zeros((ntbins,nxbins))
                Jy_inrt_array = np.zeros((ntbins,nxbins))

                # Work out indices for data binning
                iyplane = np.where(Y3d[:,0,0]>yplane)[0][0]
                #ixbins = []
                #for ibin in range(nxbins):
                #    ixbins.append(np.where((X3d[0,:,0]>x_region[0]+ibin*dxbin) & (X3d[0,:,0]<x_region[0]+(ibin+1)*dxbin))[0])
                ixbins = np.where((X3d[0,:,0]>x_region[0]) & (X3d[0,:,0]<x_region[1]))[0]
    
            # Compute time derivatives
            dts = compute_dt(["uxS1","uzS1","uxS0","uzS0"],time,type='numpy',path=dir_3D)
            duix_dt = dts["uxS1"]*1e3 #[m/s^2]
            duiz_dt = dts["uzS1"]*1e3 #[m/s^2]
            duex_dt = dts["uxS0"]*1e3 #[m/s^2]
            duez_dt = dts["uzS0"]*1e3 #[m/s^2]
        
            # Inertial current
            Bmag = np.sqrt(Bx3d**2 + By3d**2 + Bz3d**2) # T
            Jy3d_inrt_i = rho / Bmag**2 * ( Bz3d*(duix_dt + uix*duix_dx + uiy*duix_dy + uiz*duix_dz) - Bx3d*(duiz_dt + uix*duiz_dx + uiy*duiz_dy + uiz*duiz_dz) ) # [A/m^2]
            #Jy3d_inrt_e = rho/mi_me / Bmag**2 * ( Bz3d*(duex_dt + uex*duex_dx + uey*duex_dy + uez*duex_dz) - Bx3d*(duez_dt + uex*duez_dx + uey*duez_dy + uez*duez_dz) ) # [A/m^2]

            # Compute FAC
            J_para = (Jx3d*Bx3d + Jy3d*By3d + Jz3d*Bz3d)/Bmag
            Jy3d_para = J_para * By3d/Bmag # [A/m^2]
            Jy3d_perp = Jy3d - Jy3d_para # [A/m^2]
            
            # Compute grad-curv drift ring current
            # Compute magnetic unit vector
            bx,by,bz = 1/Bmag * [Bx3d,By3d,Bz3d]
        
            # Compute magnetic field unit vectors: u,v,b. u is parallel to ExB direction. v = uxb. We already have b.
            # In the end, our new system will be perp1, perp2, para (u,v,b). See back page of space physics 2024 notebook for derivation, or photos from sep. 4, 2024
            # Compute u first, making sure to normalize
            ux,uy,uz = [Ey3d*Bz3d-Ez3d*By3d, Ez3d*Bx3d-Ex3d*Bz3d, Ex3d*By3d-Ey3d*Bx3d]
            ux,uy,uz = 1/np.sqrt(ux**2+uy**2+uz**2)*[ux,uy,uz]
        
            # Then compute v (already normalized):
            vx,vy,vz = [uy*bz-uz*by, uz*bx-ux*bz, ux*by-uy*bx]
        
            # We now compute the three diagonal pressure components, using e.g. P_para = b . p . b
            P11S0 = ux*(pxxS0*ux+pxyS0*uy+pxzS0*uz) + uy*(pxyS0*ux+pyyS0*uy+pyzS0*uz) + uz*(pxzS0*ux+pyzS0*uy+pzzS0*uz)
            P22S0 = vx*(pxxS0*vx+pxyS0*vy+pxzS0*vz) + vy*(pxyS0*vx+pyyS0*vy+pyzS0*vz) + vz*(pxzS0*vx+pyzS0*vy+pzzS0*vz)
            P33S0 = bx*(pxxS0*bx+pxyS0*by+pxzS0*bz) + by*(pxyS0*bx+pyyS0*by+pyzS0*bz) + bz*(pxzS0*bx+pyzS0*by+pzzS0*bz)
            P11S1 = ux*(pxxS1*ux+pxyS1*uy+pxzS1*uz) + uy*(pxyS1*ux+pyyS1*uy+pyzS1*uz) + uz*(pxzS1*ux+pyzS1*uy+pzzS1*uz)
            P22S1 = vx*(pxxS1*vx+pxyS1*vy+pxzS1*vz) + vy*(pxyS1*vx+pyyS1*vy+pyzS1*vz) + vz*(pxzS1*vx+pyzS1*vy+pzzS1*vz)
            P33S1 = bx*(pxxS1*bx+pxyS1*by+pxzS1*bz) + by*(pxyS1*bx+pyyS1*by+pyzS1*bz) + bz*(pxzS1*bx+pyzS1*by+pzzS1*bz)
    
            # Compute perp/para total pressures
            p_perp = (P11S0+P22S0+P11S1+P22S1)/2
            p_para = P33S0 + P33S1
    
            # Compute full diamagnetic current: J_{dia} = \frac{\vec{B}}{B^2} \left[ \nabla p_\perp - \left(p_\perp - p_\parallel\right)\frac{\left(\vec{B}\cdot\nabla\right) \vec{B}}{B}\right]
            Jy3d_dia = (1/Bmag**2) * (Bz3d * (dp_perp_dx - (p_perp-p_para)/Bmag*(Bx3d*dBx_dx+By3d*dBx_dy+Bz3d*dBx_dz)) - 
                                      Bx3d * (dp_perp_dz - (p_perp-p_para)/Bmag*(Bx3d*dBz_dx+By3d*dBz_dy+Bz3d*dBz_dz)))

            # Save to arrays
            for ibin in range(nxbins):
                #Jy_array[iter,ibin] = np.nansum(Jy3d[iyplane,ixbins[ibin][0]:ixbins[ibin][-1],:])*(R_M/64)**2
                Jy_array[iter,ibin] = np.nansum(Jy3d[iyplane,ixbins[ibin],:])*(R_M/64)**2
                #Jy_perp_array[iter,ibin] = np.nansum(Jy3d_perp[iyplane,ixbins[ibin][0]:ixbins[ibin][-1],:])*(R_M/64)**2
                Jy_perp_array[iter,ibin] = np.nansum(Jy3d_perp[iyplane,ixbins[ibin],:])*(R_M/64)**2
                #Jy_dia_array[iter,ibin] = np.nansum(Jy3d_dia[iyplane,ixbins[ibin][0]:ixbins[ibin][-1],:])*(R_M/64)**2
                Jy_dia_array[iter,ibin] = np.nansum(Jy3d_dia[iyplane,ixbins[ibin],:])*(R_M/64)**2
                #Jy_inrt_array[iter,ibin] = np.nansum(Jy3d_inrt_i[iyplane,ixbins[ibin][0]:ixbins[ibin][-1],:])*(R_M/64)**2
                Jy_inrt_array[iter,ibin] = np.nansum(Jy3d_inrt_i[iyplane,ixbins[ibin],:])*(R_M/64)**2

    if plot_preset == "reconnection_spectra":

        if read_data: # If false, then skip to plotting

            # Settings for the plot
                
            # Unpack data
            X3d = data3d["X"]
            Y3d = data3d["Y"]
            Z3d = data3d["Z"]
            Bx3d = data3d["Bx"]*1e-9  #[T]
            By3d = data3d["By"]*1e-9  #[T]
            Bz3d = data3d["Bz"]*1e-9  #[T]
            Ex3d = data3d["Ex"]*1e-6  # [V/m]
            Ey3d = data3d["Ey"]*1e-6  # [V/m]
            Ez3d = data3d["Ez"]*1e-6  # [V/m]
            n = data3d["rhoS1"]*1e6 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # 1/m^3
            rho = n*amu # [kg/m^3]
            Jx3d = (e*data3d["rhoS1"]*1e6*(data3d["uxS1"]-data3d["uxS0"])*1e3) # [A/m^2]
            Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3) # [A/m^2]
            Jz3d = (e*data3d["rhoS1"]*1e6*(data3d["uzS1"]-data3d["uzS0"])*1e3) # [A/m^2]
            pxxS0 = data3d["pxxS0"]*1e-9 #[Pa]
            pxyS0 = data3d["pxyS0"]*1e-9 #[Pa]
            pxzS0 = data3d["pxzS0"]*1e-9 #[Pa]
            pyyS0 = data3d["pyyS0"]*1e-9 #[Pa]
            pyzS0 = data3d["pyzS0"]*1e-9 #[Pa]
            pzzS0 = data3d["pzzS0"]*1e-9 #[Pa]
            uex3d = data3d["uxS0"]*1e3 # [m/s]
            uey3d = data3d["uyS0"]*1e3 # [m/s]
            uez3d = data3d["uzS0"]*1e3 # [m/s]
            dB_dx = data3d["dB_dx"]*1e-9 # T/m
            dB_dy = data3d["dB_dy"]*1e-9 # T/m
            dB_dz = data3d["dB_dz"]*1e-9 # T/m
            if "dBx_dx" in data3d.keys():
                dBx_dx = data3d["dBx_dx"]*1e-9  #[T/m]
                dBx_dy = data3d["dBx_dy"]*1e-9  #[T/m]
                dBx_dz = data3d["dBx_dz"]*1e-9  #[T/m]
                dBy_dx = data3d["dBy_dx"]*1e-9  #[T/m]
                dBy_dy = data3d["dBy_dy"]*1e-9  #[T/m]
                dBy_dz = data3d["dBy_dz"]*1e-9  #[T/m]
                dBz_dx = data3d["dBz_dx"]*1e-9  #[T/m]
                dBz_dy = data3d["dBz_dy"]*1e-9  #[T/m]
                dBz_dz = data3d["dBz_dz"]*1e-9  #[T/m]
            else:
                # Jacobian of magnetic field vector field was added late, not all numpy files have it.
                print("Computing dBa_da jacobian")
                dBx_dx = np.gradient(Bx3d,dx,axis=1)
                dBx_dy = np.gradient(Bx3d,dx,axis=0)
                dBx_dz = np.gradient(Bx3d,dx,axis=2)
                dBy_dx = np.gradient(By3d,dx,axis=1)
                dBy_dy = np.gradient(By3d,dx,axis=0)
                dBy_dz = np.gradient(By3d,dx,axis=2)
                dBz_dx = np.gradient(Bz3d,dx,axis=1)
                dBz_dy = np.gradient(Bz3d,dx,axis=0)
                dBz_dz = np.gradient(Bz3d,dx,axis=2)

            if iter==0:

                nxbins = len(X3d[0,:,0][(X3d[0,:,0]>x_region[0]) & (X3d[0,:,0]<x_region[1])])
                dxbin = (x_region[1] - x_region[0])/nxbins
                nybins = len(Y3d[:,0,0][(Y3d[:,0,0]>y_region[0]) & (Y3d[:,0,0]<y_region[1])])
                dybin = (y_region[1] - y_region[0])/nybins
                ntbins = int((t_bound[1] - t_bound[0])/dt)
                t_array = np.array([np.linspace(t_bound[0],(t_bound[1]-dt),ntbins)]).T
                x_array = np.array([np.linspace(x_region[0],(x_region[1]-dxbin),nxbins)])
                y_array = np.array([np.linspace(y_region[0],(y_region[1]-dybin),nybins)])
                S_active_array_x = np.zeros((ntbins,nxbins))
                S_active_array_y = np.zeros((ntbins,nybins))
                c_sites_array_x = np.zeros((ntbins,nxbins))
                c_sites_array_y = np.zeros((ntbins,nybins))

                # Work out indices for data binning
                izmin = np.where(Z3d[0,0,:]>z_region[0])[0][0]
                izmax = np.where(Z3d[0,0,:]<z_region[1])[0][-1]
                ixbins = np.where((X3d[0,:,0]>x_region[0]) & (X3d[0,:,0]<x_region[1]))[0]
                iybins = np.where((Y3d[:,0,0]>y_region[0]) & (Y3d[:,0,0]<y_region[1]))[0]

                # Cap the reconnection cell count at X = -1.25
                ixmax = np.where(X3d[0,:,0]>-1.25)[0][0]
    
            # Compute reconnection scores
            L, D_e, APhi, root_Q, S = compute_recon_score(Bx3d,By3d,Bz3d,Ex3d,Ey3d,Ez3d,Jx3d,Jy3d,Jz3d,uex3d,uey3d,uez3d,pxxS0,pyyS0,pzzS0,pxyS0,pxzS0,pyzS0)

            # Compute first critera of Wang et al. 2022
            epsilon = 1 # 1 nT, as given in paper
            c1 = (Jx3d**2+Jy3d**2+Jz3d**2)*dx / (np.sqrt((Jy3d*Bz3d-Jz3d*By3d)**2 + (Jz3d*Bx3d-Jx3d*Bz3d)**2 + (Jx3d*By3d-Jy3d*Bx3d)**2) + np.sqrt(Jx3d**2+Jy3d**2+Jz3d**2)*epsilon)
            # c1>0.8 is recommended threshold
            
            # Compute second criteria, based on curvature divergence. If large enough, it seperates x lines from o-lines or flux ropes
            Bmag = np.sqrt(Bx3d**2+By3d**2+Bz3d**2) # [T]
            bx = Bx3d/Bmag
            by = By3d/Bmag
            bz = Bz3d/Bmag
            # We need partial derivatives of the magnetic field unit vectors.
            # To reduce the number of derivates to take, we compute them using the chain rule and the pre-computed derivatives
            dbx_dx = (1/Bmag) * (dBx_dx - bx*dB_dx)
            dbx_dy = (1/Bmag) * (dBx_dy - bx*dB_dy)
            dbx_dz = (1/Bmag) * (dBx_dz - bx*dB_dz)
            dby_dx = (1/Bmag) * (dBy_dx - by*dB_dx)
            dby_dy = (1/Bmag) * (dBy_dy - by*dB_dy)
            dby_dz = (1/Bmag) * (dBy_dz - by*dB_dz)
            dbz_dx = (1/Bmag) * (dBz_dx - bz*dB_dx)
            dbz_dy = (1/Bmag) * (dBz_dy - bz*dB_dy)
            dbz_dz = (1/Bmag) * (dBz_dz - bz*dB_dz)
            # Precompute the terms of c2
            pre_c2_x = bx*dbx_dx + by*dbx_dy + bz*dbx_dz
            pre_c2_y = bx*dby_dx + by*dby_dy + bz*dby_dz
            pre_c2_z = bx*dbz_dx + by*dbz_dy + bz*dbz_dz
            # Compute the gradient of each term
            c2_x = np.gradient(pre_c2_x,dx*R_M,axis=1)
            c2_y = np.gradient(pre_c2_y,dx*R_M,axis=0)
            c2_z = np.gradient(pre_c2_z,dx*R_M,axis=2)
            c2 = (c2_x+c2_y+c2_z)*(dx)**2
            
            # Declare current sheet region to plot through
            #z_cs = [0.0,0.4]
            #iz_cs = [np.where(Z3d[0,0,:]>z_cs[0])[0][0],np.where(Z3d[0,0,:]<z_cs[1])[0][-1]]
            
            # Downselect to cells which fulfill the criteria
            c1_min = 0.015
            c2_min = 1e-9

            # Combine criteria
            recon_sites = np.copy(c1)
            recon_sites[c1<c1_min] = 0
            recon_sites[c2<c2_min] = 0
            recon_sites[recon_sites>0]=1

            #recon_sites_count = np.sum(recon_sites[:,:,iz_cs[0]:iz_cs[1]], axis = 2)
            
            # Save to arrays
            for ibin in range(nxbins):
                #S_mean_array[iter,ibin] = np.nanmean(S[:,ixbins[ibin],izmin:izmax])
                S_active_array_x[iter,ibin] = len(S[:,ixbins[ibin],izmin:izmax][S[:,ixbins[ibin],izmin:izmax]>5])
                c_sites_array_x[iter,ibin] = len(recon_sites[:,ixbins[ibin],izmin:izmax][recon_sites[:,ixbins[ibin],izmin:izmax]>0])
            for ibin in range(nybins):
                S_active_array_y[iter,ibin] = len(S[iybins[ibin],:ixmax,izmin:izmax][S[iybins[ibin],:ixmax,izmin:izmax]>5])
                c_sites_array_y[iter,ibin] = len(recon_sites[iybins[ibin],:ixmax,izmin:izmax][recon_sites[iybins[ibin],:ixmax,izmin:izmax]>0])

            if np.sum(S_active_array_x[iter,:])<2:
                print("ERROR! LOW S")
                #break

    if plot_preset == "entropy_FAC_map":

        # Declare radial distance of surface map
        r_proj = 1.1

        # declare resolution
        n_long = 64
        n_lat = 128
        theta_lims = np.array([55,90]) * np.pi/180
        long_lims = np.array([120,240]) * np.pi/180
        center = [0,0,0.2]
    
        # Load in variables
        X3d = data3d["X"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Xcs = datacs["X"]#
        Y3d = data3d["Y"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Ycs = datacs["Y"]#
        Z3d = data3d["Z"]#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]
        Zcs = datacs["Z"]
        ZMSM3d = Z3d - 0.2
        RMSM3d = np.sqrt(X3d**2+Y3d**2+(ZMSM3d)**2)
        if read_data:
            Bx3d = data3d["Bx"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[T]
            By3d = data3d["By"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[T]
            Bz3d = data3d["Bz"]*1e-9 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[T]
            Ex3d = data3d["Ex"]*1e-6 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-6 # V/m
            Ey3d = data3d["Ey"]*1e-6 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-6 # V/m
            Ez3d = data3d["Ez"]*1e-6 #[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-6 # V/m
            pxxS0 = data3d["pxxS0"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pxyS0 = data3d["pxyS0"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pxzS0 = data3d["pxzS0"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pyyS0 = data3d["pyyS0"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pyzS0 = data3d["pyzS0"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pzzS0 = data3d["pzzS0"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pxxS1 = data3d["pxxS1"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pxyS1 = data3d["pxyS1"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pxzS1 = data3d["pxzS1"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pyyS1 = data3d["pyyS1"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pyzS1 = data3d["pyzS1"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            pzzS1 = data3d["pzzS1"]*1e-9#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e-9 #[Pa]
            n = data3d["rhoS1"]*1e6#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]]*1e6 # 1/m^3
            rho = n*amu # [kg/m^3]
            Jx3d = (e*data3d["rhoS1"]*1e6*(data3d["uxS1"]-data3d["uxS0"])*1e3)#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jy3d = (e*data3d["rhoS1"]*1e6*(data3d["uyS1"]-data3d["uyS0"])*1e3)#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]
            Jz3d = (e*data3d["rhoS1"]*1e6*(data3d["uzS1"]-data3d["uzS0"])*1e3)#[trim_y[0]:trim_y[-1],trim_x[0]:trim_x[-1],trim_z[0]:trim_z[-1]] # [A/m^2]

        # Create the meshgrid of points at the chosen distance
        if iter == 0:
            phi_ax = np.linspace(*long_lims, n_long)
            theta_ax = np.linspace(*theta_lims, n_lat) # theta in MSO
            phi,theta = np.meshgrid(phi_ax,theta_ax)
            x_proj = r_proj * np.sin(theta) * np.cos(phi) + center[0]
            y_proj = r_proj * np.sin(theta) * np.sin(phi) + center[1]
            z_proj = r_proj * np.cos(theta) + center[2]
            
            # Compute mlat/long meshgrid
            long,lat = np.meshgrid(phi_ax,theta_ax)
            mlat = np.arcsin((z_proj-0.2)/r_proj) * 180/np.pi
            long = long * 180/np.pi
            
            # Translate into a list of seeds for tracing
            print('making seeds')
            point_seeds = np.array([np.ravel(x_proj), np.ravel(y_proj), np.ravel(z_proj)]).T
    
        # Compute entropy content integrand according to : 10.5194/angeo-22-1773-2004
        # S = \int p^{1/\gamma) ds/B
        if not file_found:
            print("Computing S integrand")
            B_mag = np.sqrt(Bx3d**2+By3d**2+Bz3d**2)*1e9 # [nT]
            bx,by,bz = [Bx3d,By3d,Bz3d]/(B_mag*1e-9)
            p = (pxxS0+pyyS0+pzzS0+pxxS1+pyyS1+pzzS1)/3 * 1e9 # [nPa]
            gamma = 5/3
            S_integrand = p**gamma / B_mag #[(nPa)^5/3 / nT]
            
            # Declare an interpolator object to compute the volume at each point along the field lines
            print('defining interpolator')
            S_interpolator = RegularGridInterpolator((X3d[0,:,0], Y3d[:,0,0], Z3d[0,0,:]), np.swapaxes(S_integrand,0,1), bounds_error=False, fill_value=None)
        
            # Create a one-to-one list to store the entropy in (by default -1 for open field lines)
            S_list = np.zeros_like(np.ravel(x_proj))-1

            # Set up field line tracing
            print('Getting tracer grid')
            tracer,grid = get_tracer(X3d,Y3d,Z3d,Bx3d,By3d,Bz3d,nsteps = 4000,step_size = 5e-4)
        
            # Trace field lines
            print("beginning tracing of",len(point_seeds),"lines!")
            tracer.trace(point_seeds, grid)
            print("done!")

             # Interate through each trace
            for iseed, seed in enumerate(point_seeds):
                
                # Get x,y,z
                trace = tracer.xs[iseed]
                trace_x = trace[:,0]
                trace_y = trace[:,0]
                trace_z = trace[:,0]

                # Report progress
                if iseed%2e3 == 0:
                    print("Checked",iseed,"out of",len(point_seeds)) 
        
                # Filter to closed field lines
                if (np.sum(trace[0,:]**2)<2**2) and (np.sum(trace[-1,:]**2)<2**2):
                    
                    # INTERPOLATE: interpolate the volume value at each point along the field line
                    # Compute distances between points
                    diffs = np.diff(trace, axis=0)
                    l = np.sqrt(np.sum(diffs**2, axis=1))
                    S_interp = np.nansum(S_interpolator(trace))
                    S_list[iseed] = np.sum(l*S_interp) #[(nPa)^5/3 R_M / nT]
        
            S_array = S_list.reshape(mlat.shape)
        
            # Compute FACs
            J_para = (np.abs(bx)*Jx3d + np.abs(by)*Jy3d + np.abs(bz)*Jz3d)*1e9 # [nA/m^2]
            J_para_x = J_para*np.abs(bx)
            J_para_y = J_para*np.abs(by)
            J_para_z = J_para*np.abs(bz)
    
            # Create a radial (from dipole center) interpolator for FAC
            FAC_rhat = (J_para_x * X3d + J_para_y * Y3d + J_para_z * ZMSM3d) / (RMSM3d)
            FAC_interpolator = RegularGridInterpolator((X3d[0,:,0], Y3d[:,0,0], Z3d[0,0,:]), np.swapaxes(FAC_rhat,0,1), bounds_error=False, fill_value=None)
            FAC_list = FAC_interpolator(point_seeds)
            FAC_array = FAC_list.reshape(mlat.shape)

        # Plot
        fig,axs = plt.subplots(figsize=(16,10),ncols=2,nrows=2)

        # Compute differences from background field
        if (S_array.shape == background_field['S_planet'].shape):
            S_S0 = S_array/background_field['S_planet']
            S_S0[(S_array == -1) | (background_field['S_planet'] == -1)] = -1

            delta_FAC = FAC_array-background_field['FAC']
        else:
            print("Shape mismatch between background field average and this one")
            S_S0 = background_field['S_planet']
            delta_FAC = background_field['FAC']
    
        # Plot S
        S_lims = [-2,5]
        S_plot = axs[0,0].imshow(np.log(S_array),extent = [np.min(long),np.max(long),np.min(mlat),np.max(mlat)],vmin=S_lims[0],vmax=S_lims[1])
        #S_fill = axs[0,0].imshow((S_array<0),extent = [np.min(long),np.max(long),np.min(mlat),np.max(mlat)],alpha = 0.5)
        deltaS_plot = axs[1,0].imshow(np.log(S_S0), extent = [np.min(long),np.max(long),np.min(mlat),np.max(mlat)],
                                      cmap = 'rainbow', vmin=-2,vmax=2)

        # Plot FAC
        FAC_plot = axs[0,1].imshow(FAC_array,extent = [np.min(long),np.max(long),np.min(mlat),np.max(mlat)],vmin=-100,vmax=100,cmap='bwr')
        deltaFAC_plot = axs[1,1].imshow(delta_FAC,extent = [np.min(long),np.max(long),np.min(mlat),np.max(mlat)]
                                        ,vmin=-100,vmax=100,cmap='bwr')

        for row in [0,1]:
            for col in [0,1]:
                axs[row,col].set_xlabel("Longitude [deg]",fontsize=12)
                axs[row,col].set_ylabel("Magnetic latitude [deg]",fontsize=12)
                axs[row,col].grid()
                axs[row,col].set_xlim(120,240)
                axs[row,col].set_ylim(90 - theta_lims[1]*180/np.pi,90 - theta_lims[0]*180/np.pi)
                axs[row,col].set_aspect(2)

        axs[0,0].set_title(str("Flux tube entropy\nprojected at $R_{MSM}$ = "+str(r_proj)+", t="+time+"s"),fontsize=12)
        axs[1,0].set_title(str("Entropy normalized to background,\nprojected at $R_{MSM}$ = "+str(r_proj)+", t="+time+"s"),fontsize=12)
        axs[0,1].set_title(str("Field-aligned currents\nprojected at $R_{MSM}$ = "+str(r_proj)+", t="+time+"s"),fontsize=12)
        axs[1,1].set_title(str("Field-aligned currents minus background\nprojected at $R_{MSM}$ = "+str(r_proj)+", t="+time+"s"),fontsize=12)
        
        clb1 = fig.colorbar(S_plot, ax = axs[0,0])
        clb1.ax.set_title("$log(S)$")
        clb2 = fig.colorbar(deltaS_plot, ax = axs[1,0])
        clb2.ax.set_title("$log(S/S_0)$")
        clb3 = fig.colorbar(FAC_plot, ax = axs[0,1])
        clb3.ax.set_title("$J_{\parallel,r}$\n[nA/m$^2$]")
        clb4 = fig.colorbar(deltaFAC_plot, ax = axs[1,1])
        clb4.ax.set_title("$J_{\parallel,r} - J_{\parallel,0}$\n[nA/m$^2$]")

        # Save S data
        if read_data:
            pickle.dump({'mlat':mlat,'long':long,'R':r_proj,'S':S_array,'FAC':FAC_array}, open(str(dir+"S_data/S_t_"+'{:06.2f}'.format(round(float(time),2))), 'wb') )

    if plot_preset == "flux_tube_visualizer":

        if iter == 0:
            t_ls = []
            S_ls = []
            S0_ls = []
        
        if read_data:

            # Load in DF mask
            DF_name = "run3_paper_DDF"
            exampleDFfiles = get_files(str(dir+"DFs/"),key="DF\_filtered\_run3\_paper\_DDF\_t\_...\...",read_time = True)
            
            exampleDFfile = str(exampleDFfiles[time])
            with open(dir+"DFs/"+exampleDFfile, 'rb') as f:
                print("reading DF example mask:",str(dir+"DFs/"+exampleDFfile))
                DF_mask = pickle.load(f) 
            
            # Load in variables
            X3d = data3d["X"]
            Xcs = datacs["X"]
            Y3d = data3d["Y"]
            Ycs = datacs["Y"]
            Z3d = data3d["Z"]
            ZMSM = Z3d - 0.2
            Zcs = datacs["Z"]
            RMSM = np.sqrt(X3d**2+Y3d**2+ZMSM**2)
            Bx3d = data3d["Bx"] # [nT]
            By3d = data3d["By"] # [nT]
            Bz3d = data3d["Bz"] # [nT]
            Bzcs = datacs["Bz"]

            r_proj = entropy_map['R']
            mlat = entropy_map['mlat']
            long = entropy_map['long']
            S = entropy_map['S']
            S0 = background_field['S_planet']
            FAC_array = entropy_map['FAC']

            footpoint_map = np.zeros_like(S) - 1 # -1 is no DF footpoint, +1 means DF lands there

            # Convert the entropy map back into x,y,z
            theta = (90 - mlat)*np.pi/180
            phi = long * np.pi/180
            x_proj = (r_proj) * np.sin(theta)*np.cos(phi)
            y_proj = (r_proj) * np.sin(theta)*np.sin(phi)
            z_proj = (r_proj) * np.cos(theta) + 0.2 # Convert from MSM to MSO

            # Compute background-normalized entropy
            #S_S0 = S/S_0
            #S_S0[(S == -1) | (S_0 == -1)] = -1
            seed_mask = (np.log(S) < 1.5) & (np.log(S) > -2) & (long > 160) & (long<180)

            # Flatten the meshgrid to interpolate Z(x,y)
            down_resolve=16
            cs_points = np.column_stack((Xcs[::down_resolve,::down_resolve].ravel(), Ycs[::down_resolve,::down_resolve].ravel()))
            cs_values = Zcs[::down_resolve,::down_resolve].ravel()
    
            # Cut off current sheet inside entropy map
            Xcs[Xcs**2+Ycs**2+(Zcs-0.2)**2 <= r_proj**2] = np.nan
            #Xcs[np.abs(Zcs-0.2)>0.1] = np.nan
            Bzcs[Xcs < x_region[0]] = np.nan
            Bzcs[Ycs < y_region[0]] = np.nan

            # Set B to nan inside r_proj, to avoid tracing field lines beyond there
            Bx3d[RMSM < r_proj*0.99] = np.nan
            By3d[RMSM < r_proj*0.99] = np.nan
            Bz3d[RMSM < r_proj*0.99] = np.nan
    
            # Set up field line tracing
            #seeds = np.array([np.ravel(x_proj[seed_mask]), np.ravel(y_proj[seed_mask]), np.ravel(z_proj[seed_mask])]).T
            seeds = np.array([np.ravel(X3d[DF_mask]), np.ravel(Y3d[DF_mask]), np.ravel(Z3d[DF_mask])]).T
            tracer,grid = get_tracer(X3d,Y3d,Z3d,Bx3d,By3d,Bz3d)
            print("beginning tracing of",len(seeds),"lines!")
            tracer.trace(seeds, grid)
            print("done!")
        
        fig = plt.figure(figsize=(10,12))#, constrained_layout=True)
        gs = GridSpec(2, 1, height_ratios=[6, 1])
        ax = fig.add_subplot(gs[0], projection="3d",computed_zorder=False)
        ax2d = fig.add_subplot(gs[1])
        #ax.set_proj_type('persp', focal_length=0.5)  # FOV = 157.4 deg
        
        # Show entropy map projected onto planet
        S_lims = [-2,5]
        plot_colored_surface(ax, x_proj, y_proj, z_proj, np.log(S), vmin = S_lims[0], vmax = S_lims[1], cmap = 'viridis', alpha = 1, zorder = 2, 
                             shading = False, nan_threshold = -1e5)
        # Show flipped image for lower hemisphere
        plot_colored_surface(ax, x_proj, y_proj, 0.4-z_proj, np.log(S), vmin = S_lims[0], vmax = S_lims[1], cmap = 'viridis', alpha = 1, zorder = 2, 
                             shading = False, nan_threshold = -1e5)

        # Show DF
        above_mask = create_above_surface_mask(Xcs, Ycs, Zcs, X3d, Y3d, Z3d) 
        above_DF_mask = above_mask & DF_mask
        below_DF_mask = np.invert(above_mask) & DF_mask
        DF_plot1 = ax.scatter(X3d[above_DF_mask],Y3d[above_DF_mask],Z3d[above_DF_mask],c=(-Z3d[above_DF_mask]),cmap='Greens',vmin=-0.6,vmax=0.1,
                         s=15,zorder=3,alpha = 1)#,alpha = np.clip(((deltaBz3d[above_mask]-Bmin)/(vmax-Bmin))**0.2,0,1))
        DF_plot2 = ax.scatter(X3d[below_DF_mask],Y3d[below_DF_mask],Z3d[below_DF_mask],c=(-Z3d[below_DF_mask]),cmap='Greens',vmin=-0.5,vmax=0.5,
                         s=10,zorder=2.45,alpha = 0.5)#,alpha = np.clip(((deltaBz3d[below_mask]-Bmin)/(vmax-Bmin))**0.2,0,1))


        # Show current sheet
        Bz_lims = [-100,100]
        plot_colored_surface(ax, Xcs, Ycs, Zcs, Bzcs, vmin = Bz_lims[0], vmax = Bz_lims[1], cmap = 'bwr', alpha = 0.4, zorder = 2.5, 
                             shading = True, nan_threshold = -1e5)

        # DEBUG Show seed points
        #ax.scatter(seeds[:,0],seeds[:,1],seeds[:,2],s=3,color='red',zorder=500)
        
        # Show DF field lines
        lw = 0.5
        zorder_abo = 10 
        zorder_bel = 2.20
        fl_color = 'darkgreen'
        fl_alpha = 0.25
        
        for iseed, seed in enumerate(seeds):
        
            # Unpack line data
            trace_x = tracer.xs[iseed][:,0]
            trace_y = tracer.xs[iseed][:,1]
            trace_z = tracer.xs[iseed][:,2]
    
            # Work out which parts of the traces are above and below
            above = np.where(tracer.xs[iseed][:,2]>=griddata(cs_points, cs_values, tracer.xs[iseed][:,0:2], method='linear'))[0]
            below = np.where(tracer.xs[iseed][:,2]<griddata(cs_points, cs_values, tracer.xs[iseed][:,0:2], method='linear'))[0]
    
            # Plot the streamlines as a series of lines, without connecting between places where the indexing jumps
            start = 0 
            for j in range(1,len(above)):
                if (above[j]-above[j-1]>1) or (j==(len(above)-1)):
                    ax.plot(trace_x[above[start:j]],trace_y[above[start:j]],trace_z[above[start:j]],
                           color=fl_color,lw=lw,zorder=zorder_abo,alpha=fl_alpha) 
                    start = j
                
            start = 0
            for j in range(1,len(below)):
                if (below[j]-below[j-1]>1) or (j==(len(below)-1)):
                    ax.plot(trace_x[below[start:j]],trace_y[below[start:j]],trace_z[below[start:j]],
                           color=fl_color,lw=lw/2,zorder=zorder_bel,linestyle='dashed',alpha=fl_alpha) 
                    start = j

            # For the purposes of entropy cal, check if its a closed field line, by seeing if both ends are *close to* planet
            trace = tracer.xs[iseed]
            if (np.sum(trace[0,:]**2)<1.5**2) and (np.sum(trace[-1,:]**2)<1.5**2):
                
                # Find north hemisphere footpoint coords
                if trace[0,2] > 0.2:
                    endpoint = trace[0,:]
                else:
                    endpoint = trace[-1,:]

                #print("Endpoint RMSM:",np.sqrt(endpoint[0]**2+endpoint[1]**2+(endpoint[2]-0.2)**2))

                # Find long/lat indicies of footpoint
                dist_map = np.sqrt((x_proj-endpoint[0])**2+(y_proj-endpoint[1])**2+(z_proj-endpoint[2])**2)
                min_index = np.argmin(dist_map)
                ilong,imlat = np.unravel_index(min_index, dist_map.shape)

                #print("Endpoint:",endpoint)
                #print("Closest xyz:",x_proj[ilong,imlat],y_proj[ilong,imlat],z_proj[ilong,imlat])
                #print("Long/mlat:",long[ilong,imlat],mlat[ilong,imlat])

                footpoint_map[ilong,imlat] = 1
        

        # Show entropy map projected onto planet
        S_lims = [-2,5]
        plot_colored_surface(ax, x_proj, y_proj, z_proj, footpoint_map, vmin = 0, vmax = 1, cmap = 'Greens', alpha = 1, zorder = 2.5, 
                             shading = False, nan_threshold = 0)

        ax.set_xlim(x_region)
        ax.set_ylim(y_region)
        ax.set_zlim(z_region)
        ax.set_aspect('equal')

        ax.set_title(str("DF fieldlines at t = "+time+"s"))
    
        ax.set_xlabel("X [$R_M$]")
        ax.set_ylabel("Y [$R_M$]")
        ax.set_zlabel("Z [$R_M$]")

        # Add a color bar for S
        norm = plt.Normalize(*S_lims)
        m = cm.ScalarMappable(cmap=cm.viridis, norm=norm)
        m.set_array(S)
        clb1 = fig.colorbar(m, ax=ax, shrink=0.1, aspect=7)#,anchor=(-20,0.9))
        clb1.ax.tick_params(labelsize=12)
        clb1.ax.set_title('$log(S)$',fontsize=12)#,pad=10)
        # Add a color bar for Bz
        norm = plt.Normalize(*Bz_lims)
        m = cm.ScalarMappable(cmap=cm.bwr, norm=norm)
        m.set_array(Bzcs)
        clb2 = fig.colorbar(m, ax=ax, shrink=0.1, aspect=7)#,anchor=(-20,0.9))
        clb2.ax.tick_params(labelsize=12)
        clb2.ax.set_title('$B_z$',fontsize=12)#,pad=10)

        # Show planet
        plot_sphere(ax,radius=1,color='lightgrey',alpha=0.9,zorder=1,xlims=x_region,ylims=y_region,zlims=z_region)
        plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=0.75,xlims=x_region,ylims=y_region,zlims=z_region)
        
        # Set viewing angle
        ax.view_init(elev=15, azim=165)

        # Save results for timeseries
        t_ls.append(float(time))
        S_ls.append(np.mean(S[footpoint_map>0]))
        S0_ls.append(np.mean(S0[footpoint_map>0]))

        # Work on timeseries plot
        ax2d.set_xlim(t_bound)
        ax2d.set_ylim(S_lims)

        ax2d.plot(t_ls,np.log(S_ls),color='black')
        ax2d.plot(t_ls,np.log(S0_ls),color='black',linestyle='dashed')

        ax2d.set_ylabel("S [(nPa)$^{5/3} R_M$/nT]")

        plt.tight_layout()
        plt.subplots_adjust(hspace=0)  # Decrease this value to bring subplots closer

    if plot_preset == "pressure_balance":

        if read_data:

            Y3d = data3d["Y"]
            Z3d = data3d["Z"]

            # Find index of line to plot
            iy = np.where(Y3d[:,0,0]>loc[1])[0][0]
            iz = np.where(Z3d[0,0,:]>loc[2])[0][0]

            # Read in other data in SI
            X = data3d["X"][iy,:,iz] # [R_M]
            R = np.sqrt(X**2+Y3d[iy,:,iz]**2+Z3d[iy,:,iz]**2)
            p_e = (data3d["pxxS0"]+data3d["pyyS0"]+data3d["pzzS0"])[iy,:,iz] / 3 * 1e-9 #[Pa]
            p_i = (data3d["pxxS1"]+data3d["pyyS1"]+data3d["pzzS1"])[iy,:,iz] / 3 * 1e-9 #[Pa]
            Bx = data3d["Bx"][iy,:,iz] * 1e-9 # [T]
            By = data3d["By"][iy,:,iz] * 1e-9 # [T]
            Bz = data3d["Bz"][iy,:,iz] * 1e-9 # [T]
            n = data3d["rhoS1"][iy,:,iz] * 1e6 # [1/m^3]
            rho = m_p * n
            ux = data3d["uxS1"][iy,:,iz] * 1e3 # [m/s]
            uy = data3d["uyS1"][iy,:,iz] * 1e3 # [m/s]
            uz = data3d["uzS1"][iy,:,iz] * 1e3 # [m/s]

            # Compute magnetic field unit vector
            B_mag = np.sqrt(Bx**2+By**2+Bz**2) # [T]
            bx,by,bz = 1/B_mag * [Bx,By,Bz]

            # Compute field-aligned velocity
            u_para = ux*bx+uy*by+uz*bz

            # Compute perpendicular speed #components
            u_mag = np.sqrt(ux**2+uy**2+uz**2)
            u_perp = u_mag - np.abs(u_para)
            ux_perp = ux - u_para*bx
            #uy_perp = uy - u_para*by
            #uz_perp = uz - u_para*bz

            # Compute magnetic pressure
            p_mag = B_mag**2 / (2*mu_0)

            # Compute dynamic pressure
            p_dyn = rho*u_perp**2

            # Plot
            fig,ax = plt.subplots(figsize=(9,6))

            ax.plot(X,p_i*1e9,label='$p_i$')
            ax.plot(X,p_e*1e9,label='$p_e$')
            ax.plot(X,p_mag*1e9,label='$p_{mag}$')
            ax.plot(X,p_dyn*1e9,label='$p_{dyn}$')

            # Show base parameters
            ax2 = ax.twinx()
            ax2.plot(X,Bz*1e9,color='black')
            ax2.axhline(y=0,color='black',linestyle='dashed')

            ax.legend()
            ax.grid()
            ax.set_xlabel("X [$R_M$]")
            ax.set_ylabel("p [nPa]")
            ax2.set_ylabel("$B_z$ [nT]")

            ax.set_xlim(x_region[0],np.max(X[R>1.1]))
            ax.set_ylim(0,4)
            ax2.set_ylim(-20,120)

            ax.set_title(str("Pressure balance along y = "+str(round(Y3d[iy,0,iz],2))+", z = "+str(round(Z3d[iy,0,iz],2))+", at t = "+str(time)))

    # Update iterator
    iter +=1

    # Save
    if plot_preset in ["DF_example_extraction","DF_example_filtering","DF_example_visualizer"]:
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+DF_name+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi=300)
        plt.show()
        plt.close(fig)
    elif plot_preset in ["B_timeseries","total_current","current_spectra","reconnection_spectra","xz_slice"]:
        # For timeseries plotting, we only want to save at the end
        continue
    else:
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi=300)
        plt.show()
        plt.close(fig)

# TIMESERIES PLOTTING
# Some plotting modes only generate one plot at the end... this is for those

if plot_preset == "B_timeseries":

    if file_found and (not read_data):
        t_ls = timeseries['t']
        Bx_ls = timeseries['Bx']
        By_ls = timeseries['By']
        Bz_ls = timeseries['Bz']
        Bx_avg = timeseries['Bx0']
        By_avg = timeseries['By0']
        Bz_avg = timeseries['Bz0']
        uix_ls = timeseries['uix']
        uiy_ls = timeseries['uiy']
        uiz_ls = timeseries['uiz']
        uex_ls = timeseries['uex']
        uey_ls = timeseries['uey']
        uez_ls = timeseries['uez']
        pi_ls = timeseries['pi']
        pe_ls = timeseries['pe']
        n_ls = timeseries['n']
    else:
        # Convert to arrays
        t_ls = np.array(t_ls)
        Bx_ls = np.array(Bx_ls)
        By_ls = np.array(By_ls)
        Bz_ls = np.array(Bz_ls)
        pi_ls = np.array(pi_ls)
        pe_ls = np.array(pe_ls)
        n_ls = np.array(n_ls)
        uix_ls = np.array(uix_ls)
        uiy_ls = np.array(uiy_ls)
        uiz_ls = np.array(uiz_ls)
        uex_ls = np.array(uex_ls)
        uey_ls = np.array(uey_ls)
        uez_ls = np.array(uez_ls)
    
        # Load in background field data
        Bx_avg = background_field["Bx"][iy,ix,iz]
        By_avg = background_field["By"][iy,ix,iz]
        Bz_avg = background_field["Bz"][iy,ix,iz]

    Ti_ls = pi_ls*1e-9 / (n_ls*1e6*k_b) / 11605
    Te_ls = pe_ls*1e-9 / (n_ls*1e6*k_b) / 11605

    fig,axs = plt.subplots(figsize = (15,8),nrows = 4, gridspec_kw={'height_ratios': [4,3,2,2]},sharex=True)
    fig.subplots_adjust(hspace=0)

    # Show magnetic fields
    axs[0].plot(t_ls,Bx_ls,color='red',label = '$B_x$')
    axs[0].axhline(y=Bx_avg,color='red',linestyle='dashed',lw = 1)
    axs[0].plot(t_ls,By_ls,color='green',label = '$B_y$')
    axs[0].axhline(y=By_avg,color='green',linestyle='dashed',lw = 1)
    axs[0].plot(t_ls,Bz_ls,color='blue',label = '$B_z$')
    axs[0].axhline(y=Bz_avg,color='blue',linestyle='dashed',lw = 1)
    #axs[0].plot(t_ls,np.sqrt(Bx_ls**2+By_ls**2+Bz_ls**2),color='black',label = '$B_t$')
    #axs[0].axhline(y=np.sqrt(Bx_avg**2+By_avg**2+Bz_avg**2),color='black',linestyle='dashed',lw = 1)
    axs[0].legend(loc='lower right',ncols=4,fontsize=15)
    axs[0].set_ylabel("B [nT]",fontsize=15)
    axs[0].set_title(str("Magnetic field at "+str(loc)),fontsize=15)
    #axs[0].set_ylim(-20,140)
    
    # Show pressure
    axs[1].plot(t_ls,Ti_ls,label = "$T_i$")
    axs[1].plot(t_ls,Te_ls,label = "$T_e$")
    axs1_2 = axs[1].twinx()
    axs1_2.plot(t_ls,n_ls,color='black')
    axs[1].legend(loc='lower right',ncols=4,fontsize=15)
    axs[1].set_ylabel("T [eV]",fontsize=15)
    axs1_2.set_ylabel("n [amu/cc]",fontsize=15)

    # Show ion velocity
    axs[2].plot(t_ls,uix_ls,color='red',label='$u_x$')
    axs[2].plot(t_ls,uiy_ls,color='green',label='$u_y$')
    axs[2].plot(t_ls,uiz_ls,color='blue',label='$u_z$')
    axs[2].legend(loc='lower right',ncols=4,fontsize=15)
    axs[2].set_ylabel("$u_i$ [km/s]",fontsize=15)

    # Show electron velocity
    axs[3].plot(t_ls,uex_ls,color='red',label='$u_x$')
    axs[3].plot(t_ls,uey_ls,color='green',label='$u_y$')
    axs[3].plot(t_ls,uez_ls,color='blue',label='$u_z$')
    axs[3].legend(loc='lower right',ncols=4,fontsize=15)
    axs[3].set_ylabel("$u_e$ [km/s]",fontsize=15)

    for i,axi in enumerate(axs):
        axi.tick_params(axis='both',labelsize=15)
        axi.set_xlim(t_bound[0],t_bound[1]-dt)
        axi.set_xlabel("Time [s]",fontsize=15)
        axi.grid()

    fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_["+str(loc[0])+","+str(loc[1])+","+str(loc[2])+'].png'),bbox_inches='tight',dpi=300)
    plt.show()
    plt.close(fig)

    # Save the timeseries data
    if read_data and (not file_found):
        pickle.dump({"t":t_ls,"Bx":Bx_ls,"By":By_ls,"Bz":Bz_ls,"Bx0":Bx_avg,"By0":By_avg,"Bz0":Bz_avg,
                     "uix":uix_ls,"uiy":uiy_ls,"uiz":uiz_ls,"uex":uex_ls,"uey":uey_ls,"uez":uez_ls,
                    "pi":pi_ls,"pe":pe_ls,"n":n_ls},
                    open(str(dir+"timeseries/"+plot_preset+"_["+str(loc[0])+","+str(loc[1])+","+str(loc[2])+']'), 'wb') )


if plot_preset == "DF_example_visualizer":
    # Convert to arrays
    i_kinetic_ls = np.array(i_kinetic_ls)
    e_kinetic_ls = np.array(e_kinetic_ls)
    i_para_thermal_ls = np.array(i_para_thermal_ls)
    i_perp_thermal_ls = np.array(i_perp_thermal_ls)
    e_para_thermal_ls = np.array(e_para_thermal_ls)
    e_perp_thermal_ls = np.array(e_perp_thermal_ls)
    Jy_east_ls = np.array(Jy_east_ls)
    Jy_west_ls = np.array(Jy_west_ls)
    mag_ls = np.array(mag_ls)
    
    # Plot
    fig,ax = plt.subplots(figsize = (10,6),ncols=1)
    ax.plot(t_ls,i_kinetic_ls, label = "K [ion]")
    ax.plot(t_ls,e_kinetic_ls, label = "K [electron]")
    ax.plot(t_ls,i_para_thermal_ls, label = "T [ion, para]")
    ax.plot(t_ls,i_perp_thermal_ls, label = "T [ion, perp]")
    ax.plot(t_ls,e_para_thermal_ls, label = "T [electron, para]")
    ax.plot(t_ls,e_perp_thermal_ls, label = "T [electron, perp]")
    ax.plot(t_ls,mag_ls, label = "B")
    ax.plot(t_ls,(i_kinetic_ls+e_kinetic_ls+i_para_thermal_ls+i_perp_thermal_ls+e_para_thermal_ls+e_perp_thermal_ls+mag_ls),
                color='black',linestyle='dashed',label='total')
    #axs[1].plot(t_ls,np.abs(Jy_east_ls)*1e9, label = "Eastward current",color = 'blue')
    #axs[1].plot(t_ls,np.abs(Jy_west_ls)*1e9, label = "Westward current",color = 'red')
    
    ax.set_yscale("log")
    
    ax.set_xlabel("Time [s]",fontsize = 15)
    ax.set_ylabel("Avg. energy density [J/m$^3$]",fontsize = 15)
    ax.set_title(str("Energy density within "+str(x_region[0])+"<x<"+str(x_region[1])+", "+str(y_region[0])+"<y<"+str(y_region[1])+", "+str(z_region[0])+"<z<"+str(z_region[1])),fontsize=15)
    
    ax.tick_params(axis='both',labelsize=15)
    ax.legend(ncol=2,loc='lower right')
    
    
    fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_timeseries_"+DF_name+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi=300)
    plt.show()
    plt.close(fig)

if plot_preset == "total_current":
    # Plotmode = "DF_example_visualizer": time series of energy and current content
    # Convert to arrays
    t_ls = np.array(t_ls)
    J_E_ls = np.array(J_E_ls)
    J_W_ls = np.array(J_W_ls)
    J_up_FAC_ls = np.array(J_up_FAC_ls)
    J_down_FAC_ls = np.array(J_down_FAC_ls)

     # Plot
    fig,ax = plt.subplots(figsize = (8,6),ncols=1)
    ax.plot(t_ls,J_E_ls,label='Eastward',color='tab:blue')
    ax.plot(t_ls,np.abs(J_W_ls),label='Westward',color='tab:orange')
    ax.plot(t_ls,J_up_FAC_ls,label='upward FAC',color='tab:blue',linestyle='dashed')
    ax.plot(t_ls,np.abs(J_down_FAC_ls),label='downward FAC',color='tab:orange',linestyle='dashed')
    #ax.plot(t_ls,J_E_ls + J_W_ls, label='Total',color='black',linestyle='dashed')

    ax.set_xlabel("Time [s]",fontsize=15)
    ax.set_ylabel("Sum current density [nA/m$^2$]",fontsize=15)
    ax.set_title(str("Currents within "+str(R_region[0])+"<R<"+str(R_region[1])+" and  "+str(y_region[0])+"<y<"+str(y_region[1])),fontsize=15)

    ax.legend()
    ax.tick_params(axis='both',labelsize=15)

    ax.set_xlim(t_ls[0],t_ls[-1])

    #ax.set_yscale("log")

    fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi=300)
    plt.show()
    plt.close(fig)

if plot_preset == "Jy_components":
    # Convert to arrays
    t_ls = np.array(t_ls)
    Jy_ls = np.array(Jy_ls)
    Jperp_y_ls = np.array(Jperp_y_ls)
    Jy_dia_ls = np.array(Jy_dia_ls)
    Jy_inrt_i_ls = np.array(Jy_inrt_i_ls)
    delta_Jperp_y_ls = np.array(delta_Jperp_y_ls)
    
    # Plot
    fig,ax = plt.subplots(figsize = (10,6),ncols=1)
    ax.plot(t_ls,Jy_ls*1e-3, label = "$J_y$")
    ax.plot(t_ls,Jperp_y_ls*1e-3, label = "$J_{y,perp}$")
    ax.plot(t_ls,Jy_dia_ls*1e-3, label = "$J_{y,dia}$")
    ax.plot(t_ls,Jy_inrt_i_ls*1e-3, label = "$J_{y,inrt}$")
    ax.plot(t_ls,delta_Jperp_y_ls*1e-3, label = "$\Delta J_{y,perp}$")

    #axs[1].plot(t_ls,np.abs(Jy_east_ls)*1e9, label = "Eastward current",color = 'blue')
    #axs[1].plot(t_ls,np.abs(Jy_west_ls)*1e9, label = "Westward current",color = 'red')
    
    #ax.set_ylim(-500,500)
    
    ax.set_xlabel("Time [s]",fontsize = 15)
    ax.set_ylabel("Total current [kA]",fontsize = 15)
    ax.set_title(str("Cross-tail currents"),fontsize=15)
    
    ax.tick_params(axis='both',labelsize=15)
    ax.legend(ncol=2,loc='lower right')
    
    
    fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_timeseries_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi=300)
    plt.show()
    plt.close(fig)

if plot_preset == "current_spectra":

    J_lims = [-20,20]
    aspect = 'auto'
    cmap = 'bwr'
    
    fig,axs = plt.subplots(figsize=(18,12),nrows = 2, ncols = 2)
    
    axs[0,0].set_title(str("Total $I_y$ in Y = "+str(round(yplane,1))+" plane"),fontsize=15)
    Iy_plot = axs[0,0].imshow(Jy_array*1e-3,vmin = J_lims[0], vmax = J_lims[1], cmap = cmap, origin = 'lower',
              extent=[x_array[0,0],x_array[0,-1],t_array[0,0],t_array[-1,0]], aspect = aspect)
    
    axs[0,1].set_title(str("Total $I_{y,perp}$ in Y = "+str(round(yplane,1))+" plane"),fontsize=15)
    Iy_perp_plot = axs[0,1].imshow(Jy_perp_array*1e-3,vmin = J_lims[0], vmax = J_lims[1], cmap = cmap, origin = 'lower',
              extent=[x_array[0,0],x_array[0,-1],t_array[0,0],t_array[-1,0]], aspect = aspect)
    
    axs[1,0].set_title(str("Total $I_{y,dia}$ in Y = "+str(round(yplane,1))+" plane"),fontsize=15)
    Iy_dia_plot = axs[1,0].imshow(Jy_dia_array*1e-3,vmin = J_lims[0], vmax = J_lims[1], cmap = cmap, origin = 'lower',
              extent=[x_array[0,0],x_array[0,-1],t_array[0,0],t_array[-1,0]], aspect = aspect)
    
    axs[1,1].set_title(str("Total $I_{y,inrt}$ in Y = "+str(round(yplane,1))+" plane"),fontsize=15)
    Iy_inrt_plot = axs[1,1].imshow(Jy_inrt_array*1e-3,vmin = J_lims[0], vmax = J_lims[1], cmap = cmap, origin = 'lower',
              extent=[x_array[0,0],x_array[0,-1],t_array[0,0],t_array[-1,0]], aspect = aspect)
     
    
    for axi in np.ravel(axs):
        axi.invert_yaxis()
        axi.set_xlabel("$X_{MSM}$ [$R_M$]",fontsize=15)
        axi.set_ylabel("$t$ [sec]",fontsize=15)
        axi.tick_params(axis='both',labelsize=15)
        axi.set_xlim(*x_region)
    
    clb1 = fig.colorbar(Iy_plot,ax=axs[:,:],shrink=0.5)
    clb1.ax.set_title("$I_y$ [kA]")

    fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi=300)
    plt.show()
    plt.close(fig)

if plot_preset == "reconnection_spectra":

    plot_DFs = False

    # Load in S data in this case
    if not read_data:
        with open(str(dir+"S_t_array_"+"%.2f"%round(float(time),2)+".pkl"), 'rb') as f:
            t_array = pickle.load(f)
        with open(str(dir+"S_x_array_"+"%.2f"%round(float(time),2)+".pkl"), 'rb') as f:
            x_array = pickle.load(f)
        with open(str(dir+"S_y_array_"+"%.2f"%round(float(time),2)+".pkl"), 'rb') as f:
            y_array = pickle.load(f)
        with open(str(dir+"S_active_array_x_"+"%.2f"%round(float(time),2)+".pkl"), 'rb') as f:
            S_active_array_x = pickle.load(f)
        with open(str(dir+"S_active_array_y_"+"%.2f"%round(float(time),2)+".pkl"), 'rb') as f:
            S_active_array_y = pickle.load(f)
        with open(str(dir+"c_sites_array_x_"+"%.2f"%round(float(time),2)+".pkl"), 'rb') as f:
            c_sites_array_x = pickle.load(f)
        with open(str(dir+"c_sites_array_y_"+"%.2f"%round(float(time),2)+".pkl"), 'rb') as f:
            c_sites_array_y = pickle.load(f)

    #S_mean_lims = [np.mean(S_mean_array),np.max(S_mean_array)]
    S_active_lims = [0,20]#np.max([np.max(S_active_array_x),np.max(S_active_array_y)])]
    c_sites_lims = [0,50]#np.max([np.max(c_sites_array_x),np.max(c_sites_array_y)])]
    aspect = 'auto'
    cmap = 'viridis'
    
    fig,axs = plt.subplots(figsize=(18,24),nrows = 2, ncols = 2)
    
    #S_mean_plot = axs[0].imshow(S_mean_array,vmin = S_mean_lims[0], vmax = S_mean_lims[1], cmap = cmap, origin = 'lower',
    #          extent=[x_array[0,0],x_array[0,-1],t_array[0,0],t_array[-1,0]], aspect = aspect)
    S_active_x_plot = axs[0,0].imshow(S_active_array_x,vmin = S_active_lims[0], vmax = S_active_lims[1], cmap = cmap, origin = 'lower',
              extent=[x_array[0,0],x_array[0,-1],t_array[0,0],t_array[-1,0]], aspect = aspect)
    S_active_y_plot = axs[1,0].imshow(S_active_array_y,vmin = S_active_lims[0], vmax = S_active_lims[1], cmap = cmap, origin = 'lower',
              extent=[y_array[0,0],y_array[0,-1],t_array[0,0],t_array[-1,0]], aspect = aspect)
    c_sites_x_plot = axs[0,1].imshow(c_sites_array_x,vmin = c_sites_lims[0], vmax = c_sites_lims[1], cmap = cmap, origin = 'lower',
              extent=[x_array[0,0],x_array[0,-1],t_array[0,0],t_array[-1,0]], aspect = aspect)
    c_sites_y_plot = axs[1,1].imshow(c_sites_array_y,vmin = c_sites_lims[0], vmax = c_sites_lims[1], cmap = cmap, origin = 'lower',
              extent=[y_array[0,0],y_array[0,-1],t_array[0,0],t_array[-1,0]], aspect = aspect)

    if plot_DFs:
        with open(dir+"df_data5", 'rb') as f:
            df_data = pickle.load(f)

        for key, value in df_data.items():
            if len(value)>30 and np.mean(value["area"])*64**2>30:# and np.mean(value["X"].diff())>0 and value["X"][0]>-2.5:
                if np.mean(value["X"].diff())>0 and value["time"][0]>=t_array[0] and value["time"][0]<=t_array[-1]:
                    if np.max(value['frac_ropey'])>0.5:
                        axs[0,0].scatter(value["X"][0],value["time"][0],marker="*",color='lightblue')
                        axs[1,0].scatter(value["Y"][0],value["time"][0],marker="*",color='lightblue')
                        axs[0,1].scatter(value["X"][0],value["time"][0],marker="*",color='lightblue')
                        axs[1,1].scatter(value["Y"][0],value["time"][0],marker="*",color='lightblue')
                    else:
                        axs[0,0].scatter(value["X"][0],value["time"][0],marker="*",color='red')
                        axs[1,0].scatter(value["Y"][0],value["time"][0],marker="*",color='red')
                        axs[0,1].scatter(value["X"][0],value["time"][0],marker="*",color='red')
                        axs[1,1].scatter(value["Y"][0],value["time"][0],marker="*",color='red')

    for axi in np.ravel(axs):
        axi.invert_yaxis()
        axi.set_ylabel("$t$ [sec]",fontsize=15)
        axi.tick_params(axis='both',labelsize=15)
        axi.grid(color='white')
        
    axs[0,0].set_xlabel("$X_{MSM}$ [$R_M$]",fontsize=15)
    axs[0,0].set_xlim(-3,-1.25)
    axs[0,1].set_xlabel("$X_{MSM}$ [$R_M$]",fontsize=15)
    axs[0,1].set_xlim(-3,-1.25)
    axs[1,0].set_xlabel("$Y_{MSM}$ [$R_M$]",fontsize=15)
    axs[1,0].set_xlim(y_region[0],y_region[1])
    axs[1,1].set_xlabel("$Y_{MSM}$ [$R_M$]",fontsize=15)
    axs[1,1].set_xlim(y_region[0],y_region[1])
    
    #clb1 = fig.colorbar(S_mean_plot,ax=axs[0],shrink=0.5)
    #clb1.ax.set_title("Avg. $S$")
    clb1 = fig.colorbar(S_active_x_plot,ax=axs[:,0],shrink=0.3)
    clb1.ax.set_title("Num. of \n$S>5$")
    clb1 = fig.colorbar(c_sites_x_plot,ax=axs[:,1],shrink=0.3)
    clb1.ax.set_title("Num. of recon.\n sites")

    axs[0,0].set_title(str("Total $S$ in\n"+str(z_region[0])+"<z<"+str(z_region[1])),fontsize=15)
    axs[0,1].set_title(str("Sites fulfilling c1 and c2 in\n"+str(z_region[0])+"<z<"+str(z_region[1])),fontsize=15)

    fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi=300)
    plt.show()
    plt.close(fig)

    if read_data:
        pickle.dump(t_array, open(str(dir+"S_t_array_"+"%.2f"%round(float(time),2)+".pkl"), 'wb'))
        pickle.dump(x_array, open(str(dir+"S_x_array_"+"%.2f"%round(float(time),2)+".pkl"), 'wb'))
        pickle.dump(y_array, open(str(dir+"S_y_array_"+"%.2f"%round(float(time),2)+".pkl"), 'wb'))
        pickle.dump(S_active_array_x, open(str(dir+"S_active_array_x_"+"%.2f"%round(float(time),2)+".pkl"), 'wb'))
        pickle.dump(S_active_array_y, open(str(dir+"S_active_array_y_"+"%.2f"%round(float(time),2)+".pkl"), 'wb'))
        pickle.dump(c_sites_array_x, open(str(dir+"c_sites_array_x_"+"%.2f"%round(float(time),2)+".pkl"), 'wb'))
        pickle.dump(c_sites_array_y, open(str(dir+"c_sites_array_y_"+"%.2f"%round(float(time),2)+".pkl"), 'wb'))
        

In [None]:
dx

In [None]:
#### PARTICLE DATA PLOTTING
# !!!!!!!!! REMINDER !!!!!!!!!!!
# Ensure file codes match up e.g. cut_particle_region0_1 is electron, and *0_2 is ion, from DR_run1 onwards.
# Older nightside_particle runs use *0_0 for electron and *0_1 for ion. Needs to be manually updated.

####### START USER INPUT ###########
folder = "/Volumes/My Book Duo/runs/DR_run1/alldata" #"/Users/atcushen/Documents/MercuryModelling/runs/nightside_v1_run4/ta-2_rerun"
file = "cut_particle_region0_0_t00000031_n00016670_amrex/"
sample_loc=[-1.2,0,0.15]  # Location of sphere to sample particles in
radius=(1.75/64)  # Size of sphere
multi = True # Loop through files and save multiple panels?
start_time = 210            # What is the earliest time saved in this directory?
t_bound = [260.0,261.0]                # Start and stop times to loop through
dt = 0.05  # time step between files
particle_type = "ion" # Either "ion" or "electron"
ibin_limit = 3000 # Absolute value of the maximum velocity that should be plot for ions (acts as electrons too for older modules)
ebin_limit = 30000 # Absolute value of the maximum velocity that should be plot for ions (acts as electrons too for older modules)
plot_preset = "static_flythrough"
nbins = 128
####### END USER INPUT ###########
dir = str(folder+"/")
bin_limit = ibin_limit
""" 
PRESETS:

"Bz_movie": Bz in xy plane 

"vx_vy-vx_vz": plot of raw particle velocity distribution at loc, using standard values for controls. 

"v_para-v_perp2": plot of particle velocity distribution at loc, in parallel / exb and parallel / w coords.

"1D-vdf": 1d VDF plot at loc

"virtual_spectra": energy/count and B field timeseries in t_bound

"static_flythrough": show a spectrogram like virtual_spectra, but for a static field solution from start_pos to end_pos

"""

# Make a dictionary with better names for the files:
#named_files = {}
#for i in range(len(files)):
#    time = round(i*dt+start_time,3)
#    named_files[time] = files[i]

def get_directories(dir, key=".*cut_particle_region0_0.*"):
    # For a directory "dir", return a list of all files which match the regex expression "key"
    files=[]
    all_files = [ f.name for f in os.scandir(dir) if f.is_dir() ]
    for file in all_files:
        match = re.search(key,file)
        if match != None:
            files.append(file)
    files.sort()

    # Now give them the appropriate name for their time
    named_files = {}
    for i in range(len(files)):
        time = round(i*dt+start_time,3)
        named_files[time] = files[i]
    # Now cut the list down to files inside t_bound
    reduced_files = {}
    for time in list(named_files.keys())[int((t_bound[0]-start_time)/dt):int((t_bound[1]-start_time)/dt)]: #only loop over the times within t_bound
        reduced_files[time] = str(named_files[time])
    return reduced_files

def _unit_one(field, data):
    res = np.zeros(data[('particle', 'p_w')].shape)
    res[:] = 1
    return res

def get_B_at_loc(loc,data):
    # Input coordinates and ds of file to get B at that location
    dc = data.get_slice("z",loc[2])

    # Pull the coord data
    x,y = np.array(dc.x.value).T, np.array(dc.y.value).T
    
    # Find the cell indices surrounding loc
    xi_min = np.where(x < loc[0])[0][-1]
    yi_min = np.where(y < loc[1])[0][-1]
    
    # Define the trimmed meshgrids
    xx,yy = np.meshgrid(x[xi_min:xi_min+2],y[yi_min:yi_min+2])
    
    # Pull only the relevant mag data
    Bx = np.array(dc.evaluate_expression("Bx"))[xi_min:xi_min+2,yi_min:yi_min+2]
    By = np.array(dc.evaluate_expression("By"))[xi_min:xi_min+2,yi_min:yi_min+2]
    Bz = np.array(dc.evaluate_expression("Bz"))[xi_min:xi_min+2,yi_min:yi_min+2]

    # Do a little interpolation: B = B[0,0] + dB/dX dx + dB/dY dy
    Bx_loc = Bx[0,0] #+ (Bx[0,1]-Bx[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (Bx[1,0]-Bx[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])
    By_loc = By[0,0] #+ (By[0,1]-By[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (By[1,0]-By[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])
    Bz_loc = Bz[0,0] #+ (Bz[0,1]-Bz[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (Bz[1,0]-Bz[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])

    return Bx_loc,By_loc,Bz_loc

def get_E_at_loc(loc,data):
    # Input coordinates and ds of file to get B at that location
    dc = data.get_slice("z",loc[2])

    # Pull the coord data
    x,y = dc.x.value, dc.y.value
    
    # Find the cell indices surrounding loc
    xi_min = np.where(x < loc[0])[0][-1]
    yi_min = np.where(y < loc[1])[0][-1]
    
    # Define the trimmed meshgrids
    xx,yy = np.meshgrid(x[xi_min:xi_min+2],y[yi_min:yi_min+2])
    
    # Pull only the relevant mag data
    Ex = np.array(dc.evaluate_expression("Ex"))[xi_min:xi_min+2,yi_min:yi_min+2]
    Ey = np.array(dc.evaluate_expression("Ey"))[xi_min:xi_min+2,yi_min:yi_min+2]
    Ez = np.array(dc.evaluate_expression("Ez"))[xi_min:xi_min+2,yi_min:yi_min+2]

    # Do a little interpolation: B = B[0,0] + dB/dX dx + dB/dY dy
    Ex_loc = Ex[0,0] + (Ex[0,1]-Ex[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (Ex[1,0]-Ex[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])
    Ey_loc = Ey[0,0] + (Ey[0,1]-Ey[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (Ey[1,0]-Ey[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])
    Ez_loc = Ez[0,0] + (Ez[0,1]-Ez[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (Ez[1,0]-Ez[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])

    return Ex_loc,Ey_loc,Ez_loc

def get_B_at_loc_from_fluid(file,loc):
    # Hand the file name of cut data being used, for which the corresponding fluid data will be opened to read B
    # This was created as a workaround to get_B_at_loc, which started to fail since it read B as being 0. Couldn't work out why,
    # so now we rely on also downloading the fluid data
    # First, extract the iteration number from the filename
    iteration = file[-15:]

    # Then, pull the matching 3d fluid file
    fluid_file = get_directories(dir, key=str(".*3d_fluid.*"+iteration))
    fluid_file = fluid_file[list(fluid_file.keys())[0]]

    # Load the data and slice
    fluid_ds = fleks.load(str(dir+fluid_file))
    dc = fluid_ds.get_slice("z",loc[2])

    # Pull the coord data
    x,y = dc.x.value, dc.y.value
    
    # Find the cell indices surrounding loc
    xi_min = np.where(x < loc[0])[0][-1]
    yi_min = np.where(y < loc[1])[0][-1]
    
    # Define the trimmed meshgrids
    xx,yy = np.meshgrid(x[xi_min:xi_min+2],y[yi_min:yi_min+2])
    
    # Pull only the relevant mag data
    Bx = (np.array(dc.evaluate_expression("Bx"))[xi_min:xi_min+2,yi_min:yi_min+2]).T
    By = (np.array(dc.evaluate_expression("By"))[xi_min:xi_min+2,yi_min:yi_min+2]).T
    Bz = (np.array(dc.evaluate_expression("Bz"))[xi_min:xi_min+2,yi_min:yi_min+2]).T

    # Do a little interpolation: B = B[0,0] + dB/dX dx + dB/dY dy
    Bx_loc = Bx[0,0] + (Bx[0,1]-Bx[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (Bx[1,0]-Bx[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])
    By_loc = By[0,0] + (By[0,1]-By[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (By[1,0]-By[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])
    Bz_loc = Bz[0,0] + (Bz[0,1]-Bz[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (Bz[1,0]-Bz[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])

    return Bx_loc,By_loc,Bz_loc

def get_E_at_loc_from_fluid(file,loc):
    # Hand the file name of cut data being used, for which the corresponding fluid data will be opened to read E
    # First, extract the iteration number from the filename
    iteration = file[-15:]

    # Then, pull the matching 3d fluid file
    fluid_file = get_directories(dir, key=str(".*3d_fluid.*"+iteration))
    fluid_file = fluid_file[list(fluid_file.keys())[0]]

    # Load the data and slice
    fluid_ds = fleks.load(str(dir+fluid_file))
    dc = fluid_ds.get_slice("z",loc[2])

    # Pull the coord data
    x,y = dc.x.value, dc.y.value
    
    # Find the cell indices surrounding loc
    xi_min = np.where(x < loc[0])[0][-1]
    yi_min = np.where(y < loc[1])[0][-1]
    
    # Define the trimmed meshgrids
    xx,yy = np.meshgrid(x[xi_min:xi_min+2],y[yi_min:yi_min+2])
    
    # Pull only the relevant mag data
    Ex = (np.array(dc.evaluate_expression("Ex"))[xi_min:xi_min+2,yi_min:yi_min+2]).T
    Ey = (np.array(dc.evaluate_expression("Ey"))[xi_min:xi_min+2,yi_min:yi_min+2]).T
    Ez = (np.array(dc.evaluate_expression("Ez"))[xi_min:xi_min+2,yi_min:yi_min+2]).T

    # Do a little interpolation: B = B[0,0] + dB/dX dx + dB/dY dy
    Ex_loc = Ex[0,0] + (Ex[0,1]-Ex[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (Ex[1,0]-Ex[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])
    Ey_loc = Ey[0,0] + (Ey[0,1]-Ey[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (Ey[1,0]-Ey[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])
    Ez_loc = Ez[0,0] + (Ez[0,1]-Ez[0,0])/(xx[0,1]-xx[0,0])*(loc[0]-xx[0,0]) + (Ez[1,0]-Ez[0,0])/(yy[1,0]-yy[0,0])*(loc[1]-yy[0,0])

    return Ex_loc,Ey_loc,Ez_loc

def get_particle_np(file, x_axis='p_ux', y_axis='p_uy', loc = sample_loc, bins=64, v_lim = 2500, unit_vectors = False, filename=True):
    # Get the particle distribution in numpy format.
    # File is the directory of the amrex folder (e.g. /Users/*_amrex/"
    # x_axis and y_axis are what quantities to extract e.g. 'p_x','p_y'
    # Bins is how many bins to divide into
    # v_lim is the maximum velocity to extract (in km/s)
    # Outputs the meshgrids of velocities (in km/s), and the number of particles in each of those grid cells
    # If unit_vectors is true, then it will also output two vectors representing the x,y,z components of the two axes. Useful for seeing which direction bhat is.

    # Load the file
    if filename:
        ds = fleks.load(file)
    else:
        ds = file

    # Select the right quantity to compute given the choice of x and y axis
    if x_axis == 'p_ux': # just check what the x_axis is to determine which mode we are in; this is temporary
        ds.add_field(("particle", "unit_one"), function=_unit_one, sampling_type='particle')
        sp = ds.sphere(loc,radius)
        z_field = 'unit_one'
        plot = ds.plot_phase_region(sp, x_axis, y_axis, z_field, unit_type="planet", 
                                        x_bins=bins, y_bins=bins, domain_size=(-v_lim, v_lim, -v_lim, v_lim))
    
        for var_name in plot.profile.field_data: 
            counts_temp = plot.profile.field_data[var_name].T # take the transpose, which seems to be necessary. Bit worried about that.

    # Here, v_perp is just the magnitude of the velocity not aligned with the parallel direction
    elif (y_axis == 'v_para' and x_axis == 'v_perp'): 
        # Do a little interpolation: B = B[0,0] + dB/dX dx + dB/dY dy
        Bx_loc,By_loc,Bz_loc = get_B_at_loc(loc,ds)
        
        # Define magnetic field unit vectors at loc
        bhat = [Bx_loc,By_loc,Bz_loc]/np.sqrt(Bx_loc**2+By_loc**2+Bz_loc**2)
        
        # Sometimes B is identically 0 for some weird reason... so we get nan for bhat. Not sure how this is possible.
        bhat[np.isnan(bhat)] = 0
        bx,by,bz = bhat
        
        # For the particle data, we must create user-defined fields for the parallel and perpedicular components. See notebook for handwritten derivation.
        def _vel_para(field, data): # This is basically v_para = v . b
            res = bx*data[('particle', 'p_ux')] + by*data[('particle', 'p_uy')] + bz*data[('particle', 'p_uz')]        
            return res
        def _vel_perp(field, data): # This is v_perp = v - v_para
            res = (( data[('particle', 'p_ux')] - bx * ( bx*data[('particle', 'p_ux')]) + by*data[('particle', 'p_uy')] + bz*data[('particle', 'p_uz')] )**2 +
                   ( data[('particle', 'p_uy')] - by * ( bx*data[('particle', 'p_ux')]) + by*data[('particle', 'p_uy')] + bz*data[('particle', 'p_uz')] )**2 +
                   ( data[('particle', 'p_uz')] - bz * ( bx*data[('particle', 'p_ux')]) + by*data[('particle', 'p_uy')] + bz*data[('particle', 'p_uz')] )**2 )**0.5        
            return res
        '''
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        
        '''
        #def _vel_perp(field, data): # This is v_perp = v - v_para
        #    res = (( data[('particle', 'p_ux')]**2 - bx * ( bx*data[('particle', 'p_ux')]) + by*data[('particle', 'p_uy')] + bz*data[('particle', 'p_uz')] )**2 +
        #           ( data[('particle', 'p_uy')]**2 - by * ( bx*data[('particle', 'p_ux')]) + by*data[('particle', 'p_uy')] + bz*data[('particle', 'p_uz')] )**2 +
        #           ( data[('particle', 'p_uz')]**2 - bz * ( bx*data[('particle', 'p_ux')]) + by*data[('particle', 'p_uy')] + bz*data[('particle', 'p_uz')] )**2 )**0.5        
        #    return res
        '''
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        NEED TO FIX THIS 
        
        '''
        
        # Add the fields to the dataset
        vpara_name = ds.pvar("vel_para")
        vperp_name = ds.pvar("vel_perp")
        ds.add_field(vpara_name, units="code_velocity", function=_vel_para, sampling_type='particle')
        ds.add_field(vperp_name, units="code_velocity", function=_vel_perp, sampling_type='particle')

        # Generate plot 
        ds.add_field(("particle", "unit_one"), function=_unit_one, sampling_type='particle')
        sp = ds.sphere(loc,radius)
        x_field = vperp_name
        y_field = vpara_name
        z_field = ("particle", "unit_one")#ds.pvar('p_w')
        logs = {x_field: False, y_field: False}
        profile = yt.create_profile(data_source=sp, bin_fields=[x_field,y_field], fields=z_field, n_bins=[nbins,nbins], weight_field=None, 
                                    logs=logs, extrema={x_field: [0,2*v_lim],y_field: [-v_lim,v_lim]})
        plot = yt.PhasePlot.from_profile(profile)
        plot.set_unit(x_field, "km/s")
        plot.set_unit(y_field, "km/s")
        plot.set_unit(z_field, "amu")
        for var_name in plot.profile.field_data: 
            counts_temp = plot.profile.field_data[var_name].T

    # Here, one axis is the parallel velocity and the other is the ExB bulk flow direction
    elif ((x_axis == 'v_para' and y_axis == 'v_exb') or (y_axis == 'v_para' and x_axis == 'v_exb')):
        # Get B and E at the location we are studying
        Bx_loc,By_loc,Bz_loc = get_B_at_loc(loc,ds)
        Ex_loc,Ey_loc,Ez_loc = get_E_at_loc(loc,ds)
        
        # Define magnetic and electric field unit vectors at loc
        bhat = [Bx_loc,By_loc,Bz_loc]/np.sqrt(Bx_loc**2+By_loc**2+Bz_loc**2)
        ehat = [Ex_loc,Ey_loc,Ez_loc]/np.sqrt(Ex_loc**2+Ey_loc**2+Ez_loc**2)

        # Compute the ExB unit vector
        exb = [ehat[1]*bhat[2]-ehat[2]*bhat[1], ehat[2]*bhat[0]-ehat[0]*bhat[2], ehat[0]*bhat[1]-ehat[1]*bhat[0]]
        exb = exb / np.sqrt(np.sum(np.array(exb)**2))
        print('exb:',exb,"    mag:",np.sqrt(np.sum(np.array(exb)**2)))

        # Verify orthogonal: seems to be good
        # val = exb[0]*bhat[0]+exb[1]*bhat[1]+exb[2]*bhat[2]
        # print(val) 

        # Define v_para and v_ExB directions in the particle data files
        def _vel_para(field, data): # This is basically v_para = v . b
            res = bhat[0]*data[('particle', 'p_ux')] + bhat[1]*data[('particle', 'p_uy')] + bhat[2]*data[('particle', 'p_uz')]        
            return res
        def _vel_exb(field, data): 
            res = exb[0]*data[('particle', 'p_ux')] + exb[1]*data[('particle', 'p_uy')] + exb[2]*data[('particle', 'p_uz')]        
            return res

        # Add the fields to the dataset
        vpara_name = ds.pvar("vel_para")
        vexb_name = ds.pvar("vel_exb")
        ds.add_field(vpara_name, units="code_velocity", function=_vel_para, sampling_type='particle')
        ds.add_field(vexb_name, units="code_velocity", function=_vel_exb, sampling_type='particle')

        # Generate plot 
        ds.add_field(("particle", "unit_one"), function=_unit_one, sampling_type='particle')
        sp = ds.sphere(loc,radius)
        x_field = vexb_name
        y_field = vpara_name
        z_field = ("particle", "unit_one")#ds.pvar('p_w')
        logs = {x_field: False, y_field: False}
        profile = yt.create_profile(data_source=sp, bin_fields=[x_field,y_field], fields=z_field, 
                                    n_bins=[bins,bins], weight_field=None, logs=logs, 
                                    extrema={x_field: [-bin_limit,bin_limit],y_field: [-bin_limit,bin_limit]})
        plot = yt.PhasePlot.from_profile(profile)
        #plot = ds.plot_phase_region(sp, x_field, y_field, z_field, unit_type="planet", x_bins=bins, y_bins=bins, domain_size=(-v_lim, v_lim, -v_lim, v_lim))
        #plot.set_unit(x_field, "km/s")
        #plot.set_unit(y_field, "km/s")
        #plot.set_unit(z_field, "amu")

        #plot.show()

        # .T added march 25 2025 after exhaustive testing to ensure we have no erroneous rotations
        for var_name in plot.profile.field_data: 
            counts_temp = plot.profile.field_data[var_name].T

        # Save the unit vectors
        xhat = exb
        yhat = bhat

    # Here, one axis is the parallel velocity and the other is perpendicular to both ExB and B
    elif ((x_axis == 'v_para' and y_axis == 'v_w') or (y_axis == 'v_para' and x_axis == 'v_w')):
        # Get B and E at the locationn we are studying
        Bx_loc,By_loc,Bz_loc = get_B_at_loc(loc,ds)
        Ex_loc,Ey_loc,Ez_loc = get_E_at_loc(loc,ds)
        
        # Define magnetic and electric field unit vectors at loc
        bhat = [Bx_loc,By_loc,Bz_loc]/np.sqrt(Bx_loc**2+By_loc**2+Bz_loc**2)
        ehat = [Ex_loc,Ey_loc,Ez_loc]/np.sqrt(Ex_loc**2+Ey_loc**2+Ez_loc**2)

        # Compute the ExB unit vector
        exb = [ehat[1]*bhat[2]-ehat[2]*bhat[1], ehat[2]*bhat[0]-ehat[0]*bhat[2], ehat[0]*bhat[1]-ehat[1]*bhat[0]]
        exb = exb / np.sqrt(np.sum(np.array(exb)**2))

        # Compute the vector orthogonal to both exb and b: w = bhat x (ExB)
        w = [bhat[1]*exb[2]-bhat[2]*exb[1], bhat[2]*exb[0]-bhat[0]*exb[2], bhat[0]*exb[1]-bhat[1]*exb[0]]
        w = w / np.sqrt(np.sum(np.array(w)**2))
        print('w:',w,"    mag:",np.sqrt(np.sum(np.array(w)**2)))
        
        # Verify orthogonal:
        #val0 = exb[0]*w[0]+exb[1]*w[1]+exb[2]*w[2]
        #val1 = w[0]*bhat[0]+w[1]*bhat[1]+w[2]*bhat[2]
        #print(val0)
        #print(val1)

        # Define v_para and v_ExB directions in the particle data files
        def _vel_para(field, data): # This is basically v_para = v . b
            res = bhat[0]*data[('particle', 'p_ux')] + bhat[1]*data[('particle', 'p_uy')] + bhat[2]*data[('particle', 'p_uz')]        
            return res
        def _vel_w(field, data): 
            res = w[0]*data[('particle', 'p_ux')] + w[1]*data[('particle', 'p_uy')] + w[2]*data[('particle', 'p_uz')]        
            return res

        # Add the fields to the dataset
        vpara_name = ds.pvar("vel_para")
        vw_name = ds.pvar("vel_w")
        ds.add_field(vpara_name, units="code_velocity", function=_vel_para, sampling_type='particle')
        ds.add_field(vw_name, units="code_velocity", function=_vel_w, sampling_type='particle')

        # Generate plot 
        ds.add_field(("particle", "unit_one"), function=_unit_one, sampling_type='particle')
        sp = ds.sphere(loc,radius)
        x_field = vw_name
        y_field = vpara_name
        z_field = ("particle", "unit_one")#ds.pvar('p_w')
        logs = {x_field: False, y_field: False}
        profile = yt.create_profile(data_source=sp, bin_fields=[x_field,y_field], fields=z_field, 
                                    n_bins=[bins,bins], weight_field=None, logs=logs,
                                    extrema={x_field: [-bin_limit,bin_limit],y_field: [-bin_limit,bin_limit]})
        plot = yt.PhasePlot.from_profile(profile)
        plot.set_unit(x_field, "km/s")
        plot.set_unit(y_field, "km/s")
        plot.set_unit(z_field, "amu")
        for var_name in plot.profile.field_data: 
            counts_temp = plot.profile.field_data[var_name].T

        # Save the unit vectors
        xhat = w
        yhat = bhat
    
    else:
        print("ERROR: invalid axes!")
    
    # Convert to np. Define output arrays without units
    x_temp = plot.profile.x
    y_temp = plot.profile.y
    x=np.zeros(x_temp.shape)
    y=np.zeros(y_temp.shape)
    counts=np.zeros(counts_temp.shape)
    x[:]=x_temp[:]
    y[:]=y_temp[:]
    counts[:,:]=counts_temp[:,:]
    xx,yy=np.meshgrid(x,y)

    if unit_vectors:
        return xx,yy,counts,xhat,yhat,x_temp,y_temp,counts_temp
    else:
        return xx,yy,counts

def sum_z_by_distance(x_array, y_array, z_array, r1, r2, nbins):
    # Calculate the distances from the origin for each (x, y) position
    distances = np.sqrt(x_array**2 + y_array**2)

    # Define bin edges and initialize an array to store the sums of Z values
    bin_edges = np.linspace(r1, r2, nbins + 1)
    z_sums = np.zeros(nbins)

    # Iterate over each bin and sum the Z values whose distances fall into the bin
    for i in range(nbins):
        # Create a mask selecting the indices that belong to the current bin
        mask = (distances >= bin_edges[i]) & (distances < bin_edges[i + 1])
        # Sum the Z values within the masked indices
        z_sums[i] = np.sum(z_array[mask])

    return z_sums, bin_edges


def particle_plots(mode,particle_type='ion'):
    # Generates the desired plot for the given preset.

    if mode == "Bz_movie":
        # Define the files to read
        files = get_directories(dir,".*cut_particle_region0_2.*")

        # Loop over the identified files
        for time in list(files.keys()): 
            print("Plotting t =",time)
            file = str(files[time])

            # Get B field data
            data_file = dir+file
            ds = fleks.load(data_file)

            # Plot
            fig,axs = plt.subplots(ncols=2,figsize=(12,6))
            Bz_lims = [-75,75]
            cmap = 'bwr'
            
            # Plot xy plane
            dc = ds.get_slice("z",loc[2])
            xy_plot = axs[0].imshow(dc.evaluate_expression("Bz").T,origin='lower',
                        extent=[np.min(dc.x.value),np.max(dc.x.value),np.min(dc.y.value),np.max(dc.y.value)],
                        vmin=Bz_lims[0],vmax=Bz_lims[1],cmap = cmap)
            axs[0].set_ylim(0.7,-0.7)
            axs[0].set_ylabel("$Y$ [$R_M$]")
            axs[0].set_title(str("$B_z$ at z="+str(loc[2])+"   t="+str(round(time,2))+"s"))

            # Plot xz plane
            dc = ds.get_slice("y",loc[1])
            xz_plot = axs[1].imshow(dc.evaluate_expression("Bz").T,origin='lower',
                        extent=[np.min(dc.x.value),np.max(dc.x.value),np.min(dc.y.value),np.max(dc.y.value)],
                        vmin=Bz_lims[0],vmax=Bz_lims[1],cmap = cmap)
            # Add streamplot for field lines
            Xgrid,Zgrid = np.meshgrid(np.linspace(np.min(dc.x.value),np.max(dc.x.value),len(dc.x.value)),
                        np.linspace(np.min(dc.y.value),np.max(dc.y.value),len(dc.y.value)))
            axs[1].streamplot(Xgrid,Zgrid,np.array(dc.evaluate_expression("Bx")).T,np.array(dc.evaluate_expression("Bz")).T, broken_streamlines=False, linewidth=0.45,arrowsize=0.45,color='black',
                         density = 2.5)
            axs[1].set_ylim(-0.35,0.55)
            axs[1].set_ylabel("$Z$ [$R_M$]")
            axs[1].set_title(str("$B_z$ at y="+str(loc[1])+"   t="+str(round(time,2))+"s"))
            
            fig.colorbar(xz_plot,ax=axs[:],shrink=0.7,label='$B_z$')

            for ax in axs:
                ax.grid()
                ax.set_xlim(-1,-2.25)
                ax.set_xlabel("$X$ [$R_M$]")

            # Save
            fig.savefig(str(folder+"_plots/"+plot_preset+"_"+str(loc)+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight')

            plt.show()
            plt.close()


    if mode == "vx_vy-vx_vz":
        # Define the files to read
        if multi:
            print("Plotting all files in dir")
            if particle_type == "ion":
                files = get_directories(dir,".*cut_particle_region0_1.*")
            elif particle_type == "electron":
                # Note: which files to read for each particle type may change if #saveplot is altered. For now, ion is zone 0, electron is zone 1.
                files = get_directories(dir,".*cut_particle_region0_0.*")
            else:
                print("ERROR: invalid particle type!")
        else:
            files = {str(round(start_time,3)):str(file)}

        # Loop over the identified files
        for time in list(files.keys()): 
            print("Plotting t =",time)
            file = str(files[time])
        
            # Get data
            data_file = dir+file
            xx,yy,counts_xy = get_particle_np(data_file,'p_ux','p_uy',v_lim = bin_limit)
            xx,zz,counts_xz = get_particle_np(data_file,'p_ux','p_uz',v_lim = bin_limit)

            # Get B field data
            ds = fleks.load(data_file)
    
            # Set colorbar limits
            cmin = 1
            cmax = 100#max(counts_xy.max(),counts_xz.max())
    
            # Plot
            fig, axs = plt.subplots(ncols=3,figsize=(22,5))
            im0=axs[0].imshow(counts_xy, extent=[xx.min(), xx.max(), yy.min(), yy.max()],origin="lower",norm=LogNorm(vmin=cmin,vmax=cmax))
            im1=axs[1].imshow(counts_xz, extent=[xx.min(), xx.max(), zz.min(), zz.max()],origin="lower",norm=LogNorm(vmin=cmin,vmax=cmax))

            # Show local field
            dc = ds.get_slice("y",loc[1])
            Bx = np.array(dc.evaluate_expression("Bx").T)
            By = np.array(dc.evaluate_expression("By").T)
            Bz = np.array(dc.evaluate_expression("Bz").T)
            Bmag = np.sqrt(Bx**2+By**2+Bz**2)

            nlevels = 200 
            Bmag_plot = axs[2].imshow(Bmag, extent = [np.min(dc.x.value),np.max(dc.x.value), np.min(dc.y.value), np.max(dc.y.value)], 
                                    vmin=0,vmax=100, cmap='plasma', aspect = 'equal', origin='lower')
            B_quiver = axs[2].streamplot(dc.x.value, dc.y.value, Bx, Bz,
                                         color='white',density=15,broken_streamlines=False,linewidth=0.2)
            
            # Show location of sampling
            axs[2].scatter(loc[0],loc[2],color='green')
            axs[2].add_patch(plt.Circle((loc[0], loc[2]), radius, color='green', fill=False))
    
            # Add labels
            axs[0].set_xlabel("$v_x$ [km/s]")
            axs[0].set_ylabel("$v_y$ [km/s]")
            axs[1].set_xlabel("$v_x$ [km/s]")
            axs[1].set_ylabel("$v_z$ [km/s]")
            axs[0].set_title(str("Velocity distribution  r="+str(loc)+"   t="+str(round(time,2))+"s"+"\n Particle type: "+particle_type))
            axs[2].set_xlabel('X [$R_M$]')
            axs[2].set_ylabel('Z [$R_M$]') 
            axs[2].set_xlim(loc[0]-0.2,loc[0]+0.2)
            axs[2].set_ylim(loc[2]-0.2,loc[2]+0.2)
            axs[2].set_title(str("Bmag and quivers at y = "+str(round(loc[1],1))))
    
            # Add grid
            axs[0].grid(alpha=0.5)
            axs[1].grid(alpha=0.5)
    
            # Add colorbar
            #fig.subplots_adjust(right=0.84)
            #cbar_ax = fig.add_axes([0.85, 0.15, 0.01, 0.7])
            fig.colorbar(im1, ax=axs[0:2], label = "Counts")
            fig.colorbar(Bmag_plot, ax=axs[2], label = "B [nT]")
    
            # Save
            fig.savefig(str(folder+"_plots/"+plot_preset+"_"+particle_type+"_"+str(loc)+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight')

            if multi:
                plt.close()

            else:
                plt.show()

    if mode == "v_para-v_perp2":
        # Define the files to read
        if multi:
            print("Plotting all files in dir")
            if particle_type == "ion":
                files = get_directories(dir,".*cut_particle_region0_2.*")
                bin_limit = ibin_limit
            elif particle_type == "electron":
                # Note: which files to read for each particle type may change if #saveplot is altered. For now, ion is zone 2, electron is zone 1.
                files = get_directories(dir,".*cut_particle_region0_1.*")
            else:
                print("ERROR: invalid particle type!")
        else:
            files = {str(round(start_time,3)):str(file)}

        # Loop over the identified files
        for time in list(files.keys()): 
            print("Plotting t =",time)
            file = str(files[time])
            print("reading file",file)

            # Get data
            data_file = dir+file
            xx0,yy0,counts0, exbhat, bhat,x0,y0,val0 = get_particle_np(data_file,'v_exb','v_para',v_lim = bin_limit, unit_vectors = True)
            xx1,yy1,counts1, what, bhat,x1,y1,val1 = get_particle_np(data_file,'v_w','v_para',v_lim = bin_limit, unit_vectors = True)
            
            # Get B field data
            ds = fleks.load(data_file)
            
            # Set colorbar limits
            cmin = 1
            cmax = 100
    
            # Plot
            
            fig, axs = plt.subplots(ncols=3,figsize=(21,5))
            #axs[0].contourf(x0,y0,val0)
            #axs[1].contourf(x1,y1,val1)
            im0=axs[0].imshow(counts0, extent=[xx0.min(), xx0.max(), yy0.min(), yy0.max()],
                          origin="lower",norm=LogNorm(vmin=cmin,vmax=cmax))
            im1=axs[1].imshow(counts1, extent=[xx1.min(), xx1.max(), yy1.min(), yy1.max()],
                           origin="lower",norm=LogNorm(vmin=cmin,vmax=cmax))
                                      

            # Show local field
            dc = ds.get_slice("y",loc[1])
            Bx = np.array(dc.evaluate_expression("Bx").T)
            By = np.array(dc.evaluate_expression("By").T)
            Bz = np.array(dc.evaluate_expression("Bz").T)
            Bmag = np.sqrt(Bx**2+By**2+Bz**2)
            nlevels = 200 
            By_plot = axs[2].imshow(Bmag, extent = [np.min(dc.x.value),np.max(dc.x.value), np.min(dc.y.value), np.max(dc.y.value)], 
                                    vmin=0,vmax=100, cmap='plasma', aspect = 'equal', origin='lower')
            B_quiver = axs[2].streamplot(dc.x.value, dc.y.value, Bx, Bz,
                                         color='white',density=15,broken_streamlines=False,linewidth=0.2)
            
            # Show location of sampling
            axs[2].scatter(loc[0],loc[2],color='green')
            axs[2].add_patch(plt.Circle((loc[0], loc[2]), radius, color='green', fill=False))
            
    
            # Add labels
            axs[0].set_ylabel("$v_{para}$ [km/s]")
            axs[0].set_xlabel("$v_{ExB}$ [km/s]")
            axs[1].set_ylabel("$v_{para}$ [km/s]")
            axs[1].set_xlabel("$v_{w}$ [km/s]")
            axs[2].set_xlabel('X [$R_M$]')
            axs[2].set_ylabel('Z [$R_M$]')    
            axs[0].set_title(str("Velocity distribution  r="+str(loc)+"   t="+str(round(time,2))+"s \n Local b: ("+str(round(bhat[0],3))+", "+str(round(bhat[1],3))+", "+str(round(bhat[2],3))+")"+"\n Local exb: ("+str(round(exbhat[0],3))+", "+str(round(exbhat[1],3))+", "+str(round(exbhat[2],3))+")"))   
            axs[1].set_title(str("Particle type: "+particle_type+"\n Local w: ("+str(round(what[0],3))+", "+str(round(what[1],3))+", "+str(round(what[2],3))+")"))    
            axs[2].set_title(str("Local magnetic field at y = "+str(round(loc[1],1))))
            axs[0].set_ylim(-bin_limit,bin_limit)
            axs[0].set_xlim(-bin_limit,bin_limit)
            axs[1].set_ylim(-bin_limit,bin_limit)
            axs[1].set_xlim(-bin_limit,bin_limit)
            axs[2].set_xlim(loc[0]-0.2,loc[0]+0.2)
            axs[2].set_ylim(loc[2]-0.2,loc[2]+0.2)
    
            # Add grid
            axs[0].grid(alpha=0.5)
            axs[0].grid(alpha=0.5)
            axs[1].grid(alpha=0.5)
            axs[1].grid(alpha=0.5)

            # Fit asepct ratio
            axs[2].set_aspect(1)
    
           # Add colorbar
            fig.colorbar(im1, ax=axs[0:2], label = "Counts")
            fig.colorbar(By_plot, ax=axs[2], label = "$B_{mag}$")
    
            # Save
            fig.savefig(str(folder+"_plots/"+plot_preset+"_"+particle_type+"_"+str(loc)+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight')

            if multi:
                plt.close()

            else:
                plt.show()
        
        if make_gif:
            generate_gif(folder)

    if mode=='1D-vdf':

        if multi:
            print("Plotting all files in dir")
            ifiles = get_directories(dir,".*cut_particle_region0_1.*")
            efiles = get_directories(dir,".*cut_particle_region0_0.*")
            
        else:
            files = {str(round(start_time,3)):str(file)}

        # Loop over the identified files
        for time in list(ifiles.keys()): 
            print("Plotting t =",time)
            ifile = str(ifiles[time])
            efile = str(efiles[time])

            fig,ax = plt.subplots(figsize=(7,5))

            for iparticle,particle_type in enumerate(['ion','electron']):
                
                # Get data
                if particle_type=='ion':
                    data_file = dir+ifile
                    bin_limit = ibin_limit
                    mass = m_p
                    color = 'red'
                else:
                    data_file = dir+efile
                    bin_limit = ebin_limit
                    mass = m_p/100
                    color = 'blue'

                v_perp,v_para,counts_perp_para = get_particle_np(data_file,'v_perp','v_para',v_lim = bin_limit)

                #print(bin_limit)
                #print(particle_type)
                #plt.imshow(np.log10(counts_perp_para),origin='lower',extent=[v_perp[0,0],v_perp[0,-1],v_para[0,0],v_para[-1,0]])
                #plt.colorbar()
                #plt.show()

                # Turn into 1d vdf
                v_mag = np.sqrt(v_perp**2+v_para**2)
                counts_mag, bin_edges = sum_z_by_distance(v_perp,v_para,counts_perp_para, 0, np.max(v_mag), nbins) # Compute number of particles at each velocity
                v_centers = (bin_edges[:-1] + bin_edges[1:]) / 2  # Compute velocity at each bin center
                
                # Compute equivalent energy
                e_centers = 0.5 * mass * (v_centers*1e3)**2 * eV # eV

                # Plot vdf
                ax.scatter(e_centers/1e3, counts_mag, color=color, label=particle_type)
                
            ax.set_xscale('log')
            ax.set_yscale('log')
            ax.set_xlim(1,1e3)
            ax.set_ylim(1,1e3)
            ax.set_xlabel("E/q [keV]")
            ax.set_ylabel("Macro particle counts")
            ax.legend(loc='upper right')
            ax.set_title(str("VDF at r="+str(loc)+"   t="+str(round(time,2))+"s"))
            ax.grid()
    
            # Save
            fig.savefig(str(folder+"_plots/"+plot_preset+"_"+str(loc)+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight')

            if multi:
                plt.close()

            else:
                plt.show()

    if mode=="virtual_spectra":

        if multi:
            print("Plotting all files in dir")
            ifiles = get_directories(dir,".*cut_particle_region0_2.*")
            efiles = get_directories(dir,".*cut_particle_region0_1.*")
            
        else:
            print("ERROR: Gotta be multi mode for this!")

        fig,axs = plt.subplots(ncols = 1, nrows = 3, figsize=(20,9), sharex = True)
        fig.subplots_adjust(hspace=0)

        t_ls = []
        Bx_ls = []
        By_ls = []
        Bz_ls = []
        i_spectrum = []
        e_spectrum = []
        i_energy_levels = []
        e_energy_levels = []
        
        # Loop over the identified files
        for time in list(ifiles.keys()): 
            print("Plotting t =",time)
            ifile = str(ifiles[time])
            if len(efiles)>0:
                efile = str(efiles[time])
                particle_types = ['ion','electron']
            else:
                particle_types = ['ion']

            for iparticle,particle_type in enumerate(particle_types):
                
                # Get data
                if particle_type=='ion':
                    data_file = dir+ifile
                    bin_limit = ibin_limit
                    mass = m_p
                    color = 'red'
                else:
                    data_file = dir+efile
                    bin_limit = ebin_limit
                    mass = m_p/100
                    color = 'blue'

                v_perp,v_para,counts_perp_para = get_particle_np(data_file,'v_perp','v_para',v_lim = bin_limit)

                #print(bin_limit)
                #print(particle_type)
                #plt.imshow(np.log10(counts_perp_para),origin='lower',extent=[v_perp[0,0],v_perp[0,-1],v_para[0,0],v_para[-1,0]])
                #plt.colorbar()
                #plt.show()

                # Turn into 1d vdf
                v_mag = np.sqrt(v_perp**2+v_para**2)
                counts_mag, bin_edges = sum_z_by_distance(v_perp,v_para,counts_perp_para, 0, np.max(v_mag), nbins) # Compute number of particles at each velocity
                v_centers = (bin_edges[:-1] + bin_edges[1:]) / 2  # Compute velocity at each bin center
                
                # Compute equivalent energy
                e_centers = 0.5 * mass * (v_centers*1e3)**2 * eV # eV
                
                if particle_type=='ion':
                    i_energy_levels = e_centers/1e3 # keV
                    i_spectrum.append(counts_mag)

                else: 
                    e_energy_levels = e_centers/1e3 # keV
                    e_spectrum.append(counts_mag)

            # Get B field
            ds = fleks.load(data_file)
            Bx_loc,By_loc,Bz_loc = get_B_at_loc(loc,ds)

            t_ls.append(float(time))
            Bx_ls.append(Bx_loc)
            By_ls.append(By_loc)
            Bz_ls.append(Bz_loc)

        # Reformat
        t_ls = np.array(t_ls)
        Bx_ls = np.array(Bx_ls)
        By_ls = np.array(By_ls)
        Bz_ls = np.array(Bz_ls)
        i_spectrum = np.array(i_spectrum)
        e_spectrum = np.array(e_spectrum)

        # Show plot
        plot0 = axs[0].imshow(np.log10(i_spectrum).T, extent = [t_ls[0],t_ls[-1],i_energy_levels[0],i_energy_levels[-1]],cmap='rainbow',
                              vmin=0,vmax=2.3,aspect='auto',origin='lower')
        if len(efiles)>0: 
            plot1 = axs[1].imshow(np.log10(e_spectrum).T, extent = [t_ls[0],t_ls[-1],e_energy_levels[0],e_energy_levels[-1]],cmap='rainbow',
                              vmin=0,vmax=2.3,aspect='auto',origin='lower')

        fig.colorbar(plot0,ax=axs[:], shrink = 0.6, label='log counts')
        #fig.colorbar(plot1,ax=axs[1])

        plot2x = axs[2].plot(t_ls,Bx_ls,color='red',label='$B_x$')
        plot2y = axs[2].plot(t_ls,By_ls,color='green',label='$B_y$')
        plot2z = axs[2].plot(t_ls,Bz_ls,color='blue',label='$B_z$')

        axs[2].legend(loc='lower right')
        
        for axi in axs[0:2]:
            axi.set_yscale('log')
            axi.set_ylim(1,4e2)
            axi.set_ylabel("Energy [keV]")
            axi.grid()

        axs[2].grid()
                
        axs[0].set_title(str("Spectral at r="+str(loc)+"\nIons"))
        axs[1].set_title("Electrons")
    
        # Save
        fig.savefig(str(folder+"_plots/"+plot_preset+"_"+str(loc)+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight')
    
        plt.show()
        plt.close()

        return i_spectrum

    if mode=="static_flythrough":

        # User input start
        # Define flyby start/stop
        start_pos = [-1.5,0,-0.2]
        end_pos = [-1.5,0,0.6]
        npos = 100
        # User input stop

        x_rake = np.linspace(start_pos[0],end_pos[0],npos)
        y_rake = np.linspace(start_pos[1],end_pos[1],npos)
        z_rake = np.linspace(start_pos[2],end_pos[2],npos)

        print("Reading all files in dir, just plotting t =",t_bound[0])
        ifiles = get_directories(dir,".*cut_particle_region0_2.*")
        efiles = get_directories(dir,".*cut_particle_region0_1.*")

        fig,axs = plt.subplots(ncols = 1, nrows = 3, figsize=(20,9), sharex = True)
        fig.subplots_adjust(hspace=0)

        t_ls = []
        Bx_ls = []
        By_ls = []
        Bz_ls = []
        i_spectrum = []
        e_spectrum = []
        i_energy_levels = []
        e_energy_levels = []
        
        # Only plot data from one time
        time = list(ifiles.keys())[0]
        ifile = str(ifiles[time])
        if len(efiles)>0:
            efile = str(efiles[time])
            particle_types = ['ion','electron']
        else:
            particle_types = ['ion']

        for iparticle,particle_type in enumerate(particle_types):

             # Get data
                if particle_type=='ion':
                    data_file = dir+ifile
                    bin_limit = ibin_limit
                    mass = m_p
                    color = 'red'
                else:
                    data_file = dir+efile
                    bin_limit = ebin_limit
                    mass = m_p/100
                    color = 'blue'

                ds = fleks.load(data_file)

                # Loop through each location
                for ipos in range(0,npos):
                    pos = [x_rake[ipos],y_rake[ipos],z_rake[ipos]]

                    v_perp,v_para,counts_perp_para = get_particle_np(ds,'v_perp','v_para', loc = pos, v_lim = bin_limit, filename = False)

                    # Turn into 1d vdf
                    v_mag = np.sqrt(v_perp**2+v_para**2)
                    counts_mag, bin_edges = sum_z_by_distance(v_perp,v_para,counts_perp_para, 0, np.max(v_mag), nbins) # Compute number of particles at each velocity
                    v_centers = (bin_edges[:-1] + bin_edges[1:]) / 2  # Compute velocity at each bin center
                
                    # Compute equivalent energy
                    e_centers = 0.5 * mass * (v_centers*1e3)**2 * eV # eV
                    
                    if particle_type=='ion':
                        i_energy_levels = e_centers/1e3 # keV
                        i_spectrum.append(counts_mag)
    
                    else: 
                        e_energy_levels = e_centers/1e3 # keV
                        e_spectrum.append(counts_mag)

        # Loop through position again for B field
        for ipos in range(0,npos):
            pos = [x_rake[ipos],y_rake[ipos],z_rake[ipos]]
            
            # Get B field
            Bx_loc,By_loc,Bz_loc = get_B_at_loc(pos,ds)
    
            #t_ls.append(float(time))
            Bx_ls.append(Bx_loc)
            By_ls.append(By_loc)
            Bz_ls.append(Bz_loc)

        # Reformat
       # t_ls = np.array(t_ls)
        Bx_ls = np.array(Bx_ls)
        By_ls = np.array(By_ls)
        Bz_ls = np.array(Bz_ls)
        i_spectrum = np.array(i_spectrum)
        e_spectrum = np.array(e_spectrum)

        # Show plot
        plot0 = axs[0].imshow(np.log10(i_spectrum).T, extent = [z_rake[0],z_rake[-1],i_energy_levels[0],i_energy_levels[-1]],cmap='rainbow',
                              vmin=0,vmax=2.3,aspect='auto',origin='lower')
        if len(efiles)>0: 
            plot1 = axs[1].imshow(np.log10(e_spectrum).T, extent = [z_rake[0],z_rake[-1],e_energy_levels[0],e_energy_levels[-1]],cmap='rainbow',
                              vmin=0,vmax=2.3,aspect='auto',origin='lower')

        fig.colorbar(plot0,ax=axs[:], shrink = 0.6, label='log counts')
        #fig.colorbar(plot1,ax=axs[1])

        plot2x = axs[2].plot(z_rake,Bx_ls,color='red',label='$B_x$')
        plot2y = axs[2].plot(z_rake,By_ls,color='green',label='$B_y$')
        plot2z = axs[2].plot(z_rake,Bz_ls,color='blue',label='$B_z$')

        axs[2].legend(loc='lower right')
        
        for axi in axs[0:2]:
            axi.set_yscale('log')
            axi.set_ylim(1,4e2)
            axi.set_ylabel("Energy [keV]")
            axi.grid()
            axi.set_xlabel("Position [z]")

        axs[2].grid()
                
        axs[0].set_title(str("t = "+"%.2f"%round(float(time),2)+"\nSpectrum from "+str(start_pos)+" to "+str(end_pos)+"\nIons"))
        axs[1].set_title("Electrons")
    
        # Save
        fig.savefig(str(folder+"_plots/"+plot_preset+"_"+str(start_pos)+"_to_"+str(end_pos)+"_at_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight')
    
        plt.show()
        plt.close()

        return i_spectrum
        

# Run
i_spectrum = particle_plots(plot_preset)

In [None]:
particle_types

In [None]:
def sum_z_by_distance(x_array, y_array, z_array, r1, r2, nbins):
    # Calculate the distances from the origin for each (x, y) position
    distances = np.sqrt(x_array**2 + y_array**2)

    # Define bin edges and initialize an array to store the sums of Z values
    bin_edges = np.linspace(r1, r2, nbins + 1)
    z_sums = np.zeros(nbins)

    # Iterate over each bin and sum the Z values whose distances fall into the bin
    for i in range(nbins):
        # Create a mask selecting the indices that belong to the current bin
        mask = (distances >= bin_edges[i]) & (distances < bin_edges[i + 1])
        # Sum the Z values within the masked indices
        z_sums[i] = np.sum(z_array[mask])

    return z_sums, bin_edges

counts_mag, bin_edges = sum_z_by_distance(v_perp,v_para,counts_perp_para, 0, np.max(v_mag), 64) # Compute number of particles at each velocity
v_centers = (bin_edges[:-1] + bin_edges[1:]) / 2  # Compute velocity at each bin center
# Compute equivalent energy
if particle_type=='ion':
    e_centers = 0.5 * m_p * (v_centers*1e3)**2 * eV # eV
elif particle_type=='electron':
    e_centers = 0.5 * m_p/100 * (v_centers*1e3)**2 * eV # eV

In [None]:
plt.imshow(np.log10(counts_perp_para),origin='lower',extent=[v_perp[0,0],v_perp[0,-1],v_para[0,0],v_para[-1,0]])
plt.colorbar()

In [None]:
fig,ax = plt.subplots(figsize=(7,5))
ax.scatter(e_centers/1e3, counts_mag)
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlim(1,1e3)
ax.set_xlabel("E/q [keV]")
ax.set_ylabel("Macro particle counts")

In [None]:
# Save
fig.savefig(str(folder+"_plots/"+plot_preset+"_"+str(loc)+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight')

In [None]:
counts_mag