used with [domain_adaptation_1](./domain_adaptation_1.ipynb)

In [1]:
import subprocess
import os
import h5py
mms_data_loc = '/tigress/kendrab/analysis-notebooks/mms_data/mms_slices/'
from_numpy_types = [np.float64, np.float32, np.float16, np.complex64, np.complex128, np.int64,
                    np.int32, np.int16, np.int8, np.uint8, bool]

In [None]:
min_file_size = 10*1024  # minimum acceptable file size in bytes to avoid trying to open empty files

## get filenames

In [None]:
def get_filenames(folder=mms_data_loc):
    ls_call = subprocess.run("ls", capture_output=True, cwd=folder)
    filenames = ls_call.stdout.decode().split()
    return filenames

## get data

In [None]:
def get_mms_data(filename, folder=mms_data_loc): 
    top_level_data = {}  # if the mms data has more complexity than label -> data this won't get things well
    #make sure file isn't too small
    if os.stat(folder+filename).st_size < min_file_size:
        print("File is empty or too small. Skipping")
        return top_level_data
    with h5py.File(folder + filename,'r') as file:
        for label, data in file.items():
            top_level_data[label] = data[()]  # returns entire h5py dataset of that label
    return top_level_data

### Make a nice dataloader of the data

In [None]:
def format_mms_data(mms_data_dict):
    components_list = []
    for label, data in mms_data_dict.items():
        # handle vector data
        if len(data.shape) == 3:  # batch, vector components, timeseries
            for i in range(data.shape[1]):
                tensor_component = torch.from_numpy(data[:,i:i+1,:]).to(device, dtype=dtype)  # neat trick to keep numpy array dimensionality
                components_list.append(tensor_component)
        # handle scalar data
        elif len(data.shape) == 2:  # batch, timeseries
            if data.dtype in from_numpy_types:
                data = data[:,None,:]
                tensor_component = torch.from_numpy(data[:,0:1,:]).to(device, dtype=dtype)
                components_list.append(tensor_component)
            else:
                try:  # handle case of time as a string
                    time_data = np.vectorize(lambda x: datetime.strptime(x.decode('ascii'), '%Y-%m-%dT%H:%M:%S.%f').timestamp())(data)
                    tensor_component = torch.from_numpy(np.expand_dims(time_data,1)).to(device, dtype=dtype)
                    components_list.append(tensor_component)
                    del tensor_component
                    del time_data
                except TypeError:
                    raise TypeError(f"Dataset is of type {data.dtype} which cannot be" 
                                    "converted to tensor via implemented methods. Try reformatting.")
        # can't handle tensor(2D+ not pytorch tensor) data yet (e.g. 2d/3d distribution functions)
        else:
            raise ValueError(f"dataset is of shape {data.shape} which cannot be automatically processed.")
    
    mms_dset = TensorDataset(*components_list)
    mms_dl = DataLoader(mms_dset, batch_size = batch_size, shuffle=True, drop_last=True)
    
    return mms_dl
    
            