## Let's figure out how to section the data and massage it into whatever form we want to input it as

### Then we can paste the streamlined version into the full training notebooks

### Load Data and do 2d analysis

In [1]:
import sys
sys.path.append('/tigress/kendrab/python_pkgs')

import numpy as np
from dataframework.src.datasets.vpicdataset import VPICDataset
from numpy.random import default_rng as generator
import h5py
import time

import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import matplotlib as mpl

In [2]:
def clear_dataset(group, key):
    try:  # if the dataset exists from previous failed attempts to do this, delete it
        del group[key]
        print(f"Overwrote existing dataset {group, key}")
    except:
        pass    

In [3]:
def random_bounded(rand,bounds):
    """ Use numpy to generate an array of random numbers with specified mins and maxes
    
    Parameters
    ----------
    rand :  random number generator
    bounds : array, shape (...,2)
        holds the bounds needed for each point in the random array. bounds[...,0] has mins
        
    Returns
    -------
    rands : array, shape (...)
        the array of random numbers returned
    """
    unit_rands = rand.random(size=bounds.shape[:-1])
    rands = (bounds[...,1] - bounds[...,0])*unit_rands + bounds[...,0]
    
    return rands

def mtail_transform(sim_data):  # TODO FIX THIS
    """ The geometry of the simulation is different wrt the magnetotail. Keeping y the same, bx will be
    the mtail's 'bz' and bz will be the mtails 'bx'.
    This function is for simulations WITHOUT y spatial dimensions
    THIS IS NOT A GOOD GENERALIZABLE FUNCTION!!!! REUSE WITH EXTREME CAUTION!!!
    This is for rxn geometry in vpic with Bz positive for x>0 and negative for x<0.
    If it is the opposite, need to do a different transform to get the reconnection to line up w/ gsm directions
    """
    # vpic +x -> GSM +z
    # vpic +z -> GSM +x
    # vpic +y -> GSM -y
    # meshes now in order x, z instead of z, x (renaming coords)
    # for vector variables ax, ay, az:
    # rename ax -> az_mms
    # ay_mms = -ay
    # rename az -> ax_mms
    sim_data.variables['bz_mms'] = sim_data.variables.pop('bx')
    sim_data.variables['bz_mms_smooth'] = sim_data.variables.pop('bx_smooth')
    sim_data.variables['by_mms'] = sim_data.variables.pop('by')
    sim_data.variables['by_mms'].data *= -1
    sim_data.variables['bx_mms'] = sim_data.variables.pop('bz') 
    sim_data.variables['bx_mms_smooth'] = sim_data.variables.pop('bz_smooth') 
    sim_data.variables['vx_mms'] = sim_data.variables.pop('vz')
    sim_data.variables['jy_mms'] = sim_data.variables.pop('jy')
    sim_data.variables['jy_mms'].data *= -1    
    sim_data.variables['ez_mms'] = sim_data.variables.pop('ex')
    sim_data.variables['ey_mms'] = sim_data.variables.pop('ey')
    sim_data.variables['ey_mms'].data *= -1
    sim_data.variables['ex_mms'] = sim_data.variables.pop('ez') 
    
    # update variable names for everyone
    for key in sim_data.variables.keys():
        print(f"label {sim_data.variables[key].label} becoming {key}")
        sim_data.variables[key].label = key

    
    

In [4]:
zooms = [[-np.inf,np.inf], [-20, 20]]
#time_idx = 45 #the time index we are processing rn
smoothing = 3
de_tol = 7
kwargs = {'get_vars' : ['bx','by','bz','jy','vz','ex','ey','ez']}

In [5]:
files = ['/scratch/gpfs/kendrab/dataset_vpic_runs/unperturbed/06022023/data.h5'] + \
        [f'/scratch/gpfs/kendrab/dataset_vpic_runs/perturbed/21032023/output{i}/data.h5' for i in range(0,10)]
writepaths = ['/tigress/kendrab/06022023/new_better'] + \
             [f'/tigress/kendrab/21032023/{i}/new_better' for i in range(0,10)]
# files=[f'/scratch/gpfs/kendrab/dataset_vpic_runs/perturbed/21032023/output4/data.h5'] # debug opt
# writepaths=[f'/tigress/kendrab/21032023/4/new_better'] # debug opt
unique_seed = 308
samples = 100
min_dist = 30
max_dist = 200
time_idxs_list = [[i for i in range(5,55,5)]] + 10*[[i for i in range(5,60,5)]]
debug = True
if debug:
    time_idxs_list = [[30]]  # debug option
    files=[f'/scratch/gpfs/kendrab/dataset_vpic_runs/perturbed/21032023/output6/data.h5'] 
    writepaths=[f'/tigress/kendrab/21032023/6/new_better']

In [6]:
for file, time_idxs, writepath in zip(files, time_idxs_list, writepaths):
    sim_dset = VPICDataset(vpicfiles=[file,''], **kwargs)
    for time_idx in time_idxs:
        desired_time = sim_dset.timeseries[time_idx]
        dt = sim_dset.timeseries[1] - sim_dset.timeseries[0]
        onetime_dset = sim_dset.ndslice(timelims=[desired_time - dt/2, desired_time+dt/2])
        onetime_dset.add_param('orig_idx', time_idx)
        onetime_dset.find_structures(b1_name='bz', b2_name='bx', de_tol=de_tol, wrap=True)
        zoomed_dset = onetime_dset.ndslice(timelims=[desired_time - dt/2, desired_time + dt/2], zooms=zooms)


        seed = unique_seed*time_idx*(time.time_ns() % 1000000) # save this for later
        rand = generator(seed=seed)

        bounds = zoomed_dset.bounds(time=False)
        segment_bounds = np.expand_dims(bounds, axis=-2)
        segment_bounds = np.concatenate([segment_bounds for i in range(2)], axis=-2)
        varkwargs = {'separatricesinterp' : 'nearest', 'o_structuresinterp' : 'nearest',
                    'current_sheetsinterp' : 'nearest'}  # to use nearest interpolation for topography info
        endpts_list = []
        # slices_dict format: slices_dict['variable'] = [[var_from_slice_1, var_from_slice_2, ...]]
        # Also stores the parameters
        slices_dict={}
        # transform the data to a more magnetotail-esque config
        mtail_transform(zoomed_dset)

        for i in range(samples):
            for j in range(1000):
                endpts = np.array(random_bounded(rand, segment_bounds))  # in [[minx, maxx], [miny, maxy]] format
                if not np.all(np.abs(endpts[:,0] - endpts[:,1]) > 1):
                    continue
                if max_dist > np.linalg.norm(endpts[:,0] - endpts[:,1]) > min_dist:
                    break
            else:
                raise Error("Suitable random 1d segment not found, check minimum and maximum distances.")    
            slce = zoomed_dset.ndslice(set_pts=endpts.T, zooms=endpts, **varkwargs)
            endpts_list.append(endpts.T)  # to make first dimension the number of points
            # sanity check variable to make sure the slices are happening correctly
            orig_mesh = np.stack([slce.params['zero_pt'] 
                                  + slce.params['unit_vec'] * slce.default_mesh[0][i] 
                                  for i in range(len(slce.default_mesh[0]))], axis=-1)
            slce.add_var('x_mms', slce.timeseries, slce.default_mesh, orig_mesh[0,:])
            slce.add_var('z_mms', slce.timeseries, slce.default_mesh, orig_mesh[1,:])    
            for var in slce.variables.keys(): # this will lose info if the variables have different meshes
                                               #   so will need to interpolate all to the same mesh first to avoid
                slices_dict.setdefault(var,[]).append(slce.variables[var].data.reshape(-1))
            for param in slce.params.keys(): 
                slices_dict.setdefault(param,[]).append(slce.params[param])
            cs = slce.variables['current_sheets'].data[0]
            os = slce.variables['o_structures'].data[0]
            slices_dict.setdefault('s',[]).append(slce.default_mesh[0].reshape(-1))
            # neither 0
            # os 1
            # cs 2
            # both 3
            slices_dict.setdefault('topo',[]).append((2*cs+os).reshape(-1))

        flux_fn = zoomed_dset.variables['flux_fn']
        plasmoids = zoomed_dset.variables['o_structures']

        # visualize the slices
        X,Y = np.meshgrid(*flux_fn.mesh, indexing='ij')
        fig, ax = plt.subplots(figsize=(10,2))
        ax.contourf(X, Y, plasmoids.data[0], alpha=0.5, cmap = mpl.colors.ListedColormap(['white', 'blue'])) # bdys left right auto up
        ax.contour(X, Y, flux_fn.data[0], levels=50, colors='navy', linewidths=1) # dashed lines mean flux_fn negative
        paths = LineCollection(endpts_list, colors='black')
        ax.add_collection(paths)

        for endpts in endpts_list:
            print(endpts[0])
            ax.scatter(*endpts[0], color='black', marker='.') # marks starting pt
    
        fig.savefig(writepath+"/"+str(samples)+"samples_idx"+str(time_idx)+
                       '_' + ''.join(kwargs['get_vars'])+"_plot.svg")
        plt.close()
        
        with h5py.File(writepath+"/"+str(samples)+"samples_idx"+str(time_idx)+
                       '_' + ''.join(kwargs['get_vars']) +'.hdf5', 'a') as writefile:
            writefile.attrs.create('seed', seed, dtype=np.int64)
            for key in slices_dict.keys():  # this is a mess. Is this how you are supposed to use try except blocks? Prob not
                clear_dataset(writefile, key)

                if type(slices_dict[key]) is np.ndarray:
                    dset = writefile.create_dataset(key, data=slices_dict[key])
                    continue

                try:
                    dtype = h5py.vlen_dtype(type(slices_dict[key][0].flatten()[0])) #vlen dtype is list of arrays
                    print(key, type(dtype))
                    dset = writefile.create_dataset(key, (len(slices_dict[key]),), dtype=dtype)
                    for i in range(len(slices_dict[key])): #ew :(
                        writefile[key][i] = slices_dict[key][i] 

                except:
                    clear_dataset(writefile, key)  # clear failed attempt
                    dt = h5py.special_dtype(vlen=str)
                    dset = writefile.create_dataset(key, (len(slices_dict[key]),), dtype=dt)
                    print(key, type(dtype))
                    for i in range(len(slices_dict[key])): #ew :(
                        writefile[key][i] = str(slices_dict[key][i])

NO PARAMS ADDED, FUNCTIONALIITY NOT ADDED YET!!!! SORRY
parameter filename = /scratch/gpfs/kendrab/dataset_vpic_runs/perturbed/21032023/output6/data.h5
Added bx Variable
Added by Variable
Added bz Variable
Added jy Variable
Added vz Variable
Added ex Variable
Added ey Variable
Added ez Variable
parameter orig_idx = 30
Finding structures at simulation time 29.9930477142334
parameter d_per_de = 4
Added bz_smooth Variable
Added bx_smooth Variable
Added flux_fn Variable
Number of nulls:  16
30
parameter x_coords = [[1609.76694659  562.70148091]
 [1350.45361429  551.62001001]
 [1136.56421226  544.91259414]
 [ 848.19241048  550.32141964]
 [1827.28104122  576.54318607]
 [2770.89062935  512.35278283]
 [3421.77773089  558.43320612]
 [3864.93137151  561.42220032]]
parameter o_coords = [[1731.02320892  573.28237059]
 [1512.18936009  554.45718753]
 [1245.58786285  548.07844567]
 [2227.86343179  557.58795012]
 [3674.36609917  569.64717736]
 [3243.65103173  534.05003381]
 [4351.17928408  545.3055663

### Write the 2d stuff to a separate file

In [7]:
# writefile = h5py.File(writepath+"/"+"idx"+str(time_idx)+'_' +''.join(kwargs['get_vars']) +'.hdf5', 'a')

# clear_dataset(writefile, 'default_x') # save mesh
# dset = writefile.create_dataset('default_x', data=sim_dset.default_mesh[0])
# clear_dataset(writefile, 'default_z') # save mesh
# dset = writefile.create_dataset('default_z', data=sim_dset.default_mesh[1])
# for key in sim_dset.variables.keys():  # this is a mess. Is this how you are supposed to use try except blocks? Prob not
#     clear_dataset(writefile, key)
#     dset = writefile.create_dataset(key, data=sim_dset.variables[key].data)

# for key in sim_dset.params.keys():
#     clear_dataset(writefile, key)
#     dset = writefile.create_dataset(key, data=sim_dset.params[key])

close the output file

In [8]:
#writefile.close()

### Visualize the slices (optional)

In [9]:
# %matplotlib widget
# import matplotlib.pyplot as plt
# from matplotlib.collections import LineCollection

# flux_fn = sim_dset.variables['flux_fn']

# X,Y = np.meshgrid(*flux_fn.mesh, indexing='ij')
# fig, ax = plt.subplots(figsize=(10,2))
# ctr =ax.contour(X,Y,flux_fn.data[0], levels=100)
# plt.colorbar(mappable=ctr)
# paths = LineCollection(endpts_list, colors='black')
# ax.add_collection(paths)

# for endpts in endpts_list:
#     print(endpts[0])
#     ax.scatter(*endpts[0], color='black') # marks starting pt
    
# plt.show()    