In [None]:
# # no need for this since it has nothing outside classes and functions
# print(__name__)
# if __name__ == "__main__" and hasattr(__builtins__,'__IPYTHON__') and ('google.colab' in str(get_ipython())):
#     from google.colab import drive
#     drive.mount('/content/drive')
#     %cd /content/drive/MyDrive/PressureReliefWorkArea/SummerWork/
#     !ls

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim
import numpy as np
from sklearn.utils.random import sample_without_replacement
# from collections.abc import Iterable # , Sequence
from copy import deepcopy
# from inspect import signature
# from inspect import getmembers, ismethod
from itertools import chain

%run -n HelperFunctions.ipynb
# import ipynb
# from ipynb.fs.full.HelperFunctions import *

This code loads your CSV file, splits the data into a training set and a test set, and creates a DataLoader for each. The DataLoader can be used to iterate through the data in batches, which is useful for training a neural network.

You can replace 'yourfile.csv' with the path to your actual file. Also, note that this assumes your CSV file doesn't have a header. If it does, you might need to skip the first row.

In [None]:
class JFAccelDataset(Dataset):
    def __init__(self, data, labels, sequence_length=None):
        # if(labels == None):
        #     self.data = data.data
        #     self.labels = data.labels
        #     self.sequence_length = data.sequence_length
        #     return
        self.data = data
        self.labels = labels
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx].transpose(0, 1), self.labels[idx]  # Transposing the sequence and channel dimensions
        
        
    def group(self):
        self.data = [self.data[i:i+self.sequence_length] for i in range(len(self.data) - self.sequence_length + 1)]
        self.labels = self.labels[(int)(self.sequence_length/2) - 1 : len(self.labels) - (self.sequence_length - (int)(self.sequence_length/2))]
        # change to get the majority

In [None]:
class SKDatasetHandler:

    NDARRAY_JFAD_DTYPE = ("NDArray JFAccelDataset", None)
    NDARRAY_JFAD_PAIR_DTYPE = ("NDArray JFAccelDataset Pair", 2)
    TENSOR_JFAD_PAIR_DTYPE = ("Tensor JFAccelDataset Pair", 2)
    VALID_DATASET_DATATYPES = {
        NDARRAY_JFAD_DTYPE,
        NDARRAY_JFAD_PAIR_DTYPE,
        TENSOR_JFAD_PAIR_DTYPE
    }


    # def format_scalar
    

    def get_state_slice(state, index):
        if state["ds_format"][1] is None:
            return state
        if index not in range(state["ds_format"][1]):
            raise IndexError(f"SKDatasetHandler.get_state_slice() could not find index {index} for state {state['description']}.")
        
        ret_state = {}
        for k, v in state.items():
            if k in ("ds_format", "description") or type(v) is not tuple:
                ret_state[k] = v
            else:
                ret_state[k] = v[index]
        return ret_state
    

    # ds_format = tuple[Sequence[torch.Tensor], Sequence[torch.Tensor]]
    def __init__(self, datasets, classification, ds_format = NDARRAY_JFAD_PAIR_DTYPE, is_argmax_format = False, bufferpref = "Inner", *, description):
        SKDescriptors.validate_class_type(classification)
        self.ds_format          = ds_format # always has only one value
        self.description        = description # always has only one value
        # self.adj_test           = True # False # always has only one value, should be considered for removal
        self.datasets           = datasets
        self.classification     = classification
        self.bufferpref         = bufferpref
        self.is_argmax_format   = is_argmax_format
        self.inputnum           = SKDescriptors.NUM_OF_INPUTS_PER_TYPE[classification]
        self.dataset_types      = ("Train", "Test") if ds_format == SKDatasetHandler.NDARRAY_JFAD_PAIR_DTYPE else "Full"
        self.__log              = []
        self.__checkpoints      = []
        self.log_state()
        self.__fail_counter     = 0


    def __str__(self):
        return str(self.description)
    

    def __len__(self):
        return len(self.__log)
    

    def has_valid_description(self):
        for s in self.__log:
            if self.description == s["description"]:
                return False
        return True


    def get_state(self, index = None, feedback = True):
        if index is None:
            vals = {}
            for name, value in vars(self).items():
                if not name.startswith('_'):
                    vals[name] = value
            # vals = tuple(vals)
            if feedback:
                print(f"Read state as {vals}")
            return vals
        if index not in range(len(self.__log)):
            raise IndexError(f"{index} is not a valid index for log of size {len(self.__log)}.")
        if feedback:
            print(f"Read state {index} as {self.__log[index]}")
        return self.__log[index]
    

    def log_state(self, mark_cp = True, return_index = False, feedback = True):
        if not self.has_valid_description():
            print("WARNING: attempting to log a state without changing the description could result in not being able to tell datasets apart.")
            self.description = str(self.__fail_counter)
            self.__fail_counter += 1
            print(f"\tLogging state with description '{self.description}'")
        vals = self.get_state(None, False)
        # if type(vals["datasets"][0]) != JFAccelDataset:
        #     self.apply(lambda ds: JFAccelDataset(ds, batch_size=64, shuffle=True))
        self.__log.append(vals)
        if mark_cp:
            self.__checkpoints.append(len(self.__log) - 1)
        if feedback:
            print(f"Logged state {len(self.__log) - 1} as {vals}")
        if return_index:
            return len(self.__log) - 1

    
    def recall_state(self, index, mark_cp = True, feedback = True, *, __from_restore = False):
        double_restore = __from_restore
        print("loggggggg")
        print(self.__log[index])
        for name, current_val in vars(self).items():
            if not name.startswith('_'):
                value = self.__log[index][name]
                if not current_val == value:
                    double_restore = False
                    setattr(self, name, value)
        if double_restore:
            self.restore_state(feedback)
            return
        if mark_cp:
            self.__checkpoints.append(index)
        if feedback:
            print(f"Recalled state {index} as {self.__log[index]}")


    def restore_state(self, feedback = True):
        index = 0
        if self.__checkpoints:
            index = self.__checkpoints.pop()
        self.recall_state(index, True, feedback, __from_restore = True)


    def reset_state(self, feedback = True):
        self.recall_state(0, False, feedback)
        self.__checkpoints = []


# , *args, **kwargs
    def apply(self, func, copy_datasets = None, is_adj_func = False, save = None, keep_format = None, *args, **kwargs):
        # args = []
        # kwargs = {}
        # ft_info = iterable_info(func_tuple)
        # if not ft_info[0]:
        #     raise TypeError()
        # func_tuple = ft_info[0]
        # func = func_tuple[0]
        # if ft_info[2] > 1:
        #     args = func_tuple[1]
        # if ft_info[2] > 2:
        #     kwargs = func_tuple[2]
            
            # (not (func in (SKDatasetHandler.save_data_and_labels, SKDatasetHandler.get_inverse_and_counts)) and func in list(zip(*getmembers(self, ismethod)))[1])
        if not callable(func) or func == self.apply:
            raise TypeError(f"ERROR: using {func} within SKDatasetHandler.apply() as the function will likely yield undesired results.\
                            \n\tIf you are trying to combine two datasets in some fashion, try list(zip(ds1, ds2)) as the input dataset and access them\
                            \n\tby using lambda_arg[0] and lambda_arg[1].\
                            \n\tIf you are trying to use one of SKDatasetHandler's functions, set the object's state based on the dataset you want\
                            \n\tto apply it to and then call that function without using SKDatasetHandler.apply(). SKDatasetHandler.apply() is for\
                            \n\tfunctions from outside the class that take in a dataset as their first parameter")
        skip_test = False # is_adj_func and not self.adj_test
        if copy_datasets is None:
            datasets = deepcopy(self.datasets)
            if save is None:
                save = True
        else:
            datasets = deepcopy(copy_datasets)
            if save is None:
                save = False
        if keep_format is None:
            keep_format = save
        else:
            keep_format = save or keep_format

        if type(datasets) is tuple or self.ds_format[1] > 1:
            all_eq = True
            first = None
            result = []
            for ds, dst in zip(datasets, self.dataset_types):
                if dst == "Test" and skip_test:
                    result.append(ds)
                    continue
                temp = func(ds, *args, **kwargs)
                result.append(temp)

                if first is None:
                    first = temp
                if type(temp) is not int:
                    all_eq = False
                    # print("boo")
                elif all_eq:
                    all_eq = all_eq and first == temp
                    # print(f"yay -- all_eq: {all_eq}, temp: {temp}")
                
            if all_eq and not keep_format and first is not None:
                result = first
                # print(f"yay: {result}")
            else:
                result = tuple(result)
                # print("boo")
        else:
            assert not skip_test # don't be adjusting the whole dataset, only the training data
            result = func(datasets, *args, **kwargs)
        
        if save:
            self.datasets = result
        else:
            return result

    def apply_all(self, funcs_plus_args):
        # funcs_plus_args is a list of tuples
            # Each tuple holds a function to apply, how to apply it, and arguments to pass to it
        # Below we explain the different parts, but this is what it might look like (this would not be 100% valid)
            # [(combine), (True, None, lambda x: x.bufferpref = "Outer", [], {}), (False, ((2, 0, 3), (3, 1, 0)), np.mean)]
        # The default format of each tuple is (use_self, flow_info, func, args, kwargs)
            # purposes:
                # use_self is a bool saying whether to use self as the first arg or to leave it off
                # flow_info is a tuple (or None) that holds information about what outputs from other funcs should be passed into it
                # func is the function to apply
                # args is a list to be entered into the function as *args
                # kwargs is a dict to be entered into the function as **kwargs
            # everything but func is not always necessary
            # on flow_info:
                # if it is None, it does not use other funcs' outputs
                # if it is a tuple, it will contain other tuples that each correspond to one argument
                # the format for if it is a tuple is (arg_index, source_index, return_index)
                    # all three are ints
                    # purposes:
                        # arg_index tells which parameter to pass the output into
                        # source_index tells which tuple in funcs_plus_args represents where to get the output
                        # return_index tells which output of the func it should pull from
                    # arg_index:
                        # must be unique
                        # causes argument with that position passed in args to be ignored
                            # if that position is represented in kwargs, that should be ignored too
                    # return_index is not necessary since most functions only return one output
        outputs = []

        # not all of the mentioned quality of life logic is here;
            # do not remove any arguments, and don't expect flow_info to override anything from kwargs
        for fpa in funcs_plus_args:
            assert len(fpa) == 5
            args = fpa[3]
            # if len(fpa) > 1 and fpa[1] is not None:
                # firstarg = fpa[1]
            for t in fpa[1]:
                if len(t) == 2:
                    i, j, = t
                    # k = 0
                    assert i <= len(args)
                    args[i] = outputs[j]
                else:
                    i, j, k = t
                    assert i <= len(args)
                    args[i] = outputs[j][k]
            if fpa[0]:
                args = [self] + args
            output = fpa[2](args, **fpa[4])
            outputs.append(output)
        return output

            # if len(fpa) == 1:
            #     fpa[0](firstarg)
            # else:
            #     fpa[0](firstarg, *fpa[1 : ])



    def save_data_and_labels(dataset_data_labels_list):
        dataset = deepcopy(dataset_data_labels_list[0])
        dataset.data = dataset_data_labels_list[1]
        dataset.labels = dataset_data_labels_list[2]
        return dataset

    def __get_inverse_and_counts_col(col):
        # call_unique = lambda col: np.unique(col, return_inverse = True, return_counts = True)
        # invert_mid = lambda info: (info[0], invert(info[1]), info[2])
        # reorder_mid = lambda info: (info[0], tuple(info[1][k] for k in info[0]), info[2])
        # format_unique = lambda col: reorder_mid(invert_mid(call_unique(col)))
        # get_inverse_and_counts_col = lambda col: dict([(a, (b, c)) for a, b, c in zip(*format_unique(col))])
        # call_unique
        col_info = np.unique(col, return_inverse = True, return_counts = True)
        # print(col_info)
        # invert_mid
        col_info = (col_info[0], invert(col_info[1]), col_info[2])
        # print(col_info)
        # reorder_mid
        col_info = (col_info[0], tuple(col_info[1][k] for k in col_info[0]), col_info[2])
        # print(col_info)
        # package as a dict
        retval = dict([(a, (b, c)) for a, b, c in zip(*col_info)])
        # for k, v in retval.items():
        #     print(f"{k}: {v[1]}")
        # print(retval.get(1, (0, 0)))
        return retval

    def get_inverse_and_counts(dataset, is_argmax_format):
        # get_counts = cache(lambda col: dict(zip(*np.unique(col, return_counts = True))))

        
        # print([n for n in call_unique(repressed_dataset[0].labels)])
        # ye = invert_mid(call_unique(repressed_dataset[0].labels))
        # print(call_unique(repressed_dataset[0].labels))
        # print(ye)
        # print(ye[1].keys())
        # # print([n for n in reorder_mid(ye)[1]])
        # for k in ye[0]:
        #     print(f"k = {k}")
        #     print(ye[1][k])
        # print([ye[1][k] for k in ye[0]])
        # tuple(info[1][k] for k in info[0])
        # print(list(zip(*format_unique(repressed_dataset[0].labels))))
        # print([n for n in get_inverse_and_counts(repressed_dataset[0].labels)])
        # raise Exception("yippee")

        if is_argmax_format:
            # accounted = set(range(SKDescriptors.NUM_OF_CLASSES_PER_TYPE[self.classification])) - set(target_classes * int(skip_repressed))
            
            # inverse_and_counts = self.apply(lambda ds: get_inverse_and_counts(ds.labels), repressed_dataset, True)
            return SKDatasetHandler.__get_inverse_and_counts_col(dataset.labels)

            # no the class_indices and class_counts lines cannot be changed to the corresponding lines in the else block.
                # down there we utilize accounted before the lines; here we only use it in the lines.
                # also, accounted is almost completely different between the two;
                # here it holds indices, down there it holds a dataset
            # class_indices = self.apply(lambda ds: [ds.get(v, 0)[0] for v in repressed_dataset], inverse_and_counts, True)
            # class_counts = self.apply(lambda ds: [ds.get(v, 0)[1] for v in accounted], inverse_and_counts, True)
        else:
            # accounted = self.apply(lambda ds: np.delete(ds.labels, target_classes * int(skip_repressed), axis=1), repressed_dataset, True)
            # accounted = self.apply(lambda ds: np.delete(ds.labels, target_classes, axis=1) if skip_repressed else ds.labels, repressed_dataset, True)
            
            # inverse_and_counts = self.apply(lambda ds: \
            #                                 dict(zip(ds, np.apply_along_axis(lambda col: get_inverse_and_counts(col).get(1, (0, 0)), 0, ds))), \
            #                                 repressed_dataset, True)

            # a3result = np.apply_along_axis(SKDatasetHandler.__get_inverse_and_counts_col, 0, dataset.labels)
            # inv_and_counts = {}
            # for i, d in enumerate(a3result):
            #     inv_and_counts[i] = d.get(1, ([], 0))
            
            # print(inv_and_counts)
            # print(dataset.labels)
            return dict((i, d.get(1, ([], 0))) for i, d in enumerate(np.apply_along_axis(SKDatasetHandler.__get_inverse_and_counts_col, 0, dataset.labels)))
        #inv_and_counts
        # dict(zip(dataset.labels, inv_and_counts))

            # class_indices = self.apply(lambda ds: [b for b, c in ds.values()], inverse_and_counts, True)
            # class_counts = self.apply(lambda ds: [bc[1] for a, bc in ds], inverse_and_counts, True)



    def result(self):
        return self.datasets, self.inputnum

    def diff(self, *args, **kwargs):
        pass
        return self

    def remove_outliers(self, rem_type = None, rem_func = None, *args, **kwargs):
        pass
        return self

    def normalize(self, norm_type = None, norm_func = None, *args, **kwargs):
        pass
        return self

    def combine(self, comb_type = None, comb_func = None, *args, **kwargs):
        if comb_type is not None or comb_func is not None or args or kwargs:
            raise NotImplementedError(f"SKInputConverter.combine() has no implemented parameters")
        self.dataframe = self.dataframe.iloc[:,:self.inputnum].apply(np.linalg.norm).join(self.dataframe.iloc[:,self.inputnum:])
        self.dataframe.columns = pd.Index(np.arange(len(self.dataframe.columns) - self.inputnum + 1))
        self.inputnum = 1
        return self



    # dataset: JFSKAccelDataset, /, 
    # target_classes = ["Other", "Stationary"], rep_format = 'str', rep_func = np.mean, skip_repressed = True, apply_to_all = False, *args, **kwargs
    def repress_classes(self, target_classes = None, rep_format = 'str', rep_func = np.mean, skip_repressed = True, apply_to_all = False, *args, **kwargs):
        # NOTE: the logic here only works for one-hot vectors
        if not SKDescriptors.NUM_OF_OUTPUTS_PER_TYPE[self.classification] == 1:
            raise NotImplementedError(f"JFSKLoader.repress_classes() with repress_stationary=True may not be equipped to fairly sample stationary data for classification types with outputs that are not one-hot vectors.")

        if not callable(rep_func):
            raise NotImplementedError(f"JFSKLoader.repress_classes()'s rep_func must currently be a callable (function)\
                                      \n\tIf you want more complex logic where you'd test for a string or something, feel free to alter the code")
        
        if target_classes is None:
            target_classes = [SKDescriptors.OTHER_TAG, SKDescriptors.STATIONARY_TAG]
            rep_format = 'tag'
        target_classes = SKDescriptors.format_classes(self.classification, target_classes, rep_format)
        if not len(target_classes) and not apply_to_all:
            return

        # if self.ds_format not in SKDatasetHandler.VALID_DATASET_DATATYPES:
        #     raise ValueError(f"'{self.ds_format}' is not currently a valid dataset type for SKDatasetHandler.")
        repressed_dataset = self.apply(lambda ds: JFAccelDataset(deepcopy(ds.data), deepcopy(ds.labels), ds.sequence_length), None, True, False)
        all_classes = range(SKDescriptors.NUM_OF_CLASSES_PER_TYPE[self.classification])

        adjusted = all_classes if apply_to_all else target_classes
        inverse_and_counts = self.apply(SKDatasetHandler.get_inverse_and_counts, repressed_dataset, True, False, True, self.is_argmax_format)
        accounted = [c for c in all_classes if c not in (target_classes * int(skip_repressed))]
        class_indices = self.apply(lambda ds: [ds.get(v, (0, 0))[0] for v in all_classes], inverse_and_counts, True)
        class_counts = self.apply(lambda ds: [ds.get(v, (0, 0))[1] for v in accounted], inverse_and_counts, True)
        # print("\n\n\n\nAdjustment height:")
        adjustment_height = self.apply(lambda row: int(rep_func(row, *args, **kwargs)), class_counts, True)
        # print("\n\n\n\n\n")
        take_del_sample = lambda lz_col_ah: sample_without_replacement(len(lz_col_ah[0]), max(0, len(lz_col_ah[0]) - lz_col_ah[1]), random_state = 42)
        # get_class_i_rows_indexed = lambda ds, i: np.nonzero(np.transpose(ds.labels)[i])[0]
        apply_to_nonempty = lambda lz_ds: list(np.array(lz_ds[0])[lz_ds[1]] if lz_ds[1].size > 0 else lz_ds[1])
        delete_y_from_x_data = lambda lz_ds: np.delete(lz_ds[0].data, lz_ds[1], axis = 0)
        delete_y_from_x_labels = lambda lz_ds: np.delete(lz_ds[0].labels, lz_ds[1], axis = 0)

        # apply_to_nonempty = lambda drni_ds, ciri_ds: [ciri_ds[ds] if ds.size > 0 else ds for ds in drni_ds]
        # apply_to_nonempty = lambda drni_ds, ciri_ds: ciri_ds[drni_ds] if drni_ds.size > 0 else drni_ds
        # delete_y_from_x_data = lambda x, y: np.delete(x.data, y, axis = 0)
        # delete_y_from_x_labels = lambda x, y: np.delete(x.labels, y, axis = 0)
        # apply_del_data = lambda del_ds: self.apply(repressed_dataset, delete_y_from_x_data, True, False, del_ds)
        # apply_del_labels = lambda del_ds: self.apply(repressed_dataset, delete_y_from_x_labels, True, False, del_ds)

        del_rows_indexed = []
        for i in adjusted:
            # class_i_rows_indexed = self.apply(, repressed_dataset, True)
            class_i_rows_indexed = self.apply(lambda row: row[i], class_indices, True)
            i_del_rows_not_indexed = self.apply(take_del_sample, list(zip(class_i_rows_indexed, adjustment_height)), True, False, True)
            # if self.apply(lambda ds: ds.size, True, del_rows_not_indexed) > 0:
            #     del_rows_indexed = class_i_rows_indexed[del_rows_not_indexed]
            # else:
            #     del_rows_indexed = del_rows_not_indexed
            # get_del_rows_indexed = lambda ciri_ds: self.apply(del_rows_not_indexed, apply_to_nonempty, True, False, ciri_ds)
            # del_rows_indexed = self.apply(class_i_rows_indexed, get_del_rows_indexed, True)
            i_del_rows_indexed = self.apply(apply_to_nonempty, list(zip(class_i_rows_indexed, i_del_rows_not_indexed)), True)
            del_rows_indexed += i_del_rows_indexed


        newdata = self.apply(delete_y_from_x_data, list(zip(repressed_dataset, del_rows_indexed)), True)
        newlabels = self.apply(delete_y_from_x_labels, list(zip(repressed_dataset, del_rows_indexed)), True)
        # print(self.apply(lambda ds: np.shape(ds.data), self.datasets, save = False))
        # print(self.apply(np.shape, newdata, save = False))
        self.apply(SKDatasetHandler.save_data_and_labels, list(zip(repressed_dataset, newdata, newlabels)), True, True, True)

In [None]:
class JFSKLoader:
    def __init__(self, file_path, sequence_length, repress_classes = True, feedback = None, *args, **kwargs):
        if feedback is None:
            self.feedback = False
            feedback = True
        else:
            self.feedback = feedback
        # 1. open file

        # Gather file info
        # self.file_directory, self.beginning_descriptors, self.file_name, self.ending_descriptors, self.file_extension, self.specifier_values = SKFileNameHandler.read_data_file_name(file_path)
        file_info, self.specifier_values = SKFileNameHandler.read_data_file_name(file_path, feedback)
        file_extension = file_info["File Extension"]
        classification_type = self.specifier_values[SKDescriptors.CLASSIFICATION_TYPE_FS]
        input_num = SKDescriptors.NUM_OF_INPUTS_PER_TYPE[classification_type]

        match file_extension:
            case ".csv":
                dataframe = pd.read_csv(file_path)
                # code test file: Data/Week 1/Left then Right/Processed/Type3-Freq10-Labeled_Motion-sessions_2023-08-26_17-25-54.csv
                # classifier training file: Data/COMBINED_Type3-Freq10-Labeled_Motion-sessions_23-24_Fall.csv
            case _:
                raise NotImplementedError(f"JFSKLoader is not equipped to open {file_extension} files.")


        # 2. split dataset into data and labels

        # Get data and labels from dataframe
        data = dataframe.iloc[:, :input_num].to_numpy()  # x, y, z data
        labels = dataframe.iloc[:, input_num:].to_numpy()  # labels


        # 3. group
        self.sequence_length = sequence_length
        g_dataset = JFAccelDataset(data, labels, sequence_length)
        if sequence_length is not None and sequence_length > 0:
            g_dataset.group()


        # 4. randomize and prepare
        data_train, data_test, labels_train, labels_test = train_test_split(g_dataset.data, g_dataset.labels, test_size=0.2, random_state=42)

        # Convert to Dataset
        train_dataset = JFAccelDataset(data_train, labels_train, self.sequence_length)
        test_dataset = JFAccelDataset(data_test, labels_test, self.sequence_length)


        # other conversions happen at time of creating dataloaders
        # # Convert data to tensors
        # data_train = torch.tensor(np.array(data_train), dtype=torch.float32)  
        # data_test = torch.tensor(np.array(data_test), dtype=torch.float32)

        # # Convert labels to tensors and get max index (assuming one-hot encoding)
        # labels_train = torch.argmax(torch.tensor(np.array(labels_train), dtype=torch.float32), dim=1)
        # labels_test = torch.argmax(torch.tensor(np.array(labels_test), dtype=torch.float32), dim=1)


        # 5. make adjustments related to the data
        self.handler = SKDatasetHandler((train_dataset, test_dataset), classification_type, SKDatasetHandler.NDARRAY_JFAD_PAIR_DTYPE, False, "Inner", \
                                        description = "Type 3, Unchanged")
        self.__dataloader_dict = {}
        # uses the more general format_index
        self.current_dataloaders = format_index(0, SKDatasetHandler.NDARRAY_JFAD_PAIR_DTYPE[1])
        self.__selected_dataset_types = None
        # THIS is where we would use SKInputConverter

        # self.handler.apply(lambda ds: print(np.shape(ds.data)), save = False)
        # 6. make adjustments related to the labels
        self.handler.repress_classes(*args, **kwargs)
        self.handler.description = "Type 3, StatAndOth Repressed"
        rep_index = self.handler.log_state(True, True, False)
        # self.handler.apply(lambda ds: print(np.shape(ds.data)), save = False)
        assert rep_index == 1, "It's okay if you create more datasets in JFSKLoader, but make sure self.current_dataloaders is changed accordingly"
        if repress_classes:
        # if args and args[0]:
        #     if len(args) > 1:
        #         self.repress_classes(*(args[1:]), **kwargs)
        #     else:
        #         self.repress_classes(**kwargs)
        # elif kwargs.pop("repress_classes", False):
            self.current_dataloaders = (rep_index, 0)


        # 5. create initial dataloader for full dataset

        # These are now handled in get_dataloaders(), which automatically pulls
            # from self.handler if there's no corresponding entry in self.__dataloader_dict
        # train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        # test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

        # self.__dataloader_dict = [(self.dtype_str, train_dataloader, test_dataloader)]
            



    def __iter__(self):
        return chain.from_iterable(self.get_dataloaders(None, self.__selected_dataset_types, None, self.feedback))



    def select(self, dataset_types = None):
        self.__selected_dataset_types = dataset_types
        return self
    


    def create_dataloader(dataset, input_is_argmax = False):
        new_dataset = JFAccelDataset(deepcopy(dataset.data), deepcopy(dataset.labels), dataset.sequence_length)
        new_dataset.data = torch.tensor(new_dataset.data, dtype=torch.float32) 
        new_dataset.labels = torch.tensor(new_dataset.labels, dtype=torch.float32)
        if not input_is_argmax:
            new_dataset.labels = torch.argmax(new_dataset.labels, dim=1)
        new_dataset = DataLoader(new_dataset, batch_size=64, shuffle=True)
        return new_dataset



    def format_index(self, index):
        test_index = iterable_info(index)[0][0]
        state = self.__dataloader_dict.get(test_index, None)
        if state is None:
            state = self.handler.get_state(test_index, False)
            if state is None:
                if index is None:
                    return self.current_dataloaders
                print(f"WARNING: index {index} does not correspond to any valid state")
                return index
        # this is a different, more general format_index
        return format_index(index, state["ds_format"][1], self.current_dataloaders)



    # def use_dc(self, func, *args, **kwargs):
    #     try:
    #         val = func(self.handler, *args, **kwargs)
    #     except Exception as e:
    #         print(f"WARNING: Encountered Exception in JFSKLoader.use_dc({", ".join([func.__name__, *args, repr(kwargs)])}), see error below: \n{e}")
    #         val = None
    #     return val
    

    # indices = <current>, dataset_types = <all>, feedback = self.feedback
    # @test_return(None, "Returning dataloaders: ")
    def get_dataloaders(self, indices = None, dataset_types = None, return_indices = False, feedback = None):
        if feedback is None:
            feedback = self.feedback
        ds_format = None
        ret_dataloaders = []
        ret_state = {}
        state_slices = []

        if indices is None:
            indices = self.current_dataloaders
        # if type(indices) is not tuple#:
        indices = list(iterable_info(indices)[0])
        ind_is_int = len(indices) == 1
        for i, index in enumerate(indices):
            state = self.__dataloader_dict.get(index, None)
            if state is None:
                try:
                    if index is None:
                        index = self.handler.log_state(True, True, feedback)
                    state = self.handler.get_state(index, feedback)
                except IndexError as ie:
                    print(f"WARNING: encountered invalid index {index} based on index {i} in {indices} in JFSKLoader.get_dataloader(), see error below: \n{ie}")
                    try:
                        if index is None:
                            raise ie
                        index = self.handler.log_state(True, True, feedback)
                        state = self.handler.get_state(index, feedback)
                    except Exception as e:
                        raise type(e)(f"ERROR: Failed to access any dataloader, see error below: \n{e}")
                state = deepcopy(state)
                datasets = state.pop("datasets")
                if state["ds_format"] not in SKDatasetHandler.VALID_DATASET_DATATYPES:
                    print(f"WARNING: '{state['ds_format']}' is invalid dataset datatype. Reloading same data.")
                    loop_dataloaders = self.get_dataloaders(None, None, False, feedback)
                else:
                    loop_dataloaders = self.handler.apply(JFSKLoader.create_dataloader, datasets, False, False)
                    state["dataloaders"] = loop_dataloaders
                    state["is_argmax_format"] = True
                    self.__dataloader_dict[index] = state
                if ds_format is None:
                    # if we do give it a different value here due to the "format is not None" conditional,
                        # we know we won't be checking against it next iteration... because there won't be a next iteration :D
                    ds_format = state["ds_format"][1] # if state["ds_format"][1] is not None else 1
                elif ds_format != state["ds_format"][1]:
                    raise TypeError(f"ERROR: index {index} has a different dataset format size to index {indices[0]}'s: {state['ds_format'][1]} versus {ds_format}\n\
                                    Index {indices[0]}'s dataset description: {self.__dataloader_dict.get(indices[0],{}).get('description','')}\n\
                                    Index {index}'s dataset description: {self.__dataloader_dict.get(index,{}).get('description','')}") #\n\
                                    # Entire dataloader dict: {self.__dataloader_dict}")
            else:
                loop_dataloaders = state["dataloaders"]
            # self.current_dataloader = index
            if ind_is_int or state["ds_format"][1] in (1, None):
                indices = (index,)
                ret_dataloaders = loop_dataloaders
                ret_state = state
            else:
                indices[i] = index
                ret_dataloaders.append(loop_dataloaders[i])
                state_slices.append(SKDatasetHandler.get_state_slice(state, i))

        if not ret_state:
            ret_state["ds_format"] = state_slices[0]["ds_format"]

        indices = tuple(indices)
        if len(indices) == 1:
            indices = self.format_index(indices)
        self.current_dataloaders = indices
        self.__selected_dataset_types = dataset_types

        if ret_state["ds_format"][1] is None:
            return ret_dataloaders
        ret_dataloaders = tuple(ret_dataloaders)
        if dataset_types is None:
            return ret_dataloaders
        if return_indices is None:
            return tuple([dl for dl, dst in tuple(zip(ret_dataloaders, state_slices)) \
                          if (dst["dataset_types"] == dataset_types if type(dataset_types) is str else dst["dataset_types"] in dataset_types)])
        return tuple(zip([(dl, i) for i, dl, dst in enumerate(tuple(zip(ret_dataloaders, state_slices))) \
                          if (dst["dataset_types"] == dataset_types if type(dataset_types) is str else dst["dataset_types"] in dataset_types)]))
            

In [None]:
class SKLabelConverter:

    VALID_TYPE_CONVERSIONS = (
        (3, 5),
    )
    OVERRIDE_TYPE_VALIDATION = (
        # This is here only for conversions that
            # fail validate_class_type_conversion() but not for simple reasons
            # (simple reasons like accidentally choosing the wrong input/output types
            # or not listing the correct values in the below dictionaries)
        # If you add an entry here you may have to change logic of other parts of the code
            # for instance if the number of columns of the output spreadsheet will be more
            # than the input spreadsheet, you may have to change the dataframe.drop line at the end
    )



    # Currently these do nothing. If we later change how we want the buffers to function
        # (not the length of the buffers but which buffers overlap into other classes),
        # we will be able to do so using this
    VALID_BUFFER_TYPE_CONVERSIONS = ()
    OVERRIDE_BUFFER_TYPE_VALIDATION = (
        # This is here only for BufferType conversions that
            # fail validate_buffer_type_conversion() but not for simple reasons
            # (simple reasons like accidentally choosing the wrong input/output types
            # or not listing the correct values in the above dictionaries)
        # If you add an entry here you may have to change logic of other parts of the code
    )



    def __init__(self, labeled_data_file = None, *args):
        # if not all(n == Converter.NUM_OF_LABEL_TYPES for n in (len(Converter.NUM_OF_INPUTS_PER_TYPE), len(Converter.NUM_OF_CLASSES_PER_TYPE), len(Converter.NUM_OF_OUTPUTS_PER_TYPE))):
        #     print("Converter is not usable if defining dictionaries do not match corresponding dictionaries in size.")
        #     print("Fix and rerun the code to use the converter")
        #     return
        if labeled_data_file is not None:
            self.input_file_info, self.input_specifiers = SKFileNameHandler.read_data_file_name(labeled_data_file)
            self.output_label_type = -1
            self.output_freq = -1
            self.output_buffer_type = -1
            self.output_buffer_num = -1
        else:
            self.input_file_info = None
            self.input_specifiers = args[0]
            self.output_label_type = -1
            self.output_freq = -1
            self.output_buffer_type = -1
            self.output_buffer_num = -1

        self.type_validated = False
        self.buffer_num_validated = False



    def validate_label_type_conversion(self, input_label_type, output_label_type):
        # validating type values' consistency
        # input_label_type matches self.input_label_type
        matches_input_file = self.input_specifiers.get(SKDescriptors.CLASSIFICATION_TYPE_FS, -1) == input_label_type
        # input_label_type is valid and corresponds to self.with_class_num
        input_with_class_num = self.input_specifiers.get(SKDescriptors.WITH_CLASS_NUMBER_FS, -1)
        has_valid_input_type = SKDescriptors.validate_class_type(input_label_type, input_with_class_num, False)
        # output_label_type is valid
        has_valid_output_type = SKDescriptors.validate_class_type(output_label_type, 0, False)
        # has valid values (compared to type dictionaries and the conversion file)
        has_consistent_values = matches_input_file and has_valid_input_type and has_valid_output_type

        # validating conversion logic
        try:
            _ = SKDescriptors.get_superclass_dict(input_label_type, output_label_type)
            has_valid_logic = True
        except AssertionError:
            print("WARNING: if you start doing logic that requires work without one-hot vector outputs, you may want to consider handling this invalidation differently")
            has_valid_logic = False

        # allowing override
        # the others are to keep someone from accidentally making a "bad conversion,"
            # but this one is to allow more-complex conversions that are possible,
            # given that someone manually listed the conversion in OVERRIDE_TYPE_VALIDATION
        # this does not override the conversion if the file is incorrect or if the types are invalid
        is_overridden = (input_label_type, output_label_type) in SKLabelConverter.OVERRIDE_TYPE_VALIDATION

        return has_consistent_values and (has_valid_logic or is_overridden)


    def validate_buffer_num_conversion(self, input_buffer_num, output_buffer_num):
        # validate consistency
        # both buffer nums are non-negative
        has_nonnegative_buffer_nums = input_buffer_num >= 0 and output_buffer_num >= 0
        # input_buffer_num matches self.input_buffer_num
        matches_input_file = self.input_specifiers.get(SKDescriptors.BUFFER_NUMBER_FS, -1) == input_buffer_num
        # BufferType is valid
        has_valid_buffer_type = self.input_specifiers.get(SKDescriptors.BUFFER_TYPE_FS, 0) in np.arange(1, SKDescriptors.NUM_OF_BUFFER_TYPES + 1)
        # combining
        has_consistent_values = has_nonnegative_buffer_nums and matches_input_file and has_valid_buffer_type
        # returning
        return has_consistent_values


    def set_label_type_conversion(self, input_label_type, output_label_type):
        #NOTE that types have not yet been implemented as tuple labels
        if(not self.validate_label_type_conversion(input_label_type, output_label_type)):
            print(f"Current object/class definitions prohibit the conversion from Type {input_label_type} to Type {output_label_type}.")
            self.type_validated = False
            return
        #self.input_label_type = input_label_type
        self.output_label_type = output_label_type
        self.type_validated = True


    def set_buffer_num_conversion(self, input_buffer_num, output_buffer_num):
        if(not self.validate_buffer_num_conversion(input_buffer_num, output_buffer_num)):
            print(f"Current object/class definitions prohibit the conversion from BufferNum {input_buffer_num} to BufferNum {output_buffer_num}.")
            self.buffer_num_validated = False
            return
        #self.input_buffer_num = input_buffer_num
        self.output_buffer_num = output_buffer_num
        self.buffer_num_validated = True



# df_datatype = torch.Tensor, 
    def convert_label_type(self, input_dataframe, is_argmax_format = False, to_file = False, labels_only = True, feedback = False):
        if not self.type_validated:
            raise AssertionError("Yeah, no. You need to set the label type successfully before trying any conversions")

        input_label_type = self.input_specifiers.get(SKDescriptors.CLASSIFICATION_TYPE_FS, -1)
        df_datatype = type(input_dataframe)


        base_constr = lambda dataframe_like = None, size = 0: df_datatype(dataframe_like) if dataframe_like is not None else df_datatype(np.empty((size, 0)))
        # for indices_info, you must use iterable_info(indices)
        sc_inner = lambda dataframe, indices_info: np.expand_dims(dataframe[:, indices_info[0][0]], 1) if indices_info[2] == 1 else base_constr(np.concatenate([np.expand_dims(dataframe[:, i], 1) for i in indices_info[0]], axis = 1))
        merge_cols = lambda cols: base_constr(np.expand_dims(np.sum(cols, axis = 1), 1))
        append_col = lambda dataframe, col: base_constr(np.concatenate((dataframe, col), axis = 1))

        match df_datatype:
            case np.ndarray:
                base_constr = lambda dataframe_like = None, size = 0: np.array(dataframe_like) if dataframe_like is not None else np.empty((size, 0))
                # for indices_info, you must use iterable_info(indices)
                sc_inner = lambda dataframe, indices_info: np.expand_dims(dataframe[:, indices_info[0][0]], 1) if indices_info[2] == 1 else np.concatenate([np.expand_dims(dataframe[:, i], 1) for i in indices_info[0]], axis = 1)
                merge_cols = lambda cols: np.expand_dims(np.sum(cols, axis = 1), 1)
                append_col = lambda dataframe, col: np.concatenate((dataframe, col), axis = 1)
            case torch.Tensor:
                # copy = lambda dataframe_like: dataframe_like.clone().detach().requires_grad_(dataframe_like.requires_grad)
                base_constr = lambda dataframe_like = None, size = 0: torch.from_numpy(dataframe_like) if dataframe_like is not None else torch.empty((size, 0), dtype=input_dataframe.dtype, layout=input_dataframe.layout, requires_grad=input_dataframe.requires_grad)
                # for indices_info, you must use iterable_info(indices)
                sc_inner = lambda dataframe, indices_info: dataframe[:, indices_info[0][0]].unsqueeze(1) if indices_info[2] == 1 else torch.cat([dataframe[:, i].unsqueeze(1) for i in indices_info[0]], dim = 1)
                merge_cols = lambda cols: torch.sum(cols, dim = 1).unsqueeze(1)
                append_col = lambda dataframe, col: torch.cat((dataframe, col), dim = 1)
            case _ :
                # raise NotImplementedError(f"SKLabelConverter.convert_label_type is not yet equipped to handle '{(str)(df_datatype)}'.\
                #                           \n\t\t\tTry using 'torch.Tensor' or implement logic for a different type")
                print(f"WARNING: SKLabelConverter.convert_label_type is not specifically equipped to handle ndarrays/dataframes of type '{(str)(df_datatype)}'.\
                      \n\tEverything should still work, but results may be slower.\
                      \n\tTry using dataframes of type 'np.ndarray' or 'torch.Tensor' or implement logic for a different type.")
        select_cols = lambda dataframe, indices: sc_inner(dataframe, iterable_info(indices))


        if labels_only:
            output_dataframe = base_constr(size = np.shape(input_dataframe)[0])
        else:
            input_num = SKDescriptors.NUM_OF_INPUTS_PER_TYPE[input_label_type]
            output_dataframe = select_cols(input_dataframe, slice(input_num))
            input_dataframe = select_cols(input_dataframe, slice(input_num, input_num + SKDescriptors.NUM_OF_OUTPUTS_PER_TYPE[input_label_type]))


        if is_argmax_format:
            superclasses = SKDescriptors.get_superclass_dict(input_label_type, self.output_label_type)
            output_dataframe = base_constr(np.vectorize(superclasses.__getitem__)(input_dataframe))

            # a, b = np.unique(np.array())
            # output_dataframe = 

            # output_dataframe = np.ndarray(input_dataframe.shape)

            # output_dataframe = input_dataframe
            # for i in range(len(input_dataframe)):
            #     print(len(input_dataframe))
            #     print(i)
            #     print(input_dataframe)
            #     print(input_dataframe[i])
            #     print("yippee")
            #     output_dataframe[i] = superclasses[input_dataframe[i]]
            if feedback:
                print(f"Label conversion from type {input_label_type} to type {self.output_label_type} by changing output column num")
            return output_dataframe
        

        subclasses = SKDescriptors.get_subclass_dict(input_label_type, self.output_label_type)
        # print(superclasses)
        print(subclasses)
        # this is what converts the data
        # if you want custom conversion logic, you might want to start reworking earlier parts of the function
            # and leave this here since this is highly abstract and allows for most needed conversions
        if feedback:
            print(f"Label conversion from type {input_label_type} to type {self.output_label_type}:")
        for i, l in subclasses.items():
            # col_inds = [t[0] for t in l]
            if not l:
                raise AssertionError("Not only did we not find any superclasses, but we failed to detect such with our first assertion test. You will probably need to do some serious searching for this bug.")
            if feedback:
                print(f"\tNew class {i} replaces old classes {l}")
            old_cols = select_cols(input_dataframe, l)
            if len(l) > 1:
                new_col = merge_cols(old_cols)
            else:
                new_col = old_cols
            output_dataframe = append_col(output_dataframe, new_col)


        if to_file:
            raise NotImplementedError("Saving altered data labels to a file does not work currently. Reimplementation of this will hopefully be easy, but it is not a priority at the moment of writing this")
        else:
            return output_dataframe