In [1]:
# Essential Jupyter Notebook Magic
%matplotlib inline

# General Purpose and Data Handling Libraries
import os
import re
import glob
import numpy as np
import pandas as pd
from os import listdir
from os.path import isfile, join
from natsort import natsorted
import pickle
from operator import add
import random

# MatPlotlib for Plotting and Visualization
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.tri as tri
import matplotlib as mpl
from matplotlib import cm, ticker
from matplotlib.colors import LogNorm, LightSource, ListedColormap, BoundaryNorm
from matplotlib.collections import LineCollection
from matplotlib.cm import ScalarMappable
from matplotlib.ticker import LogFormatter, LogFormatterSciNotation
from matplotlib.ticker import LogLocator, MultipleLocator, NullFormatter
from mpl_toolkits.mplot3d import axes3d
from streamtracer import StreamTracer, VectorGrid

# Scipy for Scientific Computing and Analysis
from scipy import stats, interpolate
from scipy.optimize import curve_fit
from scipy.interpolate import interp1d, griddata
from scipy.ndimage import label, gaussian_filter
from scipy.spatial import ConvexHull
from scipy.interpolate import RegularGridInterpolator

# Image Handling and Processing
from PIL import Image

# Tecplot for Scientific Data Visualization
import tecplot as tp
from tecplot.exception import *
from tecplot.constant import *

# For 3d plotting
from skimage import measure

In [19]:
# Define Constants
amu = 1.67e-27
k_b = 1.38e-23
mu_0 = 1.257e-6
R_M = 2440e3 #m
m_p = 1.67e-27 # kg
e = 1.60218e-19 # C

# Define utility functions
def read_dataset(mypath):
    # Reads in file "mypath" and returns a dataset object. May take a while for larger files.

    print("reading:",mypath)
    # First connect to TecPlot
    tp.session.connect(port=7600)

    # Configure layout
    tp.new_layout()
    dataset = tp.data.load_tecplot(mypath)
    frame = tp.active_frame()
    frame.plot_type = PlotType.Cartesian3D

    # Return dataset
    return dataset

def Bz_dip(x_array,y_array,z_array):
    # Input: arrays of x,y,z (in planet centered coords).
    # Output: Bz at each point
    
    return - 200.9 * (3*(z_array-0.2)**2 - (x_array**2+y_array**2+(z_array-0.2)**2))/((x_array**2+y_array**2+(z_array-0.2)**2)**(5/2))

def get_files(dir, key=".*cut_particle_region0_0.*", read_time = False, reduce = True):
    # For a directory "dir", return a list of all files which match the regex expression "key"
    
    all_files = [f for f in listdir(dir) if isfile(join(dir, f))]
    files=[]
    for file in all_files:
        match = re.search(key,file)
        if match != None:
            files.append(file)
    files.sort()
    
    # Now give them the appropriate name for their time
    # If we haven't already named these files with their time, do that now
    named_files = {}
    if read_time == False:
        for i in range(len(files)):
            time = round(i*dt+start_time,3)
            named_files[time] = files[i]
    # Otherwise, read the time right from the (last 6 elements) filename
    else:
        for i in range(len(files)):
            time = str("%.2f"%float(files[i][-6:]))
            named_files[time] = files[i]        
    
    # Now cut the list down to files inside t_bound
    if reduce:
        reduced_files = {}
        for time in list(named_files.keys())[int((t_bound[0]-start_time)/dt):int((t_bound[1]-start_time)/dt)]: #only loop over the times within t_bound
            reduced_files[time] = str(named_files[time])
        return reduced_files

    else:
        return named_files

def dat_to_plt(dir,files):
    # Hand it a directory with the dict of files in it, and it will convert them to .plt and save in dir
    
    for file in files:
        dataset=read_dataset(str(dir+files[file]))
        print("saving file:",str(dir+files[file][:-3]+"plt"))
        tp.data.save_tecplot_plt(str(dir+files[file][:-3]+"plt"))
        os.remove(str(dir+files[file]))
        print(f"Deleted original .dat file: {files[file]}")

def plt_to_numpy(dataset,var_ls=["Bz"],save_cs = True):
    # Input: the path to a .plt file, and a list of variables to convert into a numpy meshgrid
    # Output: a dictionary of arrays, each labelled according to its name in var_ls
    # Var_ls should be *extensive*, so that this long process does not need to be rerun

    # Extract the coordinate axes
    x_axis = np.unique(dataset.variable("X").values(0).as_numpy_array())
    y_axis = np.unique(dataset.variable("Y").values(0).as_numpy_array())[1:-1]
    z_axis = np.unique(dataset.variable("Z").values(0).as_numpy_array())

    # Create an ordered zone
    rect_zone = dataset.add_ordered_zone('rect_zone',[len(x_axis),len(y_axis-2),len(z_axis)])

    # Create 3D coordinate meshgrids
    xxx,yyy,zzz = np.meshgrid(x_axis,y_axis,z_axis)

    # Assign coordinate values to the rect_zone using the meshgrids
    rect_zone.values('X')[:] = xxx.ravel()
    rect_zone.values('Y')[:] = yyy.ravel()
    rect_zone.values('Z')[:] = zzz.ravel()

    # Compute derivatives in tecplot, which does it efficiently
    # Compute current density, in A/m^2
    if ("Jx" in var_ls) or ("Jy" in var_ls) or ("Jz" in var_ls):
        print("Computing J = ∇xB")
        tp.data.operate.execute_equation(equation='{Jx} = (ddy({Bz}) - ddz({By}))/(1.2566*10**(-6))/2440000*10**(-9)',
            ignore_divide_by_zero=True)
        tp.data.operate.execute_equation(equation='{Jy} = (ddz({Bx}) - ddx({Bz}))/(1.2566*10**(-6))/2440000*10**(-9)',
            ignore_divide_by_zero=True)
        tp.data.operate.execute_equation(equation='{Jz} = (ddx({By}) - ddy({Bx}))/(1.2566*10**(-6))/2440000*10**(-9)',
            ignore_divide_by_zero=True)

    # Compute plasma pressure gradient, in nPa / m
    if ("dp_dx" in var_ls) or ("dp_dy" in var_ls) or ("dp_dz" in var_ls):
        print("Computing ∇$p$")
        tp.data.operate.execute_equation(equation='{dp_dx} = (ddx({pxxS1}+{pxxS0}+{pyyS1}+{pyyS0}+{pzzS1}+{pzzS0}))/3/2440000',
        ignore_divide_by_zero=True)
        tp.data.operate.execute_equation(equation='{dp_dy} = (ddy({pxxS1}+{pxxS0}+{pyyS1}+{pyyS0}+{pzzS1}+{pzzS0}))/3/2440000',
            ignore_divide_by_zero=True)
        tp.data.operate.execute_equation(equation='{dp_dz} = (ddz({pxxS1}+{pxxS0}+{pyyS1}+{pyyS0}+{pzzS1}+{pzzS0}))/3/2440000',
            ignore_divide_by_zero=True)
        
    # Compute magnetic field gradient , in nT / m
    if ("dB_dx" in var_ls) or ("dB_dy" in var_ls) or ("dB_dz" in var_ls): 
        print("Computing ∇B")
        tp.data.operate.execute_equation(equation='{dB_dx} = (ddx(({Bx}*{Bx}+{By}*{By}+{Bz}*{Bz})**(0.5)))/2440000',
            ignore_divide_by_zero=True)
        tp.data.operate.execute_equation(equation='{dB_dy} = (ddy(({Bx}*{Bx}+{By}*{By}+{Bz}*{Bz})**(0.5)))/2440000',
            ignore_divide_by_zero=True)
        tp.data.operate.execute_equation(equation='{dB_dz} = (ddz(({Bx}*{Bx}+{By}*{By}+{Bz}*{Bz})**(0.5)))/2440000',
            ignore_divide_by_zero=True)

    print("Beginning interpolation...")
    # Interpolate onto rect_zone
    tp.data.operate.interpolate_linear(source_zones=[0],
        destination_zone=1,
        fill_value=0)

    # Define dictionary to save results
    data3d = {"X":xxx,"Y":yyy,"Z":zzz}

    # All all variables to data
    for var in var_ls:
        data3d[var] = rect_zone.values(var).as_numpy_array().reshape(xxx.shape)

    # Save in place
    print("Extraction complete! Saving 3D data ...")
    save_file = open(str(dir+file[:-4]+"_numpy_t_"+'{:06.2f}'.format(round(time,2))), 'wb') 
    pickle.dump(data3d, save_file) 
    print("Done!")

    if save_cs:
        #Calculate the plasma beta meshgrid
        beta_meshgrid = (2*mu_0*(rect_zone.values("pxxS0").as_numpy_array()+rect_zone.values("pyyS0").as_numpy_array()+rect_zone.values("pzzS0").as_numpy_array()+rect_zone.values("pxxS1").as_numpy_array()+
                                rect_zone.values("pyyS1").as_numpy_array()+rect_zone.values("pzzS1").as_numpy_array())*1e9/3/(rect_zone.values("Bx").as_numpy_array()**2+rect_zone.values("By").as_numpy_array()**2+rect_zone.values("Bz").as_numpy_array()**2)).reshape(xxx.shape)
        beta_meshgrid[np.isnan(beta_meshgrid)] = -1
        
         # New code: extract all of the Z coords, smooth them, and then find the values interpolated to those points!
        data = {"X":xxx[:,:,0],"Y":yyy[:,:,0]} #, "Z":np.zeros_like(xxx[:,:,0])}
        print("Saving cs data...")
        # Define empty array to save the unsmoothed Z values to
        Z_rough = np.zeros_like(xxx[:,:,0])+0.2
        # At each x/y, find the z coord of max beta and save that
        for idy in range(len(yyy[:,0,0])):
            for idx in range(len(xxx[0,:,0])):
                # If any of the beta values in this column are -1 near the middle, that means we are at the inner edge of the bounding box
                # In this case, set the data value to 0 to let me know where the boundary is!
                if np.min(beta_meshgrid[idy,idx,len(z_axis)//4:-len(z_axis)//4]) == -1:
                    Z_rough[idy,idx] = 0.2
                else:
                    idz = np.argmax(beta_meshgrid[idy,idx,:])
                    Z_rough[idy,idx] = zzz[idy,idx,idz]
    
        # Smoothing parameter
        smoothing_param = 5
        # Smooth the Z meshgrid
        data['Z'] = smooth_meshgrid(xxx[:,:,0], yyy[:,:,0], Z_rough, smoothing_param)
    
        # Use this as a template to extract all the other data with
        for name in var_ls:
            data[name] = np.zeros_like(xxx[:,:,0])
    
            # Extract each variable from tecplot as an array
            var = rect_zone.values(name).as_numpy_array().reshape(xxx.shape)
    
            # At each x/y, find the z coord of max beta and save that
            for idy in range(len(yyy[:,0,0])):
                for idx in range(len(xxx[0,:,0])):
                    # If any of the beta values in this column are -1 near the middle, that means we are at the inner edge of the bounding box
                    # In this case, set the data value to 0 to let me know where the boundary is!
                    if np.min(beta_meshgrid[idy,idx,len(z_axis)//4:-len(z_axis)//4]) == -1:
                        data[name][idy,idx] = 0
                    else:
                        # Find the indices of the two nearest points
                        lower_idz = np.searchsorted(zzz[idy,idx,:], data['Z'][idy,idx]) - 1
                        upper_idz = lower_idz + 1
    
                        # Get the coordinates of the nearest points
                        Z_lower = zzz[idy,idx,lower_idz]
                        Z_upper = zzz[idy,idx,upper_idz]
                        var_lower = var[idy,idx,lower_idz]
                        var_upper = var[idy,idx,upper_idz]
                        
                        # Perform linear interpolation
                        data[name][idy,idx] = var_lower + (var_upper - var_lower) * (data['Z'][idy,idx] - Z_lower) / (Z_upper - Z_lower)
    
        print("Done!")
        save_file = open(str(dir+file[:-4]+"_csdata_t_"+'{:06.2f}'.format(round(time,2))), 'wb') 
        pickle.dump(data, save_file) 


def smooth_meshgrid(X, Y, Z, smoothing_param):
    """
    Smooth the Z values of a meshgrid defined by X, Y coordinates using a Gaussian filter.
    
    Parameters:
    X (2D numpy array): The X coordinates of the meshgrid.
    Y (2D numpy array): The Y coordinates of the meshgrid.
    Z (2D numpy array): The Z coordinates of the meshgrid.
    smoothing_param (float): The standard deviation for the Gaussian kernel, controlling the smoothing.
    
    Returns:
    Z_smoothed (2D numpy array): The smoothed Z values of the meshgrid.
    """
    # Check if X, Y, Z are of the same shape
    if X.shape != Y.shape or X.shape != Z.shape:
        raise ValueError("X, Y, and Z meshgrids must have the same shape")
    
    # Apply Gaussian filter to the Z meshgrid
    Z_smoothed = gaussian_filter(Z, sigma=smoothing_param)
    
    return Z_smoothed

def plot_sphere(ax, radius=1, center=(0, 0, 0), color='b', alpha=0.5, zorder = 1,quarter=False, xlims = [-10,10], ylims = [-10,10], zlims = [-10,10]):
    """
    Plots a sphere of given radius centered at center on the provided 3D axis.
    
    Parameters:
    - ax: The 3D axis to plot the sphere on.
    - radius: The radius of the sphere (default: 1).
    - center: The (x, y, z) coordinates of the sphere's center (default: (0, 0, 0)).
    - color: The color of the sphere (default: blue).
    - alpha: The transparency of the sphere (default: 0.5).
    """
    u = np.linspace(np.pi/2, 3/2 * np.pi, 100)
    if quarter:
        v = np.linspace(0, np.pi/2, 100)
    else:
        v = np.linspace(0, np.pi, 100)
    x = radius * np.outer(np.cos(u), np.sin(v)) + center[0]
    y = radius * np.outer(np.sin(u), np.sin(v)) + center[1]
    z = radius * np.outer(np.ones(np.size(u)), np.cos(v)) + center[2]

    # Mask out any values outside the axes lims
    mask = (x < xlims[0]) | (x > xlims[1]) | (y < ylims[0]) | (y > ylims[1]) | (z < zlims[0]) | (z > zlims[1]) 
    x[mask] = np.nan
    y[mask] = np.nan
    z[mask] = np.nan

    ax.plot_surface(x, y, z, color=color, alpha=alpha, zorder=zorder)

def average_value(var_ls,t0,t_start,t_stop,type='csdata'):
    # Input: Variables to average, the current time (t0), and the times relative to present to average over (t0+t_start to t0+t_stop)
    # Ouput: dictionary of arrays of time-averaged values

    averages = {}
    count = 0

    temp_files = get_files(dir,key="3d\_fluid.*"+type+"\_t\_...\...",read_time = True, reduce = False)

    for t in list(temp_files.keys()): 
        # Check to see if this file is in the time range we want
        if (float(t) >= (t0+t_start)) and (float(t) <= t0+t_stop):
            temp_file = str(temp_files[t])
    
            # Read in this data
            with open(dir+temp_file, 'rb') as f:
                temp_data = pickle.load(f) 
            
            # Add the data to our running average for each variable
            for var in var_ls:
                if var not in averages.keys():
                    averages[var] = temp_data[var]
                else:
                    averages[var] += temp_data[var]

            count += 1
        
    # Divide by total time steps
    for var in var_ls:
        averages[var] = averages[var]/count

    return averages

def find_indices(X, Y, XX, YY):
    # Function used to get ix and iy for some coordinates X and Y
    X = np.array(X)
    Y = np.array(Y)
    XX = np.array(XX)
    YY = np.array(YY)

    ix = []
    iy = []
    
    for (x, y) in zip(X, Y):
        # Find the closest index in the meshgrid for the x coordinate
        ix_index = np.abs(XX[0] - x).argmin()
        # Find the closest index in the meshgrid for the y coordinate
        iy_index = np.abs(YY[:, 0] - y).argmin()
        
        ix.append(ix_index)
        iy.append(iy_index)
        
    return iy, ix

def remove_duplicate_rows(arr):
    # Used in df_tracker... does something to remove repeated rows in the matching matrix
    seen = set()
    filtered_rows = []
    for row in arr:
        if row[0] not in seen:
            filtered_rows.append(row)
            seen.add(row[0])
    return np.array(filtered_rows)

def find_boundary_points(X, Y):
    # Combine the coordinate lists into a single array of points
    points = np.column_stack((X, Y))

    # Compute the convex hull of the points
    hull = ConvexHull(points)

    # Extract the boundary points
    boundary_points = hull.vertices

    # Boundary points in original coordinate lists
    boundary_X = points[boundary_points, 0]
    boundary_Y = points[boundary_points, 1]

    return boundary_X.tolist(), boundary_Y.tolist()

def create_above_surface_mask(X, Y, Z, XX, YY, ZZ):
    # Works out all the 3D points above a 2D surface ie all the points above the current sheet.
    # Used for 3D plotting to determine what is above what.
    # Check that XX, YY, ZZ have the same shape
    assert XX.shape == YY.shape == ZZ.shape, "Arrays XX, YY, and ZZ must have the same shape"
    
    # Check that X, Y, Z have the same shape
    assert X.shape == Y.shape == Z.shape, "Arrays X, Y, and Z must have the same shape"

    # Determine the shape of the input arrays
    nx, ny, nz = XX.shape

    # Initialize a mask with the same shape as ZZ
    mask = np.zeros_like(ZZ, dtype=bool)

    # Iterate over the entire 3D meshgrid
    for i in range(nx):
        for j in range(ny):
            # Find the index in the 2D arrays corresponding to the x and y coordinates
            xi = np.argmin(np.abs(X[0] - XX[i, j, 0]))
            yi = np.argmin(np.abs(Y[:, 0] - YY[i, j, 0]))

            # Compare ZZ with Z to determine the mask
            mask[i, j, :] = ZZ[i, j, :] > Z[yi, xi]

    return mask

In [7]:
# STEP ONE: process data
dir = "/Users/atcushen/Documents/MercuryModelling/runs/nightside_v4_run1/ta-4e/" #   # Directory with data 
start_time = 71                                              # What is the start time of the dataset?
t_bound = [71,93] #70,107                                                   # Start and stop times of this data to be processed
dt = 0.05                                                     # What is the timestZep between files?
convert_first = True # Set to true to make sure files are converted from .dat to .plt. Set to false if there are .dat files in the directory which do not need to be converted

var_ls = ["Bx","By","Bz","Ex","Ey","Ez","rhoS0","uxS0","uyS0","uzS0","pxxS0","pyyS0","pzzS0","pxyS0","pxzS0","pyzS0",
          "rhoS1","uxS1","uyS1","uzS1","pxxS1","pyyS1","pzzS1","pxyS1","pxzS1","pyzS1","Jx","Jy","Jz","dp_dx","dp_dy","dp_dz",
          "dB_dx","dB_dy","dB_dz"]


# RUN
if convert_first:
    files = get_files(dir,key="3d\_fluid.*\.dat")
    dat_to_plt(dir,files)

files = get_files(dir,key="3d\_fluid.*\.plt")

for time in list(files.keys()): 
    print("Extracting data for t =",time)
    file = str(files[time])

    # Read in dataset
    dataset = read_dataset(dir+file)
    # Save .plt as numpy data
    data = plt_to_numpy(dataset,var_ls=var_ls)


reading: /Users/atcushen/Documents/MercuryModelling/runs/nightside_v4_run1/ta-4e/3d_fluid_region0_0_t00000111_n00032464_amrex.dat
Connecting to Tecplot 360 TecUtil Server on:
    tcp://localhost:7600
Connection established.
saving file: /Users/atcushen/Documents/MercuryModelling/runs/nightside_v4_run1/ta-4e/3d_fluid_region0_0_t00000111_n00032464_amrex.plt
Deleted original .dat file: 3d_fluid_region0_0_t00000111_n00032464_amrex.dat
reading: /Users/atcushen/Documents/MercuryModelling/runs/nightside_v4_run1/ta-4e/3d_fluid_region0_0_t00000111_n00032491_amrex.dat
Connecting to Tecplot 360 TecUtil Server on:
    tcp://localhost:7600
Connection established.
saving file: /Users/atcushen/Documents/MercuryModelling/runs/nightside_v4_run1/ta-4e/3d_fluid_region0_0_t00000111_n00032491_amrex.plt
Deleted original .dat file: 3d_fluid_region0_0_t00000111_n00032491_amrex.dat
reading: /Users/atcushen/Documents/MercuryModelling/runs/nightside_v4_run1/ta-4e/3d_fluid_region0_0_t00000111_n00032518_amrex.dat


ZMQError: Operation cannot be accomplished in current state

In [None]:
# STEP TWO: read and plot
dir = "/Users/atcushen/Documents/MercuryModelling/runs/nightside_v4_run1/ta-234e/"     # Directory with data 
start_time = 30 
t_bound = [51.5,75]                                               # Start and stop times of this data to be plot
dt = 0.05         

# Plotmode settings
plot_preset = '3D_gridscale_df_tracker'

'''
"3D_Bz": Bz in PIC domain.

"3D_Bz1": Bz1 in PIC domain.

"3D_delta_Bz": delta Bz in PIC domain.

"3D_df_tracker": show DFs and track their trajectories, saving a dictionary of their data

"3D_field_lines": view of field line geometry in FLEKS domain

"3D_flux_tube_content":

"3D_current_sheet": Shows 3 xz plane slices of the plasma beta and current density, to demonstrate cs fit

"3D_gridscale_df_tracker": Similar to df_tracker, but does not compute DF-averaged quantities, instead finding full electric field and current terms
'''

# Zoom controls
do_zoom = True # whether panels should zoom
zoom_time_start = 51.50 # when to start zooming
zoom_time_end = 53.00 # when to end
zoom_x_range = [-2.25,-0.75] # what x region to zoom to
zoom_y_range = [-1.1,0.5] # what y region to zoom to
zoom_z_range = [-0.1,0.5]
azim_start = -110
azim_end = -130

#RUN
files3D = get_files(dir,key="3d\_fluid.*numpy\_t\_...\...",read_time = True)
filescs = get_files(dir,key="3d\_fluid.*csdata\_t\_...\...",read_time = True)
iter = 0
for time in list(files3D.keys()): 
    print("Plotting t =",time)

    # Read in the 3D data
    file3D = str(files3D[time])
    with open(dir+file3D, 'rb') as f:
        data3d = pickle.load(f) 
    
    # Read in current sheet data
    filecs = str(filescs[time])
    with open(dir+filecs, 'rb') as f:
        datacs = pickle.load(f) 

    # PLOT PRESET 'Bz'
    if plot_preset=='3D_Bz':
        fig = plt.figure(figsize=(13,6), constrained_layout=True)
        ax = fig.add_subplot(111, projection="3d",computed_zorder=False)

        # Unpack variables
        X = datacs["X"]
        Y = datacs["Y"]
        Z = datacs["Z"]
        Bz = datacs["Bz"]

        # Mask out values
        radius = 1.10
        mask = (X**2 + Y**2) < radius**2
        Z[mask] = np.nan

        mid = 71 # Row index corresponding to midnight 
        # Define colormap and lighting
        norm = plt.Normalize(-50,50)
        dawn_colors = cm.bwr(norm(Bz[:mid,:]),alpha=0.5)
        dusk_colors = cm.bwr(norm(Bz[mid:,:]),alpha=0.5)

        # Set the lighting
        light = LightSource()  # Azimuth and altitude of the light source
        dawn_illuminated_colors = light.shade_rgb(dawn_colors, Z[:mid,:], blend_mode='soft')  # Apply light source shading
        dusk_illuminated_colors = light.shade_rgb(dusk_colors, Z[mid:,:], blend_mode='soft')  # Apply light source shading

        # move camera view
        ax.view_init(elev=20, azim=-110)

        # Create the surface plot
        surf1 = ax.plot_surface(X[:mid,:], Y[:mid,:], Z[:mid,:], facecolors=dawn_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=2)
        surf2 = ax.plot_surface(X[mid:,:], Y[mid:,:], Z[mid:,:], facecolors=dusk_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=1.5)
        plot_sphere(ax,radius=1,color='lightgrey',alpha=0.5,zorder=-1)
        plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=-0.75)
        
        # Add stream traces
        nsteps = 10000
        step_size = 0.001
        tracer = StreamTracer(nsteps, step_size)
        ny,nx,nz = data3d["Bx"].shape
        
        field = np.zeros((nx,ny,nz,3))
        field[:,:,:,0] = np.transpose(data3d["Bx"],axes=[1,0,2])
        field[:,:,:,1] = np.transpose(data3d["By"],axes=[1,0,2])
        field[:,:,:,2] = np.transpose(data3d["Bz"],axes=[1,0,2])
        
        grid_spacing = [1/64,1/64,1/64]
        grid = VectorGrid(field, grid_spacing, origin_coord = [data3d["X"].min(),data3d["Y"].min(),data3d["Z"].min()])
        
        seeds = np.array([[-1.5,0,0.2], [-1.75,0,0.2],[-2,0,0.2],[-2.25,0,0.2]])
        tracer.trace(seeds, grid)

        for seed in range(len(seeds)):
            start = np.where(tracer.xs[seed][:,2]>=seeds[seed][2])[0][0]
            ax.plot(tracer.xs[seed][start:,0],tracer.xs[seed][start:,1],tracer.xs[seed][start:,2],color='black',linewidth=0.5,zorder=2)

        # Add a color bar 
        m = cm.ScalarMappable(cmap=cm.bwr, norm=norm)
        m.set_array(Bz)
        clb = fig.colorbar(m, ax=ax, shrink=0.3, aspect=7,anchor=(0.5,0.3))
        clb.ax.tick_params(labelsize=12)
        clb.ax.set_title('$B_z$ [nT]',fontsize=12)

        # Set axes
        z_lower = -0.6
        z_upper = 1.0
        ax.set_zlim(z_lower,z_upper)
        x_range = X.max() - X.min()
        y_range = Y.max() - Y.min()
        z_range = z_upper - z_lower
        ax.set_box_aspect([x_range, y_range, z_range])  # Aspect ratio is set based on the data limits
        
        # Add labels
        ax.set_xlabel("X [$R_M$]",fontsize=12)
        ax.set_ylabel("Y [$R_M$]",fontsize=12)
        ax.set_zlabel("Z [$R_M$]",fontsize=12)
        ax.tick_params(axis='both',labelsize=12)
        ax.set_title(str("$B_z$ at t="+time+"s"),fontsize=12,y=1.0, pad=-14)

        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+time+'.png'),bbox_inches='tight',pad_inches=0.3, dpi=300)
        plt.close(fig)

    # PLOT PRESET '3D_Bz1'
    if plot_preset=='3D_Bz1':
        fig = plt.figure(figsize=(13,6), constrained_layout=True)
        ax = fig.add_subplot(111, projection="3d",computed_zorder=False)

        # Unpack variables
        Xcs = datacs["X"]
        Ycs = datacs["Y"]
        Zcs = datacs["Z"]
        Bzcs = datacs["Bz"]
        Bz1cs = Bzcs - Bz_dip(Xcs,Ycs,Zcs)
        
        X3d = data3d["X"]
        Y3d = data3d["Y"]
        Z3d = data3d["Z"]
        Bz3d = data3d["Bz"]
        Bz13d = Bz3d - Bz_dip(X3d,Y3d,Z3d)

        # Mask out values
        radius = 1.10
        mask = (Xcs**2 + Ycs**2) < radius**2
        Zcs[mask] = np.nan

        mid = 71 # Row index corresponding to midnight 
        # Define colormap and lighting
        vmin=-100
        vmax=0
        norm = plt.Normalize(-100,10)
        dawn_colors = cm.plasma(norm(Bz1cs[:mid,:]),alpha=0.5)
        dusk_colors = cm.plasma(norm(Bz1cs[mid:,:]),alpha=0.5)

        # Set the lighting
        light = LightSource()  # Azimuth and altitude of the light source
        dawn_illuminated_colors = light.shade_rgb(dawn_colors, Z[:mid,:], blend_mode='soft')  # Apply light source shading
        dusk_illuminated_colors = light.shade_rgb(dusk_colors, Z[mid:,:], blend_mode='soft')  # Apply light source shading

        # move camera view
        ax.view_init(elev=5, azim=-90)

        # Create the surface plot
        surf1 = ax.plot_surface(X[:mid,:], Y[:mid,:], Z[:mid,:], facecolors=dawn_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=2)
        surf2 = ax.plot_surface(X[mid:,:], Y[mid:,:], Z[mid:,:], facecolors=dusk_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=1)
        plot_sphere(ax,radius=1,color='lightgrey',alpha=0.5,zorder=1)
        plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=1.25)
        
        # Add isosurfaces
        b1min = 0
        iso = ax.scatter(X3d[Bz13d>b1min],Y3d[Bz13d>b1min],Z3d[Bz13d>b1min],c=Bz13d[Bz13d>b1min],
                         vmin=vmin,vmax=vmax,cmap='plasma',s=1,alpha=0.5)

        # Add a color bar 
        m = cm.ScalarMappable(cmap=cm.plasma, norm=norm)
        m.set_array(Bz1cs)
        clb = fig.colorbar(m, ax=ax, shrink=0.3, aspect=7,anchor=(0.5,0.3))
        clb.ax.tick_params(labelsize=12)
        clb.ax.set_title('$B_{z1}$ [nT]',fontsize=12,pad=10)

        # Set axes
        z_lower = -0.6
        z_upper = 1.0
        ax.set_zlim(z_lower,z_upper)
        x_range = X.max() - X.min()
        y_range = Y.max() - Y.min()
        z_range = z_upper - z_lower
        ax.set_box_aspect([x_range, y_range, z_range])  # Aspect ratio is set based on the data limits
        
        # Add labels
        ax.set_xlabel("X [$R_M$]",fontsize=12)
        ax.set_ylabel("Y [$R_M$]",fontsize=12)
        ax.set_zlabel("Z [$R_M$]",fontsize=12)
        ax.tick_params(axis='both',labelsize=12)
        ax.set_title(str("$B_{z1}$ at t="+time+"s"),fontsize=12,y=1.0, pad=-14)

        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+time+'.png'),bbox_inches='tight',pad_inches=0.3, dpi=300)
        plt.close(fig)


    # PLOT PRESET '3D_delta_Bz'
    if plot_preset=='3D_delta_Bz':
        fig = plt.figure(figsize=(13,6), constrained_layout=True)
        ax = fig.add_subplot(111, projection="3d",computed_zorder=False)
        x_cutoff = 0 # How many indices in X to cutoff
        z_cutoff = 0 # How many indices in Z to cutoff

        # Unpack variables
        Xcs = datacs["X"][:,x_cutoff:]
        Ycs = datacs["Y"][:,x_cutoff:]
        Zcs = datacs["Z"][:,x_cutoff:]
        Bzcs = datacs["Bz"][:,x_cutoff:]
        deltaBzcs = Bzcs - average_value(["Bz"],float(time),-5,-2,type='csdata')["Bz"][:,x_cutoff:]
        
        X3d = data3d["X"][:,x_cutoff:,z_cutoff:]
        Y3d = data3d["Y"][:,x_cutoff:,z_cutoff:]
        Z3d = data3d["Z"][:,x_cutoff:,z_cutoff:]
        Bz3d = data3d["Bz"][:,x_cutoff:,z_cutoff:]
        deltaBz3d = Bz3d - average_value(["Bz"],float(time),-5,2,type='numpy')["Bz"][:,x_cutoff:,z_cutoff:]


        # Mask out values
        radius = 1.01
        mask = (Xcs**2 + Ycs**2) < radius**2
        Zcs[mask] = np.nan

        mid = 71 # Row index corresponding to midnight 
        # Define colormap and lighting
        vmin=-25
        vmax=25
        norm = plt.Normalize(vmin,vmax)
        dawn_colors = cm.bwr(norm(deltaBzcs[:mid,:]),alpha=0.9)
        dusk_colors = cm.bwr(norm(deltaBzcs[mid:,:]),alpha=0.9)

        # Set the lighting
        light = LightSource()  # Azimuth and altitude of the light source
        dawn_illuminated_colors = light.shade_rgb(dawn_colors, Zcs[:mid,:], blend_mode='soft')  # Apply light source shading
        dusk_illuminated_colors = light.shade_rgb(dusk_colors, Zcs[mid:,:], blend_mode='soft')  # Apply light source shading

        # move camera view
        ax.view_init(elev=30, azim=-160)

        # Create the surface plot
        surf1 = ax.plot_surface(Xcs[:mid,:], Ycs[:mid,:], Zcs[:mid,:], facecolors=dawn_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=2)
        surf2 = ax.plot_surface(Xcs[mid:,:], Ycs[mid:,:], Zcs[mid:,:], facecolors=dusk_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=1.5)
        plot_sphere(ax,radius=1,color='lightgrey',alpha=0.5,zorder=1)
        plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=1.25)
        
        # Add isosurfaces
        Bmin = 10
        above_mask = (deltaBz3d>Bmin) & create_above_surface_mask(Xcs, Ycs, Zcs, X3d, Y3d, Z3d) 
        #iso1 = ax.scatter(X3d[above_mask],Y3d[above_mask],Z3d[above_mask],c=deltaBz3d[above_mask],cmap='bwr',vmin=vmin,vmax=vmax,
        #                 s=0.8,zorder=3,alpha = np.clip(((deltaBz3d[above_mask]-Bmin)/(vmax-Bmin))**0.2,0,1))
        iso1 = ax.scatter(X3d[above_mask],Y3d[above_mask],Z3d[above_mask],c=(-Z3d[above_mask]),cmap='Greens',vmin=-0.6,vmax=0.1,
                         s=0.8,zorder=3,alpha = 0.8)#,alpha = np.clip(((deltaBz3d[above_mask]-Bmin)/(vmax-Bmin))**0.2,0,1))
        below_mask = (deltaBz3d>Bmin) & ~(create_above_surface_mask(Xcs, Ycs, Zcs, X3d, Y3d, Z3d)) 
        iso2 = ax.scatter(X3d[below_mask],Y3d[below_mask],Z3d[below_mask],c=(-Z3d[below_mask]),cmap='Greens',vmin=-0.5,vmax=0.5,
                         s=0.8,zorder=0.75,alpha = 0.8)#,alpha = np.clip(((deltaBz3d[below_mask]-Bmin)/(vmax-Bmin))**0.2,0,1))

        # Add Bz=0 line
        #contour_mask = np.isclose(Bzcs, 0, atol=0.5) & (Xcs<-1.3)
        #ax.scatter(Xcs[contour_mask], Ycs[contour_mask], Zcs[contour_mask],c='black',s=0.03,zorder=2.1)
        
        # Add a color bar 
        m = cm.ScalarMappable(cmap=cm.bwr, norm=norm)
        m.set_array(deltaBzcs)
        clb = fig.colorbar(m, ax=ax, shrink=0.3, aspect=7,anchor=(-0.5,0.3))
        clb.ax.tick_params(labelsize=12)
        clb.ax.set_title('$\delta B_{z}$ [nT]',fontsize=12,pad=10)

        # Add big x axis
        ax.plot([np.min(Xcs[:mid,:]),-1],[0,0],[0.2,0.2],color='black',lw=2)
        ax.scatter([-4,-3,-2,-1],[0,0,0,0],[0.2,0.2,0.2,0.2],s=5,color='black')

        # Set axes
        z_lower = -0.1
        z_upper = 1.0
        ax.set_zlim(z_lower,z_upper)
        x_range = Xcs.max() - Xcs.min()
        y_range = Ycs.max() - Ycs.min()
        z_range = z_upper - z_lower
        ax.set_box_aspect([x_range, y_range, z_range])  # Aspect ratio is set based on the data limits
        
        # Add labels
        ax.set_xlabel("\nX [$R_M$]",fontsize=12)
        ax.set_ylabel("Y [$R_M$]",fontsize=12)
        ax.set_zlabel("Z [$R_M$]",fontsize=12)
        ax.tick_params(axis='both',labelsize=12)
        ax.set_title(str("$\delta B_{z}$ at t="+time+"s"),fontsize=12,y=1.0, pad=-14)

        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+time+'.png'),bbox_inches='tight',pad_inches=0.3, dpi=300)
        plt.close(fig)

    if plot_preset=='3D_field_lines':

        # Set up plot environment
        fig = plt.figure(figsize=(18,6))
        ax = fig.add_subplot(111, projection="3d",computed_zorder=False)
        
        # Unpack data
        Xcs = datacs["X"]
        Ycs = datacs["Y"]
        Zcs = datacs["Z"]
        X3d = data3d["X"]
        Y3d = data3d["Y"]
        Z3d = data3d["Z"]
        Bx3d = data3d["Bx"]
        By3d = data3d["By"]
        Bz3d = data3d["Bz"]
        ncs = datacs["rhoS1"]
        n3d = data3d["rhoS1"] * 1e6 # convert to SI
        pe3d = ((data3d["pxxS0"]+data3d["pyyS0"]+data3d["pzzS0"])/3*1e-9) # convert to SI
        pi3d = ((data3d["pxxS1"]+data3d["pyyS1"]+data3d["pzzS1"])/3*1e-9) # convert to SI

        # Set axes
        z_lower = -1
        z_upper = 1
        ax.set_zlim(z_lower,z_upper)
        x_range = X3d.max() - X3d.min()
        y_range = Y3d.max() - Y3d.min()
        z_range = z_upper - z_lower
        ax.set_box_aspect([x_range, y_range, z_range])  # Aspect ratio is set based on the data limits

        # Set up grid for field line tracing
        ny,nx,nz = Bx3d.shape
        field = np.zeros((nx,ny,nz,3))
        field[:,:,:,0] = np.transpose(Bx3d,axes=[1,0,2])
        field[:,:,:,1] = np.transpose(By3d,axes=[1,0,2])
        field[:,:,:,2] = np.transpose(Bz3d,axes=[1,0,2])
        grid_spacing = [1/64,1/64,1/64]
        grid = VectorGrid(field, grid_spacing, origin_coord = [X3d.min(),Y3d.min(),Z3d.min()])
        nsteps = 10000
        step_size = 0.001
        tracer = StreamTracer(nsteps, step_size)

        # Add stream trace seed at each CS x,y
        trace_skip=16
        seeds = np.zeros((len(Xcs[::trace_skip,::trace_skip].ravel()),3))
        seeds[:,0] = Xcs[::trace_skip,::trace_skip].ravel()
        seeds[:,1] = Ycs[::trace_skip,::trace_skip].ravel()
        seeds[:,2] = Zcs[::trace_skip,::trace_skip].ravel()

        # Trace the field lines
        tracer.trace(seeds, grid)

        # Integrate quantities along field line
        #for i,seed in enumerate(seeds):
        #    print(tracer.xs[i][0,:])

        # Plot field lines

        # Define colormap to use for field lines
        cmap = plt.colormaps["Reds"] #plt.get_cmap('Greys')  # You can choose any colormap you like
        norm = plt.Normalize(-1.2, 0.8)
        
        for i,seed in enumerate(seeds):
            above = np.where(tracer.xs[i][:,2]>=seeds[i][2])[0]
            below = np.where(tracer.xs[i][:,2]<seeds[i][2])[0]
            # Plot the streamlines as a series of lines, without connecting between places where the indexing jumps
            start = 0 
            for j in range(1,len(above)):
                if (above[j]-above[j-1]>1) or (j==(len(above)-1)):
                    x = tracer.xs[i][above[start:j-1],0]
                    y = tracer.xs[i][above[start:j-1],1]
                    z = tracer.xs[i][above[start:j-1],2]
                    points = np.array([x,y,z]).T.reshape(-1, 1, 3)
                    segments = np.concatenate([points[:-1], points[1:]], axis=1)
                    lc = Line3DCollection(segments, cmap=cmap, norm=norm, linewidth=0.5)
                    lc.set_array(-z)
                    ax.add_collection3d(lc)
                    #ax.plot(tracer.xs[i][above[start:j-1],0],tracer.xs[i][above[start:j-1],1],tracer.xs[i][above[start:j-1],2],
                    #       color="black",lw=0.3,alpha=1,zorder=3.6) 
                    start = j
                
            start = 0
            for j in range(1,len(below)):
                if (below[j]-below[j-1]>1) or (j==(len(below)-1)):
                    x = tracer.xs[i][below[start:j-1],0]
                    y = tracer.xs[i][below[start:j-1],1]
                    z = tracer.xs[i][below[start:j-1],2]
                    points = np.array([x,y,z]).T.reshape(-1, 1, 3)
                    segments = np.concatenate([points[:-1], points[1:]], axis=1)
                    lc = Line3DCollection(segments, cmap=cmap, norm=norm, linewidth=0.5)
                    lc.set_array(-z)
                    ax.add_collection3d(lc)
                    #ax.plot(tracer.xs[i][below[start:j-1],0],tracer.xs[i][below[start:j-1],1],tracer.xs[i][below[start:j-1],2],
                    #       color="black",lw=0.3,alpha=1,zorder=1.5) 
                    start = j

        # Mask out values
        radius = 1.01
        mask = (Xcs**2 + Ycs**2) < radius**2
        Zcs[mask] = np.nan

        mid = 71 # Row index corresponding to midnight 
        # Define colormap and lighting
        vmin=0.1
        vmax=10
        norm = LogNorm(vmin,vmax) #plt.Normalize(vmin,vmax)
        dawn_colors = cm.plasma(norm(ncs[:mid,:]),alpha=0.01)
        dusk_colors = cm.plasma(norm(ncs[mid:,:]),alpha=0.01)

        # Set the lighting
        light = LightSource()  # Azimuth and altitude of the light source
        dawn_illuminated_colors = light.shade_rgb(dawn_colors, Zcs[:mid,:], blend_mode='soft')  # Apply light source shading
        dusk_illuminated_colors = light.shade_rgb(dusk_colors, Zcs[mid:,:], blend_mode='soft')  # Apply light source shading

        # move camera view
        ax.view_init(elev=35, azim=-145)

        # Create the surface plot
        #surf1 = ax.plot_surface(Xcs[:mid,:], Ycs[:mid,:], Zcs[:mid,:], facecolors=dawn_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=2)
        #surf2 = ax.plot_surface(Xcs[mid:,:], Ycs[mid:,:], Zcs[mid:,:], facecolors=dusk_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=1.5)
        plot_sphere(ax,radius=1,color='lightgrey',alpha=0.5,zorder=1)
        plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=1.25)

        # Add a color bar 
        m = cm.ScalarMappable(cmap=cm.plasma, norm=norm)
        m.set_array(ncs)
        clb = fig.colorbar(m, ax=ax, shrink=0.3, aspect=7,anchor=(-0.5,0.3))
        clb.ax.tick_params(labelsize=12)
        clb.ax.set_title('$n$ [cm$^{-3}$]',fontsize=12,pad=10)

        # Add big x axis
        ax.plot([np.min(Xcs[:mid,:]),-1],[0,0],[0.2,0.2],color='black',lw=0.8,zorder=3)
        ax.scatter([-4,-3,-2,-1],[0,0,0,0],[0.2,0.2,0.2,0.2],s=5,color='black',zorder=3)
        
        # Add labels
        ax.set_xlabel("\nX [$R_M$]",fontsize=12)
        ax.set_ylabel("Y [$R_M$]",fontsize=12)
        ax.set_zlabel("Z [$R_M$]",fontsize=12)
        ax.tick_params(axis='both',labelsize=12)
        ax.set_title(str("Density at t="+time+"s"),fontsize=12,y=1.0, pad=-14)

        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+time+'.png'),bbox_inches='tight',pad_inches=0.3, dpi=300)
        plt.close(fig)

    
    if plot_preset=='3D_df_tracker':

        # On the first iteration, define an empty dictionary to save our DF data to
        if iter == 0:
            df_data = {} # This stores all the DFs we have seen
            df_dict = None # This stores all the cells of DFs from the previous step
            df_seeds = {} # This stores all the field line seed locations for the DFs

        # Define X cutoff, if required
        x_cutoff=0

        # Unpack data
        X = datacs["X"][:,x_cutoff:]
        Y = datacs["Y"][:,x_cutoff:]
        Z = datacs["Z"][:,x_cutoff:]
        X3d = data3d["X"][:,x_cutoff:,:]
        Y3d = data3d["Y"][:,x_cutoff:,:]
        Z3d = data3d["Z"][:,x_cutoff:,:]
        Bx3d = data3d["Bx"][:,x_cutoff:,:]
        By3d = data3d["By"][:,x_cutoff:,:]
        Bz3d = data3d["Bz"][:,x_cutoff:,:]
        Jx = datacs["Jx"][:,x_cutoff:]
        Jy = datacs["Jy"][:,x_cutoff:]
        Jz = datacs["Jz"][:,x_cutoff:]
        Bx = datacs["Bx"][:,x_cutoff:] 
        By = datacs["By"][:,x_cutoff:] 
        Bz = datacs["Bz"][:,x_cutoff:] 
        n = datacs["rhoS1"][:,x_cutoff:] * 1e6 # convert to SI
        pe = ((datacs["pxxS0"]+datacs["pyyS0"]+datacs["pzzS0"])/3*1e-9)[:,x_cutoff:] # convert to SI
        Te = pe/n/k_b / 11605 / 1e3 #Convert to keV
        uix = datacs["uxS1"][:,x_cutoff:]
        uiy = datacs["uyS1"][:,x_cutoff:]
        uiz = datacs["uyS1"][:,x_cutoff:]
        uex = datacs["uxS0"][:,x_cutoff:]
        uey = datacs["uyS0"][:,x_cutoff:]
        uez = datacs["uyS0"][:,x_cutoff:]
        beta = (2*mu_0*(datacs["pxxS0"]+datacs["pyyS0"]+datacs["pzzS0"]+datacs["pxxS1"]+datacs["pyyS1"]+datacs["pzzS1"])*1e9/3/(datacs["Bx"]**2+datacs["By"]**2+datacs["Bz"]**2))[:,x_cutoff:]
        E_convx = (-(uey*Bz-uez*By)*1000*1e-9) # Convert to V/m ie SI
        E_convy = (-(uez*Bx-uex*Bz)*1000*1e-9)
        E_convz = (-(uex*By-uey*Bx)*1000*1e-9)
        dp_dx = datacs['dp_dx'][:,x_cutoff:]
        dp_dy = datacs['dp_dy'][:,x_cutoff:]
        dp_dz = datacs['dp_dz'][:,x_cutoff:]

        # Compute average values of Bz in the 5seconds preceeding the current time
        Bz_avg = average_value(["Bz"],float(time),-1,0)["Bz"][:,x_cutoff:]
        delta_Bz3d = Bz3d - average_value(["Bz"],float(time),-5,-2,type='numpy')["Bz"][:,x_cutoff:,:]

        # Compute DF metric
        metric = (Bz-Bz_avg)

        # Set bounds on metric
        min_value = 10 #nT 
        min_size = 10
        dx = 1/64

        # Set z bounds
        z_lower = -1
        z_upper = 1

        # Compute zoom rates, if activated
        if do_zoom:
            zoom_dxdt_min = (zoom_x_range[0]-np.min(X))/(zoom_time_end-zoom_time_start)
            zoom_dxdt_max = (np.max(X)-zoom_x_range[1])/(zoom_time_end-zoom_time_start)
            zoom_dydt_min = (zoom_y_range[0]-np.min(Y))/(zoom_time_end-zoom_time_start)
            zoom_dydt_max = (np.max(Y)-zoom_y_range[1])/(zoom_time_end-zoom_time_start)
            zoom_dzdt_min = (zoom_z_range[0]-z_lower)/(zoom_time_end-zoom_time_start)
            zoom_dzdt_max = (z_upper-zoom_y_range[1])/(zoom_time_end-zoom_time_start)
            zoom_dazimdt = (azim_end - azim_start)/(zoom_time_end-zoom_time_start)
        
        # Find DF regions
        # Create boolean mask where Z exceeds z_0
        mask = metric > min_value
        
        # Label connected regions
        structure = np.zeros((3, 3), dtype=bool)  # Structuring element
        structure[1,:] = True
        structure[:,1] = True # Use a + shaped mask
        labeled, num_features = label(mask, structure=structure)
        
        # Find all the DF regions in this time slice
        new_df_dict = {}
        count=1
        for feature_num in range(1, num_features + 1):
            region = (labeled == feature_num)
            DF_beta = np.mean(beta[region])
            # Remove regions that are too small or have an average beta<1
            if (len(X[region])>min_size) and (DF_beta>1):
                new_df_dict[count] = (X[region], Y[region])
                count+=1
    
        print("Found",len(new_df_dict.keys()),"DFs at this time")

        ################# Plot #################
        fig = plt.figure(figsize=(18,6))
        ax = fig.add_subplot(111, projection="3d",computed_zorder=False)

        mid = 95 # Row index corresponding to midnight 

        # Mask out values
        if do_zoom and (float(time)>zoom_time_start):
            if (float(time)>=zoom_time_end):
                zoom_time = zoom_time_stop-zoom_time_start # Effectively applies a stopping condition, to "stay" zoomed
            else:
                zoom_time = float(time)-zoom_time_start
            zoom_xmin = np.min(X)+zoom_time*zoom_dxdt_min
            zoom_xmax = np.max(X)-zoom_time*zoom_dxdt_max
            zoom_ymin = np.min(Y)+zoom_time*zoom_dydt_min
            zoom_ymax = np.max(Y)-zoom_time*zoom_dydt_max
            zoom_zmin = z_lower+zoom_time*zoom_dzdt_min
            zoom_zmax = z_upper-zoom_time*zoom_dzdt_max
            zoom_mask = (X > zoom_xmax) | (X < zoom_xmin) | (Y > zoom_ymax) | (Y < zoom_ymin) 
            Z[zoom_mask] = np.nan
        radius = 1.01
        mask = (X**2 + Y**2) < radius**2
        Z[mask] = np.nan

        mid = 90 # Row index corresponding to midnight 
        # Define colormap and lighting
        vmin=-25
        vmax=25
        norm = plt.Normalize(vmin,vmax)
        dawn_colors = cm.bwr(norm(metric[:mid,:]),alpha=0.9)
        dusk_colors = cm.bwr(norm(metric[mid:,:]),alpha=0.9)

        # Set the lighting
        light = LightSource()  # Azimuth and altitude of the light source
        dawn_illuminated_colors = light.shade_rgb(dawn_colors, Z[:mid,:], blend_mode='soft')  # Apply light source shading
        dusk_illuminated_colors = light.shade_rgb(dusk_colors, Z[mid:,:], blend_mode='soft')  # Apply light source shading

        # move camera view
        if do_zoom and (float(time)>zoom_time_start):
            ax.view_init(elev=25, azim=azim_start + zoom_dazimdt*(zoom_time))
        else:
            ax.view_init(elev=25, azim=azim_start)
            
        # Create the surface plot
        surf1 = ax.plot_surface(X[:mid,:], Y[:mid,:], Z[:mid,:], facecolors=dawn_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=2)
        surf2 = ax.plot_surface(X[mid:,:], Y[mid:,:], Z[mid:,:], facecolors=dusk_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=0.75)
        
        # Add a color bar 
        m = cm.ScalarMappable(cmap=cm.bwr, norm=norm)
        m.set_array(metric)
        clb = fig.colorbar(m, ax=ax, shrink=0.3, aspect=7,anchor=(0.0,0.3))
        clb.ax.tick_params(labelsize=12)
        clb.ax.set_title('$\delta B_{z}$ [nT]',fontsize=12,pad=10)

        # Add big x axis
        if do_zoom and (float(time)>zoom_time_start):
            ax.plot([zoom_xmin,-1],[0,0],[0.2,0.2],color='black',lw=1)
            ax.scatter(np.arange(int(zoom_xmin),0),np.arange(int(zoom_xmin),0)*0,np.arange(int(zoom_xmin),0)*0+0.2,s=4,color='black')
            # Show Mercury
            plot_sphere(ax,radius=1,color='lightgrey',alpha=0.5,zorder=1,quarter=False,
                    xlims=[zoom_xmin,zoom_xmax],ylims=[zoom_ymin,zoom_ymax],zlims=[zoom_zmin,zoom_zmax])
            plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=1.25,quarter=False,
                    xlims=[zoom_xmin,zoom_xmax],ylims=[zoom_ymin,zoom_ymax],zlims=[zoom_zmin,zoom_zmax])
        else:
            ax.plot([np.min(X[:mid,:]),-1],[0,0],[0.2,0.2],color='black',lw=1)
            ax.scatter([-4,-3,-2,-1],[0,0,0,0],[0.2,0.2,0.2,0.2],s=4,color='black')
            # Show Mercury
            plot_sphere(ax,radius=1,color='lightgrey',alpha=0.5,zorder=1,quarter=False)
            plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=1.25,quarter=False)

        # Add labels
        ax.set_xlabel("X [$R_M$]",fontsize=12)
        ax.set_ylabel("Y [$R_M$]",fontsize=12)
        ax.set_zlabel("Z [$R_M$]",fontsize=12)
        ax.tick_params(axis='both',labelsize=12)
        ax.set_title(str("$\delta B_{z}$ at t="+time+"s"),fontsize=12,y=1.0, pad=-5)

        ################# END PLOT #################
        
        # Compare to previous df_dict, if any, and relabel DFs for continuity
    
        if df_dict is not None and len(new_df_dict.keys())>0: # Only proceed with attempting to match DFs if we have data from last timestep and there is at least 1 DF in this timestep
            # The name of the game is just to relabel all the keys appropriately.
            # Set up a new dictionary where we will make all these changes.
            next_df_dict = {}
    
            # Iterate through new_key_dict, which has all the dfs identified in this step (with keys which will generally be totally wrong)
            new_keys = list(new_df_dict.keys()).copy()
            overlap_masks = [] # Here, we will store key pairs: [new_key, old_key, agreement_lvl]
            for new_key in new_keys:
                for old_key in df_dict.keys():
                    # Compare all the currently identified DFs to those from the previous step, and save an entry to overlap_masks if any overlap
                    xmask = np.isin(new_df_dict[new_key][0],df_dict[old_key][0])
                    ymask = np.isin(new_df_dict[new_key][1],df_dict[old_key][1])
                    mask=xmask&ymask
                    if mask.any():
                        print("New DF#"+str(new_key)+" overlaps with old DF#"+str(old_key))
                        overlap_masks.append([new_key,old_key,sum(mask)]) # sum(mask) gives the number of "True" in the list
            # Sometimes a weird error happens where we have new DFs but none overlap and we have an empty matrix.. this is a hotfix for that case:
            #if len(overlap_masks)==0:
            #    df_dict = new_df_dict   
            #else:
            # Matrix stores the relationship between the DFs labelled at this time and the previous time.
            unfiltered_matrix = np.array(overlap_masks, ndmin=2)
            # If a new DF has appeared, we have not accounted for it yet (since it will have no overlap with the previous step).
            for key in new_keys:
                if len(overlap_masks)==0:
                    unfiltered_matrix = np.array([key,-1,0], ndmin=2) # In some cases, we have only new DFs and no overlap, so unfiltered matrix cannot be indexed in the next elif and the code crashes. This hotfix solves that.
                elif key not in unfiltered_matrix[:,0]:
                    unfiltered_matrix = np.vstack([unfiltered_matrix, [key,-1,0]]) # Add newly formed DFs to the register, and associate it with the previous DF -1 (i.e. assocaited with none)
            unfiltered_matrix = unfiltered_matrix[unfiltered_matrix[:,2].argsort()[::-1]] # Sort to start with largest overlap ones
            
            # We now need to remove repeated rows with repeated values of new_key (column zero) to stop an infinite cascade of new DFs
            # Now that we've sorted the data, the dfs with the most overlap will be selected for
            # Only need to filter out rows if there are more rows in matrix than the number of dfs at this time
            if len(unfiltered_matrix[:,0])>len(new_df_dict.keys()):
                matrix = remove_duplicate_rows(unfiltered_matrix)
            else:
                matrix = unfiltered_matrix
            
            print("NEW DF KEY   OLD DF KEY   MATCH")
            print(matrix)
    
            temp_key=-1
    
            for i in range(len(matrix[:,0])):
                if matrix[i,1]==-1: # this means its a newly formed DF in this step.
                    print("DF#"+str(matrix[i,0]),"is a new one and is temporarily assigned #"+str(temp_key))
                    r = np.mean(np.sqrt(new_df_dict[matrix[i,0]][0]**2+new_df_dict[matrix[i,0]][1]**2))
                    if r>1.25:
                        next_df_dict[temp_key] = new_df_dict[matrix[i,0]] # Give it a temporary name, we will come back to it at the end
                        temp_key-=1
                    else:
                        print("This DF formed too close to the planet, throwing it out...")
                        #print("position:",r,"    beta:",beta_DR)
                elif (matrix[i,1] not in next_df_dict.keys()): # Check to see if this DF has already been named for the updated dict. If its not there, add it
                    print("DF#"+str(matrix[i,1])+" has been tracked from the previous step")
                    next_df_dict[matrix[i,1]] = new_df_dict[matrix[i,0]] # The name of the DF is taken from df_dict, and is populated with data from the new dict. The matrix is used as a reference to connect the two.
                else: # This means this DF has already been identified with a previous DF that has more overlap with it i.e. it is a child
                    print("DF#"+str(matrix[i,1])+" has split and formed a new DF, which is temporarily assigned #"+str(temp_key))
                    next_df_dict[temp_key] = new_df_dict[matrix[i,0]] # Give it a temporary name, we will come back to it at the end
                    temp_key-=1
            # All DFs identified in this step have been assigned names in next_df_dict. Now, we need to rename the negative ones to the next largest names
            if len(df_data.keys())==0:
                new_df_key = 1    
            else:
                new_df_key = np.max(list(df_data.keys()))+1 # Start naming at one larger than the maximum df number already used
            for key in list(next_df_dict.keys()).copy():
                if key<0:
                    print("Reassigning the temporary DF#"+str(key),"to DF#"+str(new_df_key))
                    next_df_dict[new_df_key] = next_df_dict.pop(key)
                    new_df_key+=1
            print("Feature tracking complete!")
            df_dict = next_df_dict   
        else:
            df_dict = new_df_dict   

        ####### PLOT 2 SETUP ######
        color_ls = ["tab:blue","tab:green","tab:blue","tab:orange","tab:purple","tab:brown","tab:pink","tab:olive","tab:cyan"]
        # Set up grid for field line tracing
        ny,nx,nz = Bx3d.shape
        field = np.zeros((nx,ny,nz,3))
        field[:,:,:,0] = np.transpose(Bx3d,axes=[1,0,2])
        field[:,:,:,1] = np.transpose(By3d,axes=[1,0,2])
        field[:,:,:,2] = np.transpose(Bz3d,axes=[1,0,2])
        grid_spacing = [1/64,1/64,1/64]
        grid = VectorGrid(field, grid_spacing, origin_coord = [X3d.min(),Y3d.min(),Z3d.min()])
        nsteps = 10000
        step_size = 0.001
        tracer = StreamTracer(nsteps, step_size)
        ##### END PLOT 2 SETUP #########

            
        # For each DF, either create a new item to store info about it or add to an existing item
        for key in df_dict.keys():
            if key not in df_data.keys(): # Create new dataframe if this df has not been registered already
                df_data[key] = pd.DataFrame(columns=['time','X','Y','Z','Bx','By','Bz','Te','n','uix',"uiy","uiz",'uex',"uey","uez",
                                                     "E_convx","E_convy","E_convz","dp_dx","dp_dy","dp_dz",
                                                     "J_inrt,x","J_inrt,y","J_inrt,z","J_gradp,x","J_gradp,y","J_gradp,z",
                                                     "Jx","Jy","Jz","Bz_max",'area'])
            new_row = np.zeros(32)
            # Now we iterate over each coordinate associated with this DF
            Bz_max_ls = []
            for i in range(len(df_dict[key][0])):
                coord = [df_dict[key][0][i],df_dict[key][1][i]] # Remember, each item in df_dict is a tuple of the X coords and Y coords
                # Find which indices of "data" these coordinates correspond to
                ix = np.where(X[0,:]==coord[0])[0]
                iy = np.where(Y[:,1]==coord[1])[0]
                new_row[1] = new_row[1] + X[iy,ix].item()
                new_row[2] = new_row[2] + Y[iy,ix].item()
                new_row[3] = new_row[3] + Z[iy,ix].item()
                new_row[4] = new_row[4] + Bx[iy,ix].item()
                new_row[5] = new_row[5] + By[iy,ix].item()
                new_row[6] = new_row[6] + Bz[iy,ix].item()
                new_row[7] = new_row[7] + Te[iy,ix].item()
                new_row[8] = new_row[8] + n[iy,ix].item()
                new_row[9] = new_row[9] + uix[iy,ix].item()
                new_row[10] = new_row[10] + uiy[iy,ix].item()
                new_row[11] = new_row[11] + uiz[iy,ix].item()
                new_row[12] = new_row[12] + uex[iy,ix].item()
                new_row[13] = new_row[13] + uey[iy,ix].item()
                new_row[14] = new_row[14] + uez[iy,ix].item()
                new_row[15] = new_row[15] + E_convx[iy,ix].item()
                new_row[16] = new_row[16] + E_convy[iy,ix].item()
                new_row[17] = new_row[17] + E_convz[iy,ix].item()
                new_row[18] = new_row[18] + dp_dx[iy,ix].item()
                new_row[19] = new_row[19] + dp_dy[iy,ix].item()
                new_row[20] = new_row[20] + dp_dz[iy,ix].item()
                if len(df_data[key])>0: # Compute inertial current using the acceleration between last time step and this one
                    # Problem: The velocity from the last time step is the average... will that be an issue?
                    new_row[21] = new_row[21] + (n[iy,ix].item())*m_p/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)*1e-9) * (By[iy,ix].item()*(uiz[iy,ix].item()-df_data[key]['uiz'].iloc[-1])/dt - Bz[iy,ix].item()*(uiy[iy,ix].item()-df_data[key]['uiy'].iloc[-1])/dt)*1e3 #A/m^2
                    new_row[22] = new_row[22] + (n[iy,ix].item())*m_p/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)*1e-9) * (Bz[iy,ix].item()*(uix[iy,ix].item()-df_data[key]['uix'].iloc[-1])/dt - Bx[iy,ix].item()*(uiz[iy,ix].item()-df_data[key]['uiz'].iloc[-1])/dt)*1e3 #A/m^2
                    new_row[23] = new_row[23] + (n[iy,ix].item())*m_p/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)*1e-9) * (Bx[iy,ix].item()*(uiy[iy,ix].item()-df_data[key]['uiy'].iloc[-1])/dt - By[iy,ix].item()*(uix[iy,ix].item()-df_data[key]['uix'].iloc[-1])/dt)*1e3 #A/m^2
                else:
                    new_row[21] = 0
                    new_row[22] = 0
                    new_row[23] = 0
                new_row[24] = new_row[24] + 1/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)) * (By[iy,ix].item()*dp_dz[iy,ix].item()+Bz[iy,ix].item()*dp_dy[iy,ix].item()) #A/m^2
                new_row[25] = new_row[25] + 1/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)) * (Bz[iy,ix].item()*dp_dx[iy,ix].item()+Bx[iy,ix].item()*dp_dz[iy,ix].item()) #A/m^2
                new_row[26] = new_row[26] + 1/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)) * (Bx[iy,ix].item()*dp_dy[iy,ix].item()+By[iy,ix].item()*dp_dx[iy,ix].item()) #A/m^2
                new_row[27] = new_row[27] + Jx[iy,ix].item()
                new_row[28] = new_row[28] + Jy[iy,ix].item()
                new_row[29] = new_row[29] + Jz[iy,ix].item()

                Bz_max_ls.append(Bz[iy,ix]) # Save all the Bz values to find the max in the DF
                
            # Divide by the total number of cells for this DF to get the average quantity
            new_row = new_row/(i+1) 
            new_row[0] = time # Set the first column to the time
            new_row[30] = np.max(Bz_max_ls) # Set the 9th column to the max Bz
            new_row[31] = (1/64)**2*(i+1) # Set the last row to the area
            
            temp = df_data[key]
            temp.loc[len(temp)] = new_row
            df_data[key] = temp # Add this new row to the correct dataframe

            # Show a trace of each current DF's path
            #ax.plot(temp["X"],temp["Y"],temp["Z"], color = color_ls[key%10])

            ######################## PLOT2 START ################################

            # Plot outline and traces of each DF
        
            # Pull out X,Y,Z coords of each cell of this DF in the current sheet
            X_region, Y_region = df_dict[key]
            Z_region = Z[find_indices(X_region, Y_region, X, Y)].tolist()
            # Get boundary points to outline this df
            #X_bound,Y_bound = find_boundary_points(X_region, Y_region)
            #Z_bound = Z[find_indices(X_bound, Y_bound, X, Y)].tolist()
            # Make the plot wrap around as a loop
            #X_bound.append(X_bound[0])
            #Y_bound.append(Y_bound[0])
            #Z_bound.append(Z_bound[0])
            #ax.scatter(X_bound, Y_bound, Z_bound, alpha=0.05, color = color_ls[key%10],zorder=5)#, label=str('DF '+str(key)),zorder=5)
            #ax.plot(X_bound, Y_bound, Z_bound, alpha=0.9, color = color_ls[key%9],zorder=2.5, linewidth=1)#, label=str('DF '+str(key)),zorder=5)
            ax.scatter(X_region, Y_region, Z_region, s = 0.15, color = color_ls[key%9],zorder=6)#, label=str('DF '+str(key)),zorder=5)

            # Find seed points for field lines
            if key in df_seeds.keys():   
                # Update DF seeds through bulk electron velocity
                df_seeds[key][:,0] = df_seeds[key][:,0] + dt*df_data[key]['uex'].iloc[-1]*1e3/R_M
                df_seeds[key][:,1] = df_seeds[key][:,1] + dt*df_data[key]['uey'].iloc[-1]*1e3/R_M
                # Field lines tend to advect outside of the DF, so iterate through each to check
                for iseed in range(len(df_seeds[key])):
                    if (df_seeds[key][iseed,0] > np.max(X_region)) or (df_seeds[key][iseed,0] < np.min(X_region)) or (df_seeds[key][iseed,1] > np.max(Y_region)) or (df_seeds[key][iseed,1] < np.min(Y_region)):
                        new_loc = random.randint(0,len(X_region)-1)
                        print("Field line seed left the DF! Moved seed at",df_seeds[key][iseed,:],"to",X_region[new_loc],Y_region[new_loc],0.2)
                        df_seeds[key][iseed,0] = X_region[new_loc]
                        df_seeds[key][iseed,1] = Y_region[new_loc]
                        df_seeds[key][iseed,2] = 0.2
                    
            # For first time this DF is generated, make all new seed points
            else:
                # Add stream traces
                trace_skip=10
                df_seeds[key] = np.zeros((len(X_region)//trace_skip+1,3))
                df_seeds[key][:,0] = X_region[::max(len(X_region),trace_skip)]
                df_seeds[key][:,1] = Y_region[::max(len(X_region),trace_skip)]
                df_seeds[key][:,2] = Z_region[::max(len(X_region),trace_skip)]

            # Trace the field lines
            tracer.trace(df_seeds[key], grid)

            # Plot them
            for iseed in range(len(df_seeds[key])):
                if do_zoom and (float(time)>zoom_time_start):
                    above = np.where((tracer.xs[iseed][:,2]>=df_seeds[key][iseed][2]) & (tracer.xs[iseed][:,0]<zoom_xmax) & (tracer.xs[iseed][:,0]>zoom_xmin) & (tracer.xs[iseed][:,1]<zoom_ymax) & (tracer.xs[iseed][:,1]>zoom_xmin) & (tracer.xs[iseed][:,2]<zoom_zmax) & (tracer.xs[iseed][:,2]>zoom_zmin))[0]
                    below = np.where((tracer.xs[iseed][:,2]<df_seeds[key][iseed][2]) & (tracer.xs[iseed][:,0]<zoom_xmax) & (tracer.xs[iseed][:,0]>zoom_xmin) & (tracer.xs[iseed][:,1]<zoom_ymax) & (tracer.xs[iseed][:,1]>zoom_xmin) & (tracer.xs[iseed][:,2]<zoom_zmax) & (tracer.xs[iseed][:,2]>zoom_zmin))[0]
                else:
                    above = np.where(tracer.xs[iseed][:,2]>=df_seeds[key][iseed][2])[0]
                    below = np.where(tracer.xs[iseed][:,2]<df_seeds[key][iseed][2])[0]
                # Plot the streamlines as a series of lines, without connecting between places where the indexing jumps
                start = 0 
                for j in range(1,len(above)):
                    if (above[j]-above[j-1]>1) or (j==(len(above)-1)):
                        ax.plot(tracer.xs[iseed][above[start:j-1],0],tracer.xs[iseed][above[start:j-1],1],tracer.xs[iseed][above[start:j-1],2],
                               color=color_ls[key%9],lw=0.3,alpha=1,zorder=3.6) 
                        start = j
                    
                start = 0
                for j in range(1,len(below)):
                    if (below[j]-below[j-1]>1) or (j==(len(below)-1)):
                        ax.plot(tracer.xs[iseed][below[start:j-1],0],tracer.xs[iseed][below[start:j-1],1],tracer.xs[iseed][below[start:j-1],2],
                               color=color_ls[key%9],lw=0.3,alpha=1,zorder=0.5) 
                        start = j
                #ax.scatter(tracer.xs[iseed][above,0],tracer.xs[iseed][above,1],tracer.xs[iseed][above,2],color=color_ls[key%9],s=0.0005,alpha=0.8,zorder=3.6)
                #ax.scatter(tracer.xs[iseed][below,0],tracer.xs[iseed][below,1],tracer.xs[iseed][below,2],color=color_ls[key%9],s=0.0005,alpha=0.8,zorder=0.5)

        # Set axes
        if do_zoom and (float(time)>zoom_time_start):
            ax.set_xlim(zoom_xmin,zoom_xmax)
            ax.set_ylim(zoom_ymin,zoom_ymax)
            ax.set_zlim(zoom_zmin,zoom_zmax)
            x_range = zoom_xmax - zoom_xmin
            y_range = zoom_ymax - zoom_ymin
            z_range = zoom_zmax - zoom_zmin
            ax.set_box_aspect([x_range, y_range, z_range])
        else:
            ax.set_xlim(X.min(),X.max())
            ax.set_ylim(Y.min(),Y.max())
            ax.set_zlim(z_lower,z_upper)
            x_range = X.max() - X.min()
            y_range = Y.max() - Y.min()
            z_range = z_upper - z_lower
            ax.set_box_aspect([x_range, y_range, z_range])  # Aspect ratio is set based on the data limits

        ######################## PLOT2 END ################################

    
        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi = 300)
        plt.close(fig)

    # PLOT PRESET '3D_current_sheet'
    if plot_preset=='3D_current_sheet':
        fig,axs = plt.subplots(nrows = 3, ncols = 2, figsize=(18,10))#, constrained_layout=True)

        # Define y values for xz planes
        y0=-0.5
        y1=0
        y2=0.5

        # Unpack cs variables
        Xcs = datacs["X"]
        Ycs = datacs["Y"]
        Zcs = datacs["Z"]

        # Unpack 3d variables
        X3d = data3d["X"]
        Y3d = data3d["Y"]
        Z3d = data3d["Z"]
        Bx3d = data3d["Bx"]
        By3d = data3d["By"]
        Bz3d = data3d["Bz"]
        J3d = np.sqrt(data3d["Jx"]**2+data3d["Jy"]**2+data3d["Jz"]**2)
        beta3d = (2*mu_0*(data3d["pxxS0"]+data3d["pyyS0"]+data3d["pzzS0"]+data3d["pxxS1"]+data3d["pyyS1"]+data3d["pzzS1"])*1e9/3/(Bx3d**2+By3d**2+Bz3d**2))

        # Find indices for each xz plane
        y0i = np.where(Y3d[:,0,0]<y0)[0][-1]
        y1i = np.where(Y3d[:,0,0]<y1)[0][-1]
        y2i = np.where(Y3d[:,0,0]<y2)[0][-1]

        # Plot current density
        y_ls = [y0,y1,y2]
        yi_ls = [y0i,y1i,y2i]
        levels = np.logspace(0,3,21)
        for i in range(len(y_ls)):
            jplot = axs[i,0].contourf(X3d[yi_ls[i],:,:],Z3d[yi_ls[i],:,:],J3d[yi_ls[i],:,:]*1e9,norm=LogNorm(),levels=levels,cmap='plasma',extend='both')
            # Add field lines
            xx,zz = np.meshgrid(np.linspace(X3d[yi_ls[i],0,0],X3d[yi_ls[i],-1,0],len(X3d[yi_ls[i],:,0])),
                                np.linspace(Z3d[yi_ls[i],0,0],Z3d[yi_ls[i],0,-1],len(Z3d[yi_ls[i],0,:])))
            axs[i,0].streamplot(xx,zz,Bx3d[yi_ls[i],:,:].T,Bz3d[yi_ls[i],:,:].T,color='white',linewidth=0.5,broken_streamlines=False,arrowsize=0.5)
            # Add current sheet fit
            xmax = np.where(datacs['rhoS1'][yi_ls[i],:]==0.0)[0][0]
            axs[i,0].plot(Xcs[yi_ls[i],:xmax],Zcs[yi_ls[i],:xmax],color='green',lw=2)
            # Other config
            axs[i,0].add_patch(plt.Circle((0, 0), np.sqrt(1-y_ls[i]**2), color='grey'))
            axs[i,0].add_patch(plt.Circle((0, 0), np.sqrt(0.8**2-y_ls[i]**2), color='black'))
            axs[i,0].set_aspect(1) # you may also use am.imshow(..., aspect="auto") to restore the aspect ratio
            axs[i,0].set_xlim(-4,-0.5)
            axs[i,0].set_ylim(-0.7,0.8)
            axs[i,0].tick_params(axis='both',labelsize=15)
            axs[i,0].set_title(str("Current density at Y = "+str(y_ls[i])),fontsize=15)
            axs[i,0].set_ylabel("Z [$R_M$]",fontsize=15)
        axs[i,0].set_xlabel("X [$R_M$]",fontsize=15)
        clb1 = fig.colorbar(jplot, ax=axs[:,0], norm=LogNorm()) 
        clb1.ax.tick_params(labelsize=15)
        clb1.locator = LogLocator()
        clb1.formatter = LogFormatterSciNotation()  
        clb1.ax.set_title('J [nA/m$^2$]',fontsize=15)

        # Plot plasma beta
        levels = np.logspace(-3,3,21)
        for i in range(len(y_ls)):
            betaplot = axs[i,1].contourf(X3d[yi_ls[i],:,:],Z3d[yi_ls[i],:,:],beta3d[yi_ls[i],:,:],norm=LogNorm(),levels=levels,cmap='bwr',extend='both')
            # Add field lines
            xx,zz = np.meshgrid(np.linspace(X3d[yi_ls[i],0,0],X3d[yi_ls[i],-1,0],len(X3d[yi_ls[i],:,0])),
                                np.linspace(Z3d[yi_ls[i],0,0],Z3d[yi_ls[i],0,-1],len(Z3d[yi_ls[i],0,:])))
            axs[i,1].streamplot(xx,zz,Bx3d[yi_ls[i],:,:].T,Bz3d[yi_ls[i],:,:].T,color='black',linewidth=0.5,broken_streamlines=False,arrowsize=0.5)
            # Add current sheet fit
            xmax = np.where(datacs['rhoS1'][yi_ls[i],:]==0.0)[0][0]
            axs[i,1].plot(Xcs[yi_ls[i],:xmax],Zcs[yi_ls[i],:xmax],color='green',lw=2)
            axs[i,1].add_patch(plt.Circle((0, 0), np.sqrt(1-y_ls[i]**2), color='grey'))
            axs[i,1].add_patch(plt.Circle((0, 0), np.sqrt(0.8**2-y_ls[i]**2), color='black'))
            axs[i,1].set_aspect(1) # you may also use am.imshow(..., aspect="auto") to restore the aspect ratio
            axs[i,1].set_xlim(-4,-0.5)
            axs[i,1].set_ylim(-0.7,0.8)
            axs[i,1].tick_params(axis='both',labelsize=15)
            axs[i,1].set_title(str("Plasma beta at Y = "+str(y_ls[i])),fontsize=15)
            axs[i,1].set_ylabel("Z [$R_M$]",fontsize=15)
        axs[i,1].set_xlabel("X [$R_M$]",fontsize=15)
        clb2 = fig.colorbar(betaplot, ax=axs[:,1], norm=LogNorm()) 
        clb2.ax.tick_params(labelsize=15)
        clb2.locator = LogLocator()
        clb2.formatter = LogFormatterSciNotation()  
        clb2.ax.set_title('beta',fontsize=15)

        '''
        
        # Add labels
        ax.set_xlabel("X [$R_M$]",fontsize=12)
        ax.set_ylabel("Y [$R_M$]",fontsize=12)
        ax.set_zlabel("Z [$R_M$]",fontsize=12)
        ax.tick_params(axis='both',labelsize=12)
        ax.set_title(str("$B_z$ at t="+time+"s"),fontsize=12,y=1.0, pad=-14)

        '''

        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+time+'.png'),bbox_inches='tight',pad_inches=0.3, dpi=300)
        plt.close(fig)

    if plot_preset=='3D_flux_tube_content':

        # Set up plot environment
        fig = plt.figure(figsize=(18,6))
        ax = fig.add_subplot(111)

        xlims = [-4,0]
        ylims = [-1.2,1.2]
        
        # Unpack data
        Xcs = datacs["X"]
        Ycs = datacs["Y"]
        Zcs = datacs["Z"]
        X3d = data3d["X"]
        Y3d = data3d["Y"]
        Z3d = data3d["Z"]
        Bx3d = data3d["Bx"]
        By3d = data3d["By"]
        Bz3d = data3d["Bz"]
        ncs = datacs["rhoS1"]
        n3d = data3d["rhoS1"] * 1e6 # convert to SI
        pe3d = ((data3d["pxxS0"]+data3d["pyyS0"]+data3d["pzzS0"])/3*1e-9) # convert to SI
        pi3d = ((data3d["pxxS1"]+data3d["pyyS1"]+data3d["pzzS1"])/3*1e-9) # convert to SI

        # Ratio of specific heats
        gamma = 5/3

        # Set up grid for field line tracing
        ny,nx,nz = Bx3d.shape
        field = np.zeros((nx,ny,nz,3))
        field[:,:,:,0] = np.transpose(Bx3d,axes=[1,0,2])
        field[:,:,:,1] = np.transpose(By3d,axes=[1,0,2])
        field[:,:,:,2] = np.transpose(Bz3d,axes=[1,0,2])
        grid_spacing = [1/64,1/64,1/64]
        grid = VectorGrid(field, grid_spacing, origin_coord = [X3d.min(),Y3d.min(),Z3d.min()])
        nsteps = 10000
        step_size = 0.001
        tracer = StreamTracer(nsteps, step_size)
        trace_skip = 2

        # Compute entropy integrand and define interpolator
        entropy = (pe3d+pi3d)**(gamma)/np.sqrt(Bx3d**2+By3d**2+Bz3d**2)*1e9 # SI
        interpolator = RegularGridInterpolator((X3d[0, :, 0], Y3d[:, 0, 0], Z3d[0, 0, :]), np.swapaxes(entropy,0,1), bounds_error=False, fill_value=None)
        
        # Define array to save field line entropy
        entropy_content = np.zeros_like(Xcs[::trace_skip,::trace_skip])
        for ix in range(len(entropy_content[0,:])):
            for iy in range(len(entropy_content[:,0])):
                seed = np.array((Xcs[iy*trace_skip,ix*trace_skip],Ycs[iy*trace_skip,ix*trace_skip],Zcs[iy*trace_skip,ix*trace_skip]))
                tracer.trace(seed, grid)
                entropy_content[iy,ix] = np.sum(np.nan_to_num(interpolator(tracer.xs[0])*step_size*R_M))
                #if seed[0]>-1.5:
                #    print(seed)
                #    print(tracer.xs[0])
                 #   print(entropy_content[iy,ix])
        '''
        # Add stream trace seed at each CS x,y
        #trace_skip=10
        seeds = np.zeros((len(Xcs[::trace_skip,::trace_skip].ravel()),3))
        seeds[:,0] = Xcs[::trace_skip,::trace_skip].ravel()
        seeds[:,1] = Ycs[::trace_skip,::trace_skip].ravel()
        seeds[:,2] = Zcs[::trace_skip,::trace_skip].ravel()

        # Trace the field lines
        tracer.trace(seeds, grid)

        # Set up linear interpolator to get values at each field line point
        entropy = (pe3d+pi3d)**(gamma)/np.sqrt(Bx3d**2+By3d**2+Bz3d**2)*1e9 # SI
        interpolator = RegularGridInterpolator((X3d[0, :, 0], Y3d[:, 0, 0], Z3d[0, 0, :]), np.swapaxes(entropy,0,1), bounds_error=False, fill_value=None)
        
        # Integrate quantities along field line
        entropy_content = np.zeros(len(seeds))
        for i,seed in enumerate(seeds):
            entropy_content[i] = np.sum(interpolator(tracer.xs[i])*step_size*R_M)
        '''
        levels = np.logspace(-5, 2, 31)
        plot = ax.contourf(Xcs[::trace_skip,::trace_skip],Ycs[::trace_skip,::trace_skip],entropy_content,cmap="viridis",
                           norm=LogNorm(),levels=levels,extend='both')

        inner = plt.Circle((0, 0), np.sqrt(0.8**2-np.mean(Z)**2), color='black')
        outer = plt.Circle((0, 0), np.sqrt(1-np.mean(Z)**2), color='grey')

        x_major_ticks = np.arange(xlims[0], xlims[1], 0.25)
        x_minor_ticks = np.arange(xlims[0], xlims[1], 0.05)
        y_major_ticks = np.arange(ylims[0], ylims[1], 0.25)
        y_minor_ticks = np.arange(ylims[0], ylims[1], 0.05)

        ax.set_xticks(x_major_ticks)
        ax.set_xticks(x_minor_ticks, minor=True)
        ax.set_yticks(y_major_ticks)
        ax.set_yticks(y_minor_ticks, minor=True)

        ax.grid(which='both')
        ax.grid(which='minor', alpha=0.2)
        ax.grid(which='major', alpha=0.5)
        
        clb1 = fig.colorbar(plot, ax=ax, norm=LogNorm()) 
        clb1.ax.tick_params(labelsize=15)
        clb1.locator = LogLocator()
        clb1.formatter = LogFormatterSciNotation()  
        clb1.ax.set_title('$H$ ',fontsize=15)
        
        # Add labels
        ax.set_xlabel("X [$R_M$]",fontsize=12)
        ax.set_ylabel("Y [$R_M$]",fontsize=12)
        ax.add_patch(outer)
        ax.add_patch(inner)
        ax.set_xlim(xlims)
        ax.set_ylim(ylims)
        ax.tick_params(axis='both',labelsize=12)
        ax.set_title(str("Density at t="+time+"s"),fontsize=12)
        ax.set_aspect(1)

        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+time+'.png'),bbox_inches='tight',pad_inches=0.3, dpi=300)
        plt.close(fig)

    if plot_preset=='3D_gridscale_df_tracker':

        # On the first iteration, define an empty dictionary to save our DF data to
        if iter == 0:
            df_data = {} # This stores all the DFs we have seen
            df_dict = None # This stores all the cells of DFs from the previous step
            df_seeds = {} # This stores all the field line seed locations for the DFs

        # Define X cutoff, if required
        x_cutoff=0

        # Unpack data
        X = datacs["X"][:,x_cutoff:]
        Y = datacs["Y"][:,x_cutoff:]
        Z = datacs["Z"][:,x_cutoff:]
        X3d = data3d["X"][:,x_cutoff:,:]
        Y3d = data3d["Y"][:,x_cutoff:,:]
        Z3d = data3d["Z"][:,x_cutoff:,:]
        Bx3d = data3d["Bx"][:,x_cutoff:,:]
        By3d = data3d["By"][:,x_cutoff:,:]
        Bz3d = data3d["Bz"][:,x_cutoff:,:]
        Jx = datacs["Jx"][:,x_cutoff:]
        Jy = datacs["Jy"][:,x_cutoff:]
        Jz = datacs["Jz"][:,x_cutoff:]
        Bx = datacs["Bx"][:,x_cutoff:] 
        By = datacs["By"][:,x_cutoff:] 
        Bz = datacs["Bz"][:,x_cutoff:] 
        n = datacs["rhoS1"][:,x_cutoff:] * 1e6 # convert to SI
        pe = ((datacs["pxxS0"]+datacs["pyyS0"]+datacs["pzzS0"])/3*1e-9)[:,x_cutoff:] # convert to SI
        Te = pe/n/k_b / 11605 / 1e3 #Convert to keV
        uix = datacs["uxS1"][:,x_cutoff:]
        uiy = datacs["uyS1"][:,x_cutoff:]
        uiz = datacs["uyS1"][:,x_cutoff:]
        uex = datacs["uxS0"][:,x_cutoff:]
        uey = datacs["uyS0"][:,x_cutoff:]
        uez = datacs["uyS0"][:,x_cutoff:]
        beta = (2*mu_0*(datacs["pxxS0"]+datacs["pyyS0"]+datacs["pzzS0"]+datacs["pxxS1"]+datacs["pyyS1"]+datacs["pzzS1"])*1e9/3/(datacs["Bx"]**2+datacs["By"]**2+datacs["Bz"]**2))[:,x_cutoff:]
        E_convx = (-(uey*Bz-uez*By)*1000*1e-9) # Convert to V/m ie SI
        E_convy = (-(uez*Bx-uex*Bz)*1000*1e-9)
        E_convz = (-(uex*By-uey*Bx)*1000*1e-9)
        dp_dx = datacs['dp_dx'][:,x_cutoff:]
        dp_dy = datacs['dp_dy'][:,x_cutoff:]
        dp_dz = datacs['dp_dz'][:,x_cutoff:]

        # Compute average values of Bz in the 5seconds preceeding the current time
        Bz_avg = average_value(["Bz"],float(time),-1,0)["Bz"][:,x_cutoff:]
        delta_Bz3d = Bz3d - average_value(["Bz"],float(time),-5,-2,type='numpy')["Bz"][:,x_cutoff:,:]

        # Compute DF metric
        metric = (Bz-Bz_avg)

        # Set bounds on metric
        min_value = 10 #nT 
        min_size = 10
        dx = 1/64

        # Set z bounds
        z_lower = -1
        z_upper = 1

        # Compute zoom rates, if activated
        if do_zoom:
            zoom_dxdt_min = (zoom_x_range[0]-np.min(X))/(zoom_time_end-zoom_time_start)
            zoom_dxdt_max = (np.max(X)-zoom_x_range[1])/(zoom_time_end-zoom_time_start)
            zoom_dydt_min = (zoom_y_range[0]-np.min(Y))/(zoom_time_end-zoom_time_start)
            zoom_dydt_max = (np.max(Y)-zoom_y_range[1])/(zoom_time_end-zoom_time_start)
            zoom_dzdt_min = (zoom_z_range[0]-z_lower)/(zoom_time_end-zoom_time_start)
            zoom_dzdt_max = (z_upper-zoom_y_range[1])/(zoom_time_end-zoom_time_start)
            zoom_dazimdt = (azim_end - azim_start)/(zoom_time_end-zoom_time_start)
        
        # Find DF regions
        # Create boolean mask where Z exceeds z_0
        mask = metric > min_value
        
        # Label connected regions
        structure = np.zeros((3, 3), dtype=bool)  # Structuring element
        structure[1,:] = True
        structure[:,1] = True # Use a + shaped mask
        labeled, num_features = label(mask, structure=structure)
        
        # Find all the DF regions in this time slice
        new_df_dict = {}
        count=1
        for feature_num in range(1, num_features + 1):
            region = (labeled == feature_num)
            DF_beta = np.mean(beta[region])
            # Remove regions that are too small or have an average beta<1
            if (len(X[region])>min_size) and (DF_beta>1):
                new_df_dict[count] = (X[region], Y[region])
                count+=1
    
        print("Found",len(new_df_dict.keys()),"DFs at this time")

        ################# Plot #################
        fig = plt.figure(figsize=(18,6))
        ax = fig.add_subplot(111, projection="3d",computed_zorder=False)

        mid = 95 # Row index corresponding to midnight 

        # Mask out values
        if do_zoom and (float(time)>zoom_time_start):
            if (float(time)>=zoom_time_end):
                zoom_time = zoom_time_stop-zoom_time_start # Effectively applies a stopping condition, to "stay" zoomed
            else:
                zoom_time = float(time)-zoom_time_start
            zoom_xmin = np.min(X)+zoom_time*zoom_dxdt_min
            zoom_xmax = np.max(X)-zoom_time*zoom_dxdt_max
            zoom_ymin = np.min(Y)+zoom_time*zoom_dydt_min
            zoom_ymax = np.max(Y)-zoom_time*zoom_dydt_max
            zoom_zmin = z_lower+zoom_time*zoom_dzdt_min
            zoom_zmax = z_upper-zoom_time*zoom_dzdt_max
            zoom_mask = (X > zoom_xmax) | (X < zoom_xmin) | (Y > zoom_ymax) | (Y < zoom_ymin) 
            Z[zoom_mask] = np.nan
        radius = 1.01
        mask = (X**2 + Y**2) < radius**2
        Z[mask] = np.nan

        mid = 90 # Row index corresponding to midnight 
        # Define colormap and lighting
        vmin=-25
        vmax=25
        norm = plt.Normalize(vmin,vmax)
        dawn_colors = cm.bwr(norm(metric[:mid,:]),alpha=0.9)
        dusk_colors = cm.bwr(norm(metric[mid:,:]),alpha=0.9)

        # Set the lighting
        light = LightSource()  # Azimuth and altitude of the light source
        dawn_illuminated_colors = light.shade_rgb(dawn_colors, Z[:mid,:], blend_mode='soft')  # Apply light source shading
        dusk_illuminated_colors = light.shade_rgb(dusk_colors, Z[mid:,:], blend_mode='soft')  # Apply light source shading

        # move camera view
        if do_zoom and (float(time)>zoom_time_start):
            ax.view_init(elev=25, azim=azim_start + zoom_dazimdt*(zoom_time))
        else:
            ax.view_init(elev=25, azim=azim_start)
            
        # Create the surface plot
        surf1 = ax.plot_surface(X[:mid,:], Y[:mid,:], Z[:mid,:], facecolors=dawn_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=2)
        surf2 = ax.plot_surface(X[mid:,:], Y[mid:,:], Z[mid:,:], facecolors=dusk_illuminated_colors, rstride=1, cstride=1, antialiased=False, zorder=0.75)
        
        # Add a color bar 
        m = cm.ScalarMappable(cmap=cm.bwr, norm=norm)
        m.set_array(metric)
        clb = fig.colorbar(m, ax=ax, shrink=0.3, aspect=7,anchor=(0.0,0.3))
        clb.ax.tick_params(labelsize=12)
        clb.ax.set_title('$\delta B_{z}$ [nT]',fontsize=12,pad=10)

        # Add big x axis
        if do_zoom and (float(time)>zoom_time_start):
            #ax.plot([zoom_xmin,-1],[0,0],[0.2,0.2],color='black',lw=1)
            #ax.scatter(np.arange(int(zoom_xmin),0),np.arange(int(zoom_xmin),0)*0,np.arange(int(zoom_xmin),0)*0+0.2,s=4,color='black')
            # Show Mercury
            plot_sphere(ax,radius=1,color='lightgrey',alpha=0.5,zorder=1,quarter=False,
                    xlims=[zoom_xmin,zoom_xmax],ylims=[zoom_ymin,zoom_ymax],zlims=[zoom_zmin,zoom_zmax])
            plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=1.25,quarter=False,
                    xlims=[zoom_xmin,zoom_xmax],ylims=[zoom_ymin,zoom_ymax],zlims=[zoom_zmin,zoom_zmax])
        else:
            ax.plot([np.min(X[:mid,:]),-1],[0,0],[0.2,0.2],color='black',lw=1)
            ax.scatter([-4,-3,-2,-1],[0,0,0,0],[0.2,0.2,0.2,0.2],s=4,color='black')
            # Show Mercury
            plot_sphere(ax,radius=1,color='lightgrey',alpha=0.5,zorder=1,quarter=False)
            plot_sphere(ax,radius=0.8,color='grey',alpha=1,zorder=1.25,quarter=False)

        # Add labels
        ax.set_xlabel("X [$R_M$]",fontsize=12)
        ax.set_ylabel("Y [$R_M$]",fontsize=12)
        ax.set_zlabel("Z [$R_M$]",fontsize=12)
        ax.tick_params(axis='both',labelsize=12)
        ax.set_title(str("$\delta B_{z}$ at t="+time+"s"),fontsize=12,y=1.0, pad=-5)

        ################# END PLOT #################
        
        # Compare to previous df_dict, if any, and relabel DFs for continuity
    
        if df_dict is not None and len(new_df_dict.keys())>0: # Only proceed with attempting to match DFs if we have data from last timestep and there is at least 1 DF in this timestep
            # The name of the game is just to relabel all the keys appropriately.
            # Set up a new dictionary where we will make all these changes.
            next_df_dict = {}
    
            # Iterate through new_key_dict, which has all the dfs identified in this step (with keys which will generally be totally wrong)
            new_keys = list(new_df_dict.keys()).copy()
            overlap_masks = [] # Here, we will store key pairs: [new_key, old_key, agreement_lvl]
            for new_key in new_keys:
                for old_key in df_dict.keys():
                    # Compare all the currently identified DFs to those from the previous step, and save an entry to overlap_masks if any overlap
                    xmask = np.isin(new_df_dict[new_key][0],df_dict[old_key][0])
                    ymask = np.isin(new_df_dict[new_key][1],df_dict[old_key][1])
                    mask=xmask&ymask
                    if mask.any():
                        print("New DF#"+str(new_key)+" overlaps with old DF#"+str(old_key))
                        overlap_masks.append([new_key,old_key,sum(mask)]) # sum(mask) gives the number of "True" in the list
            # Sometimes a weird error happens where we have new DFs but none overlap and we have an empty matrix.. this is a hotfix for that case:
            #if len(overlap_masks)==0:
            #    df_dict = new_df_dict   
            #else:
            # Matrix stores the relationship between the DFs labelled at this time and the previous time.
            unfiltered_matrix = np.array(overlap_masks, ndmin=2)
            # If a new DF has appeared, we have not accounted for it yet (since it will have no overlap with the previous step).
            for key in new_keys:
                if len(overlap_masks)==0:
                    unfiltered_matrix = np.array([key,-1,0], ndmin=2) # In some cases, we have only new DFs and no overlap, so unfiltered matrix cannot be indexed in the next elif and the code crashes. This hotfix solves that.
                elif key not in unfiltered_matrix[:,0]:
                    unfiltered_matrix = np.vstack([unfiltered_matrix, [key,-1,0]]) # Add newly formed DFs to the register, and associate it with the previous DF -1 (i.e. assocaited with none)
            unfiltered_matrix = unfiltered_matrix[unfiltered_matrix[:,2].argsort()[::-1]] # Sort to start with largest overlap ones
            
            # We now need to remove repeated rows with repeated values of new_key (column zero) to stop an infinite cascade of new DFs
            # Now that we've sorted the data, the dfs with the most overlap will be selected for
            # Only need to filter out rows if there are more rows in matrix than the number of dfs at this time
            if len(unfiltered_matrix[:,0])>len(new_df_dict.keys()):
                matrix = remove_duplicate_rows(unfiltered_matrix)
            else:
                matrix = unfiltered_matrix
            
            print("NEW DF KEY   OLD DF KEY   MATCH")
            print(matrix)
    
            temp_key=-1
    
            for i in range(len(matrix[:,0])):
                if matrix[i,1]==-1: # this means its a newly formed DF in this step.
                    print("DF#"+str(matrix[i,0]),"is a new one and is temporarily assigned #"+str(temp_key))
                    r = np.mean(np.sqrt(new_df_dict[matrix[i,0]][0]**2+new_df_dict[matrix[i,0]][1]**2))
                    if r>1.25:
                        next_df_dict[temp_key] = new_df_dict[matrix[i,0]] # Give it a temporary name, we will come back to it at the end
                        temp_key-=1
                    else:
                        print("This DF formed too close to the planet, throwing it out...")
                        #print("position:",r,"    beta:",beta_DR)
                elif (matrix[i,1] not in next_df_dict.keys()): # Check to see if this DF has already been named for the updated dict. If its not there, add it
                    print("DF#"+str(matrix[i,1])+" has been tracked from the previous step")
                    next_df_dict[matrix[i,1]] = new_df_dict[matrix[i,0]] # The name of the DF is taken from df_dict, and is populated with data from the new dict. The matrix is used as a reference to connect the two.
                else: # This means this DF has already been identified with a previous DF that has more overlap with it i.e. it is a child
                    print("DF#"+str(matrix[i,1])+" has split and formed a new DF, which is temporarily assigned #"+str(temp_key))
                    next_df_dict[temp_key] = new_df_dict[matrix[i,0]] # Give it a temporary name, we will come back to it at the end
                    temp_key-=1
            # All DFs identified in this step have been assigned names in next_df_dict. Now, we need to rename the negative ones to the next largest names
            if len(df_data.keys())==0:
                new_df_key = 1    
            else:
                new_df_key = np.max(list(df_data.keys()))+1 # Start naming at one larger than the maximum df number already used
            for key in list(next_df_dict.keys()).copy():
                if key<0:
                    print("Reassigning the temporary DF#"+str(key),"to DF#"+str(new_df_key))
                    next_df_dict[new_df_key] = next_df_dict.pop(key)
                    new_df_key+=1
            print("Feature tracking complete!")
            df_dict = next_df_dict   
        else:
            df_dict = new_df_dict   

        ####### PLOT 2 SETUP ######
        color_ls = ["tab:blue","tab:green","tab:blue","tab:orange","tab:purple","tab:brown","tab:pink","tab:olive","tab:cyan"]
        # Set up grid for field line tracing
        ny,nx,nz = Bx3d.shape
        field = np.zeros((nx,ny,nz,3))
        field[:,:,:,0] = np.transpose(Bx3d,axes=[1,0,2])
        field[:,:,:,1] = np.transpose(By3d,axes=[1,0,2])
        field[:,:,:,2] = np.transpose(Bz3d,axes=[1,0,2])
        grid_spacing = [1/64,1/64,1/64]
        grid = VectorGrid(field, grid_spacing, origin_coord = [X3d.min(),Y3d.min(),Z3d.min()])
        nsteps = 10000
        step_size = 0.001
        tracer = StreamTracer(nsteps, step_size)
        ##### END PLOT 2 SETUP #########

            
        # For each DF, either create a new item to store info about it or add to an existing item
        for key in df_dict.keys():
            if key not in df_data.keys(): # Create new dataframe if this df has not been registered already
                df_data[key] = pd.DataFrame(columns=['time','X','Y','Z','Bx','By','Bz','Te','n','uix',"uiy","uiz",'uex',"uey","uez",
                                                     "E_convx","E_convy","E_convz","dp_dx","dp_dy","dp_dz",
                                                     "J_inrt,x","J_inrt,y","J_inrt,z","J_gradp,x","J_gradp,y","J_gradp,z",
                                                     "Jx","Jy","Jz","Bz_max",'area'])
            new_row = np.zeros(32)
            # Now we iterate over each coordinate associated with this DF
            Bz_max_ls = []
            for i in range(len(df_dict[key][0])):
                coord = [df_dict[key][0][i],df_dict[key][1][i]] # Remember, each item in df_dict is a tuple of the X coords and Y coords
                # Find which indices of "data" these coordinates correspond to
                ix = np.where(X[0,:]==coord[0])[0]
                iy = np.where(Y[:,1]==coord[1])[0]
                new_row[1] = new_row[1] + X[iy,ix].item()
                new_row[2] = new_row[2] + Y[iy,ix].item()
                new_row[3] = new_row[3] + Z[iy,ix].item()
                new_row[4] = new_row[4] + Bx[iy,ix].item()
                new_row[5] = new_row[5] + By[iy,ix].item()
                new_row[6] = new_row[6] + Bz[iy,ix].item()
                new_row[7] = new_row[7] + Te[iy,ix].item()
                new_row[8] = new_row[8] + n[iy,ix].item()
                new_row[9] = new_row[9] + uix[iy,ix].item()
                new_row[10] = new_row[10] + uiy[iy,ix].item()
                new_row[11] = new_row[11] + uiz[iy,ix].item()
                new_row[12] = new_row[12] + uex[iy,ix].item()
                new_row[13] = new_row[13] + uey[iy,ix].item()
                new_row[14] = new_row[14] + uez[iy,ix].item()
                new_row[15] = new_row[15] + E_convx[iy,ix].item()
                new_row[16] = new_row[16] + E_convy[iy,ix].item()
                new_row[17] = new_row[17] + E_convz[iy,ix].item()
                new_row[18] = new_row[18] + dp_dx[iy,ix].item()
                new_row[19] = new_row[19] + dp_dy[iy,ix].item()
                new_row[20] = new_row[20] + dp_dz[iy,ix].item()
                if len(df_data[key])>0: # Compute inertial current using the acceleration between last time step and this one
                    # Problem: The velocity from the last time step is the average... will that be an issue?
                    new_row[21] = new_row[21] + (n[iy,ix].item())*m_p/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)*1e-9) * (By[iy,ix].item()*(uiz[iy,ix].item()-df_data[key]['uiz'].iloc[-1])/dt - Bz[iy,ix].item()*(uiy[iy,ix].item()-df_data[key]['uiy'].iloc[-1])/dt)*1e3 #A/m^2
                    new_row[22] = new_row[22] + (n[iy,ix].item())*m_p/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)*1e-9) * (Bz[iy,ix].item()*(uix[iy,ix].item()-df_data[key]['uix'].iloc[-1])/dt - Bx[iy,ix].item()*(uiz[iy,ix].item()-df_data[key]['uiz'].iloc[-1])/dt)*1e3 #A/m^2
                    new_row[23] = new_row[23] + (n[iy,ix].item())*m_p/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)*1e-9) * (Bx[iy,ix].item()*(uiy[iy,ix].item()-df_data[key]['uiy'].iloc[-1])/dt - By[iy,ix].item()*(uix[iy,ix].item()-df_data[key]['uix'].iloc[-1])/dt)*1e3 #A/m^2
                else:
                    new_row[21] = 0
                    new_row[22] = 0
                    new_row[23] = 0
                new_row[24] = new_row[24] + 1/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)) * (By[iy,ix].item()*dp_dz[iy,ix].item()+Bz[iy,ix].item()*dp_dy[iy,ix].item()) #A/m^2
                new_row[25] = new_row[25] + 1/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)) * (Bz[iy,ix].item()*dp_dx[iy,ix].item()+Bx[iy,ix].item()*dp_dz[iy,ix].item()) #A/m^2
                new_row[26] = new_row[26] + 1/((Bx[iy,ix].item()**2+By[iy,ix].item()**2+Bz[iy,ix].item()**2)) * (Bx[iy,ix].item()*dp_dy[iy,ix].item()+By[iy,ix].item()*dp_dx[iy,ix].item()) #A/m^2
                new_row[27] = new_row[27] + Jx[iy,ix].item()
                new_row[28] = new_row[28] + Jy[iy,ix].item()
                new_row[29] = new_row[29] + Jz[iy,ix].item()

                Bz_max_ls.append(Bz[iy,ix]) # Save all the Bz values to find the max in the DF
                
            # Divide by the total number of cells for this DF to get the average quantity
            new_row = new_row/(i+1) 
            new_row[0] = time # Set the first column to the time
            new_row[30] = np.max(Bz_max_ls) # Set the 9th column to the max Bz
            new_row[31] = (1/64)**2*(i+1) # Set the last row to the area
            
            temp = df_data[key]
            temp.loc[len(temp)] = new_row
            df_data[key] = temp # Add this new row to the correct dataframe

            # Show a trace of each current DF's path
            #ax.plot(temp["X"],temp["Y"],temp["Z"], color = color_ls[key%10])

            ######################## PLOT2 START ################################
            
            # Plot outline and traces of each DF
        
            # Pull out X,Y,Z coords of each cell of this DF in the current sheet
            X_df, Y_df = df_dict[key]
            Z_df = Z[find_indices(X_df, Y_df, X, Y)]
            Jx_df = Jx[find_indices(X_df, Y_df, X, Y)] # Comes in A/m^2
            Jy_df = Jy[find_indices(X_df, Y_df, X, Y)]
            Jz_df = Jz[find_indices(X_df, Y_df, X, Y)]
            Bx_df = Bx[find_indices(X_df, Y_df, X, Y)]*1e-9 # Comes in nT
            By_df = By[find_indices(X_df, Y_df, X, Y)]*1e-9
            Bz_df = Bz[find_indices(X_df, Y_df, X, Y)]*1e-9
            dp_dx_df = dp_dx[find_indices(X_df, Y_df, X, Y)]*1e-9 # Comes in nPa/m
            dp_dy_df = dp_dy[find_indices(X_df, Y_df, X, Y)]*1e-9
            dp_dz_df = dp_dz[find_indices(X_df, Y_df, X, Y)]*1e-9
            uex_df = uex[find_indices(X_df, Y_df, X, Y)] # Comes in km/s
            uey_df = uey[find_indices(X_df, Y_df, X, Y)]

            zoom_scale = (np.max(X)-np.min(X))/(zoom_xmax-zoom_xmin)
            
            ax.scatter(X_df, Y_df, Z_df, s = 0.15*zoom_scale, color = color_ls[key%9],zorder=5)#, label=str('DF '+str(key)),zorder=5)
            ax.quiver(np.mean(X_df),np.mean(Y_df),np.mean(Z_df),np.mean(uex_df),np.mean(uey_df),0,color='red',length = 1e-4)
            qskip = int(10/zoom_scale)
            #J_quiver = ax.quiver(X_df[::qskip],Y_df[::qskip],Z_df[::qskip],Jx_df[::qskip],Jy_df[::qskip],Jz_df[::qskip],
            #          color='black',length = 2e5, cmap=J_cmap, norm=J_norm)
            JxB_quiver = ax.quiver(X_df[::qskip],Y_df[::qskip],Z_df[::qskip],
                                   (Jy_df*Bz_df - Jz_df*By_df)[::qskip],(Jz_df*Bx_df - Jx_df*Bz_df)[::qskip],(Jx_df*By_df - Jy_df*Bz_df)[::qskip]*0,
                                      color='fuchsia',length=6.5e12*zoom_scale,linewidths = 0.4*zoom_scale,zorder=5.5) # Comes out in units of rho du/dt = N/m^3
            gradp_quiver = ax.quiver(X_df[::qskip],Y_df[::qskip],Z_df[::qskip],
                                   -dp_dx_df[::qskip],-dp_dy_df[::qskip],-dp_dz_df[::qskip]*0,
                                      color='deepskyblue',length=6.5e12*zoom_scale,linewidths = 0.4*zoom_scale,zorder=5.5)

        #,length=3e13,linewidths = 0.8)

            # Line added for back-compatibility
            X_region = X_df 
            Y_region = Y_df
            # Find seed points for field lines
            if key in df_seeds.keys():   
                # Update DF seeds through bulk electron velocity
                df_seeds[key][:,0] = df_seeds[key][:,0] + dt*df_data[key]['uex'].iloc[-1]*1e3/R_M
                df_seeds[key][:,1] = df_seeds[key][:,1] + dt*df_data[key]['uey'].iloc[-1]*1e3/R_M
                # Field lines tend to advect outside of the DF, so iterate through each to check
                for iseed in range(len(df_seeds[key])):
                    if (df_seeds[key][iseed,0] > np.max(X_region)) or (df_seeds[key][iseed,0] < np.min(X_region)) or (df_seeds[key][iseed,1] > np.max(Y_region)) or (df_seeds[key][iseed,1] < np.min(Y_region)):
                        new_loc = random.randint(0,len(X_region)-1)
                        print("Field line seed left the DF! Moved seed at",df_seeds[key][iseed,:],"to",X_region[new_loc],Y_region[new_loc],0.2)
                        df_seeds[key][iseed,0] = X_region[new_loc]
                        df_seeds[key][iseed,1] = Y_region[new_loc]
                        df_seeds[key][iseed,2] = 0.2
                    
            # For first time this DF is generated, make all new seed points
            else:
                # Add stream traces
                trace_skip=10
                df_seeds[key] = np.zeros((len(X_region)//trace_skip+1,3))
                df_seeds[key][:,0] = X_region[::max(len(X_region),trace_skip)]
                df_seeds[key][:,1] = Y_region[::max(len(X_region),trace_skip)]
                df_seeds[key][:,2] = Z_region[::max(len(X_region),trace_skip)]

            # Trace the field lines
            tracer.trace(df_seeds[key], grid)

            # Plot them
            for iseed in range(len(df_seeds[key])):
                if do_zoom and (float(time)>zoom_time_start):
                    above = np.where((tracer.xs[iseed][:,2]>=df_seeds[key][iseed][2]) & (tracer.xs[iseed][:,0]<zoom_xmax) & (tracer.xs[iseed][:,0]>zoom_xmin) & (tracer.xs[iseed][:,1]<zoom_ymax) & (tracer.xs[iseed][:,1]>zoom_xmin) & (tracer.xs[iseed][:,2]<zoom_zmax) & (tracer.xs[iseed][:,2]>zoom_zmin))[0]
                    below = np.where((tracer.xs[iseed][:,2]<df_seeds[key][iseed][2]) & (tracer.xs[iseed][:,0]<zoom_xmax) & (tracer.xs[iseed][:,0]>zoom_xmin) & (tracer.xs[iseed][:,1]<zoom_ymax) & (tracer.xs[iseed][:,1]>zoom_xmin) & (tracer.xs[iseed][:,2]<zoom_zmax) & (tracer.xs[iseed][:,2]>zoom_zmin))[0]
                else:
                    above = np.where(tracer.xs[iseed][:,2]>=df_seeds[key][iseed][2])[0]
                    below = np.where(tracer.xs[iseed][:,2]<df_seeds[key][iseed][2])[0]
                # Plot the streamlines as a series of lines, without connecting between places where the indexing jumps
                start = 0 
                for j in range(1,len(above)):
                    if (above[j]-above[j-1]>1) or (j==(len(above)-1)):
                        ax.plot(tracer.xs[iseed][above[start:j-1],0],tracer.xs[iseed][above[start:j-1],1],tracer.xs[iseed][above[start:j-1],2],
                               color=color_ls[key%9],lw=0.3,alpha=1,zorder=6) 
                        start = j
                    
                start = 0
                for j in range(1,len(below)):
                    if (below[j]-below[j-1]>1) or (j==(len(below)-1)):
                        ax.plot(tracer.xs[iseed][below[start:j-1],0],tracer.xs[iseed][below[start:j-1],1],tracer.xs[iseed][below[start:j-1],2],
                               color=color_ls[key%9],lw=0.3,alpha=1,zorder=0.5) 
                        start = j
                #ax.scatter(tracer.xs[iseed][above,0],tracer.xs[iseed][above,1],tracer.xs[iseed][above,2],color=color_ls[key%9],s=0.0005,alpha=0.8,zorder=3.6)
                #ax.scatter(tracer.xs[iseed][below,0],tracer.xs[iseed][below,1],tracer.xs[iseed][below,2],color=color_ls[key%9],s=0.0005,alpha=0.8,zorder=0.5)

        # Set axes
        if do_zoom and (float(time)>zoom_time_start):
            ax.set_xlim(zoom_xmin,zoom_xmax)
            ax.set_ylim(zoom_ymin,zoom_ymax)
            ax.set_zlim(zoom_zmin,zoom_zmax)
            x_range = zoom_xmax - zoom_xmin
            y_range = zoom_ymax - zoom_ymin
            z_range = zoom_zmax - zoom_zmin
            ax.set_box_aspect([x_range, y_range, z_range])
        else:
            ax.set_xlim(X.min(),X.max())
            ax.set_ylim(Y.min(),Y.max())
            ax.set_zlim(z_lower,z_upper)
            x_range = X.max() - X.min()
            y_range = Y.max() - Y.min()
            z_range = z_upper - z_lower
            ax.set_box_aspect([x_range, y_range, z_range])  # Aspect ratio is set based on the data limits

        ######################## PLOT2 END ################################

    
        # Save
        fig.savefig(str(str(dir[:-1])+"_plots/"+plot_preset+"_"+"%.2f"%round(float(time),2)+'.png'),bbox_inches='tight',dpi = 300)
        plt.close(fig)

    iter+=1

# Save data
if plot_preset == 'cross_tail':
    pickle.dump(cross_tail_data, open(str(dir+"cross_tail_data"), 'wb') )
    print("Cross tail data saved to:",str(dir+"cross_tail_data"))
elif plot_preset == '3D_df_tracker':
    pickle.dump(df_data, open(str(dir+"df_data"), 'wb') )
    print("DF data data saved to:",str(dir+"df_data"))

In [61]:
dp_dx_df

array([-8.8201890e-16, -5.5039047e-16, -1.4452087e-15, -1.4749508e-15,
       -1.5864119e-15, -1.1508451e-16, -1.2425335e-15, -1.7371213e-15,
       -1.3186858e-15, -1.0470059e-15, -9.0118821e-16, -8.3771521e-16],
      dtype=float32)