In [None]:
# following code from: https://www.analyticsvidhya.com/blog/2020/08/top-4-pre-trained-models-for-image-classification-with-python-code/
!pip install csiread

from google.colab import drive

import tensorflow as tf
import keras
import os
import os.path
from keras import layers, optimizers, regularizers, models
import numpy as np
import csiread

import re
import matplotlib.pyplot as plt
from matplotlib import colors

# used for confusion matrix
import seaborn as sn
import pandas as pd

# used for kfold cross validation
import sklearn
import sklearn.model_selection

# for hyperparameter search
import itertools
import contextlib

# for saving the data
import json

In [None]:
# mount script

drive.mount('/Drive')
!ln -s '/Drive/MyDrive/google_colab_files_for_CSI' '/content/REU'
root_data_path = os.path.join('REU','csi_data_all')
REU = os.path.join('REU')

In [None]:
# visualization funcs for easy debugging
def subcarLineChart(gph, log = False, save_path = '', plot_title="Magnitude of Signal by Channel and Packet Number"): # log: use log scale or not
    packet_index = gph.shape[1]

    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    plt.title(plot_title)
    plt.xlabel("Packet #")
    plt.ylabel("Magnitude")
    ax1.set_yscale("log") if log else None

    num_subcarriers = gph.shape[1]
    num_samples = gph.shape[0]
    for subcar_index in range(num_subcarriers):
        one_carrier = np.take(gph, [subcar_index], axis = 1)
        ax1.plot(range(num_samples), one_carrier, lw = .25, label = subcar_index)
    if save_path != '':
        plt.savefig(save_path)
    else:
        plt.show()
    plt.close()

def subcarSpectrogram(gph, log = False, subcar_start = 0, subcar_end = -1, save_path=""): # log: use log scale or not
    subcar_end = gph.shape[1] if subcar_end == -1 else subcar_end # replace -1 with valid final index
    plt.imshow(
        gph,
        # extent --> left, rigth, bottom, top
        extent = (subcar_start, subcar_end, gph.shape[0], 0),
        norm = colors.LogNorm() if log else None,
        interpolation='none',
    )
    plt.xlabel('Subcarriers')
    plt.ylabel('Packet #')
    plt.colorbar(label = 'intensity')
    #   plt.show()
    if save_path != "":
        plt.savefig(save_path)
    else:
        plt.show()
    plt.close()

# preprocessing if needed

# some funcs for preprocessing
def subcar_anomaly_removal(subcar, std_threshold):
  z_scores = (subcar - np.mean(subcar)) / np.std(subcar)
  to_ditch = np.where(np.abs(z_scores) > std_threshold)
  # print(to_ditch)
  cleaned = np.copy(subcar)
  cleaned[to_ditch] = np.median(subcar) # replace with median of subcarrier

  return cleaned

def df_anomaly_removal(df, std_threshold):
  # cleaned_df = df.apply((lambda x: subcar_anomaly_removal(x, std_threshold)), axis = 0)
  cleaned_df = np.apply_along_axis(
      (lambda x: subcar_anomaly_removal(x, std_threshold)),
      axis = 0,
      arr=df
      )
  return cleaned_df

## filters ##
def mov_avg(subcar, width):
  return np.convolve(subcar, np.ones(width), 'valid') / width

def gaussian_filter(size, sigma):
  # even integers mess up the gaussian filter
  if size % 2 == 0:
      raise ValueError("Size must be an odd integer.")

  x = np.linspace(-(size // 2), size // 2, size)
  filter = np.exp(-x**2 / (2 * sigma**2))
  filter /= np.sum(filter)  # Normalize the filter
  return filter

# use this to apply a filter to an array (1d)
def apply_filter(array, filter):
    if len(array.shape) != 1:
        raise ValueError("Input array must be 1-dimensional.")
    if len(filter.shape) != 1:
        raise ValueError("Filter must be 1-dimensional.")

    # Pad the input array
    pad_size = len(filter) // 2
    padded_array = np.pad(array, pad_size, mode='constant')

    # Apply the filter
    filtered_array = np.convolve(padded_array, filter, mode='valid')
    return filtered_array

def rolling_difference_1d(arr, window_size):
    result = []
    for i in range(len(arr) - window_size + 1):
        result.append(np.abs(arr[i + window_size - 1] - arr[i]))
    # no padding: we just care abuot the starting index
    return np.array(result)

def rolling_sum(arr, window_size):
    # print(arr.shape)
    return np.convolve(arr, np.ones(window_size), 'valid')

def ash_trimming(csi, n_roll = 20, return_index = False):
    to_trim = csi # will be used later...

    csi = np.apply_along_axis(lambda subcar: rolling_difference_1d(subcar, 3), axis = 0, arr = csi)
    csi = np.apply_along_axis(lambda packet: sum(rolling_difference_1d(packet, 2)), axis = 1, arr = csi)
    csi = rolling_sum(csi, n_roll)
    start_index = np.argmax(csi)
    end_index = start_index + n_roll

    if return_index:
        return start_index
    else:
        # trim the original based on the starting & ending index
        return to_trim[start_index : end_index]

# background removal using singular value decomposition --> works with any 2d matrix
def svd_background_removal(data, num_sv = 1):
    u, s, vh = np.linalg.svd(data)
    background = u[:, :num_sv] @ np.diag(s[:num_sv]) @ vh[:num_sv, :]
    bg_removed = data - background
    return bg_removed - np.min(bg_removed) # shift such that the lowest value is 0

def simple_background_removal(data):
    vertical_sum = np.mean(data, axis = 0)[np.newaxis, :]
    return data - vertical_sum # subtract anything that is static in the subcarriers

# actual script starts here
def clean(csi_matrix, drop = [], log = False, trim = True, remove_bg = False, window_size = 20, filter_size = 2, filter = '', bg_remove_first = False):
    csi_matrix = np.delete(csi_matrix, drop, axis = 1)

    csi_matrix = np.abs(csi_matrix)
    csi_matrix = df_anomaly_removal(csi_matrix, 2)

    # apply specified filter
    # if nothing specified, does nothing
    if filter == 'average':
        csi_matrix = np.apply_along_axis((lambda x: mov_avg(x, filter_size)), arr = csi_matrix, axis = 0)
    elif filter == 'gaussian':
        filter = gaussian_filter(filter_size, 1)
        csi_matrix = np.apply_along_axis((lambda x: apply_filter(x, filter)), arr = csi_matrix, axis = 0)
    # take log, replace anything below 1 with 0

    if bg_remove_first:
        if remove_bg:
            csi_matrix = svd_background_removal(csi_matrix)

        if trim:
            csi_matrix = ash_trimming(csi_matrix, n_roll = window_size)
    else:
        if trim:
            csi_matrix = ash_trimming(csi_matrix, n_roll = window_size)

        if remove_bg:
            csi_matrix = svd_background_removal(csi_matrix)

    if log:
        csi_matrix = np.where(csi_matrix <= 1, 0, np.log(csi_matrix))

    return csi_matrix

def process_dataset_folder(path, trimmed = False, version = 3, create_graphs = False):
    # DEBUG
    # print("path:", path)
    combined_csi_data=np.asarray([[[["-1"]]]])

    for folder in sorted(os.listdir(path)):
        one_gesture_data = np.asarray([[["-1"]]])

        for file in sorted(os.listdir(os.path.join(path, folder))):
            filepath = os.path.join(path, folder, file)
            csireader = csiread.Nexmon(filepath, chip='4339', bw=20)
            csireader.read()
            one_sample = csireader.csi

            ### PREPROCESSING ADJUSTMENTS ###
            bad_packets = [0, 1, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38]
            one_sample = clean(one_sample, log = False, drop = bad_packets, remove_bg = (version in (4, 5)), bg_remove_first=(version == 5), trim = trimmed, filter = 'gaussian', filter_size = 3, window_size = 20) # <-- CHANGE THIS TO ADJUST PREPROCESSING
            #################################

            # generate graphs if specified
            # PLOTS ARE SAVED IN THE "plots" FOLDER IN /REU
            if create_graphs:
                plot_name = re.sub('.pcap', '', file)
                dataset_name = os.path.basename(os.path.normpath(path))
                plot_location = os.path.join('REU','plots',f'{dataset_name}_plots', folder)
                os.makedirs(plot_location, exist_ok = True)
                subcarSpectrogram(one_sample, save_path = os.path.join(plot_location, f'{plot_name}_spectrogram'))
                subcarLineChart(one_sample, save_path = os.path.join(plot_location, f'{plot_name}_line'))

            if one_gesture_data[0][0][0] == "-1":
                one_gesture_data = np.asarray([np.asarray(one_sample)])
            else:
                one_gesture_data = np.append(one_gesture_data, [one_sample], axis=0)

        gesture_folder = os.path.join(path, folder)
        if combined_csi_data[0][0][0][0] == "-1":
            print(combined_csi_data.shape, one_gesture_data.shape, "Number of Samples:", len(os.listdir(gesture_folder)), gesture_folder)
            combined_csi_data = np.asarray([np.copy(one_gesture_data)])
        else:
            print(combined_csi_data.shape, one_gesture_data.shape, "Number of Samples:", len(os.listdir(gesture_folder)), gesture_folder)
            combined_csi_data = np.append(combined_csi_data, [one_gesture_data], axis=0)
    return combined_csi_data

def read_csi(dataset_folder_path, dataset_numpy_path, trimmed = False, version = 3, use_cache = True, create_graphs = False):
    # check if cache exists and will be used
    if use_cache and os.path.exists(dataset_numpy_path): # original
        return np.load(dataset_numpy_path)
    else:
        return process_dataset_folder(dataset_folder_path, trimmed = trimmed, version = version, create_graphs = create_graphs)

# this code has non-functional parts that I didn't bother to change
def preprocess(dataset_nums, version = 3, trimmed = False, create_graphs = False):

    root_data_path = os.path.join('REU','csi_data_all')

    # had to patch earlier code
    dataset_nums = tuple([int(dataset_nums)])

    # save_folder = "csi_numpy_v4" if trimmed else "csi_numpy_v4_unchopped"
    save_folder = 'csi_numpy'
    if version == 4: save_folder += "_v4"
    if version == 5: save_folder += '_v5'
    if not trimmed: save_folder += '_unchopped'

    os.makedirs(os.path.join(root_data_path, save_folder), exist_ok=True)

    # print(root_data_path + dataset_nums[0] + "/")
    print(f"processing dataset {dataset_nums[0]}")
    # dataset_numpy_path = f'{root_data_path}csi_numpy_v4/csi_data_{dataset_nums[0]}.npy'
    dataset_numpy_path = os.path.join(root_data_path, save_folder, f'csi_data_{dataset_nums[0]}.npy')
    print('saving in...', dataset_numpy_path)
    # dataset_folder_path = f"{root_data_path}csi_data_{dataset_nums[0]}/"
    dataset_folder_path = os.path.join(root_data_path, f'csi_data_{dataset_nums[0]}')
    combined_csi_data = read_csi(dataset_folder_path, dataset_numpy_path, trimmed = trimmed, version = version, use_cache = False, create_graphs = create_graphs)
    np.save(dataset_numpy_path, combined_csi_data)
    return combined_csi_data

### loader function ###
# put this at the top of your model and you should be set!
# don't forget to call it

# load data
def npy_files_to_numpy(dataset_nums, preprocessing_version, from_chopped = True):
    datasets = []
    csi_numpy_folder = ''
    if preprocessing_version == 3:
        csi_numpy_folder = 'csi_numpy'
    elif preprocessing_version == 4:
        csi_numpy_folder = 'csi_numpy_v4'
    elif preprocessing_version == 5:
        csi_numpy_folder = 'csi_numpy_v5'
    else:
        raise Exception('Invalid preprocessing_version!!')

    for dataset_num in dataset_nums:
        dataset_path = ''
        if from_chopped:
            dataset_path = os.path.join(root_data_path, csi_numpy_folder, f'csi_data_{dataset_num}.npy')
        else :
            dataset_path = os.path.join(root_data_path, f'{csi_numpy_folder}_unchopped', f'csi_data_{dataset_num}.npy')

        # check if they need preprocessing
        if not os.path.isfile(dataset_path):
            preprocess(dataset_num, version = preprocessing_version, trimmed = from_chopped)

        dataset = np.load(dataset_path)
        datasets.append(dataset)
        print("imported", dataset_path, dataset.shape)

    merged_datasets = np.concatenate(datasets, axis = 1)
    # classes = os.listdir(f'{root_data_path}csi_data_{dataset_num}')
    classes = sorted(os.listdir(os.path.join(root_data_path, f'csi_data_{dataset_num}')))
    print(f'dataset {dataset_num} classes: ', classes)

    # return the data and class names
    return (merged_datasets, np.asarray(classes))

def single_dataset_from_flattened_numpy(samples, labels_numeric, batch_size, random_seed = int(np.random.rand() * 100)):
    # LABELS SHOULD HAVE THE SAME LENGTH AS FLATTENED DATA!!!
    rng = np.random.RandomState(random_seed) # passing the shuffle seed
    samples = rng.permutation(samples)
    rng = np.random.RandomState(random_seed) # reset the shuffle seed
    labels_numeric = rng.permutation(labels_numeric)

    return tf.data.Dataset.from_tensor_slices((samples, labels_numeric)).batch(batch_size)

def single_dataset_from_non_flat_numpy(merged_data, classes, batch_size, random_seed = int(np.random.rand() * 100), give_numpy = False): # only used for cross validation
    labels = []
    print(classes)
    print(merged_data.shape)
    for gesture_index in range(len(merged_data)):
        for sample_index in range(len(merged_data[gesture_index])):
            labels.append(classes[gesture_index])

    # flatten gesture and sample dimensions
    old_shape = merged_data.shape
    samples = merged_data.reshape(-1, old_shape[2], old_shape[3])
    print("number of samples:", len(samples))
    labels = np.asarray(labels)

    # print('labels before randomization:', labels) # DEBUG

    # normalize all input data so that they are between 0 & 1
    # INPUT NORMALIZATION HANDLED AT THE MODEL
    # samples = np.asarray([samples[i] / np.max(samples[i]) for i in range(len(samples))])

    # shuffle data and labels with same seed
    num_labels = len(classes)
    label_to_index = dict(zip(classes, range(num_labels)))
    labels = np.array([label_to_index[label] for label in labels])

    rng = np.random.RandomState(random_seed) # passing the shuffle seed
    samples = rng.permutation(samples)
    rng = np.random.RandomState(random_seed) # reset the shuffle seed
    labels = rng.permutation(labels)

    if give_numpy:
        return (samples, labels)

    return tf.data.Dataset.from_tensor_slices((samples, labels)).batch(batch_size)

def create_datasets(merged_data, classes, train_ratio, valid_ratio, batch_size, random_seed = int(np.random.rand() * 100)):
    labels = []
    for gesture_index in range(len(merged_data)):
        for sample_index in range(len(merged_data[gesture_index])):
            labels.append(classes[gesture_index])

    # flatten gesture and sample dimensions
    old_shape = merged_data.shape
    samples = merged_data.reshape(-1, old_shape[2], old_shape[3])
    print("total number of samples:", len(samples))
    labels = np.asarray(labels)

    # print('labels before randomization:', labels) # DEBUG

    # normalize all input data so that they are between 0 & 1
    # INPUT NORMALIZATION HANDLED AT THE MODEL
    # samples = np.asarray([samples[i] / np.max(samples[i]) for i in range(len(samples))])

    # shuffle data and labels with same seed
    num_labels = len(classes)
    label_to_index = dict(zip(classes, range(num_labels)))
    labels = np.array([label_to_index[label] for label in labels])

    rng = np.random.RandomState(random_seed) # passing the shuffle seed
    samples = rng.permutation(samples)
    rng = np.random.RandomState(random_seed) # reset the shuffle seed
    labels = rng.permutation(labels)

    ds = tf.data.Dataset.from_tensor_slices((samples, labels))

    len_ds = len(ds)
    num_train = int(len(ds) * train_ratio)
    print('LEN OF DS:', len(ds))
    num_valid = int(len(ds) * valid_ratio)
    num_test = len_ds - num_train - num_valid # test size is implicit

    train_ds = ds.take(num_train).batch(batch_size) # from the beginning, take num_train amount
    valid_ds = ds.skip(num_train).take(num_valid).batch(batch_size) # skip num train, take num_valid amount
    test_ds = ds.skip(num_train).skip(num_valid).batch(batch_size) # skip num train, skip num_valid, take the remaining

    print("total samples:", len(ds))
    print("train_batches:", len(train_ds), 'samples:', num_train)
    print('valid_batches:', len(valid_ds), 'samples:', num_valid)
    print("test_baches:", len(test_ds), 'samples:', num_test)

    return train_ds, valid_ds, test_ds

def get_labels(dataset):
    return np.concatenate([y for x, y in dataset], axis = 0)

# functions for training / testing model

def plot_confusion_matrix(conf_matrix, labels):
    conf_matrix = np.round(conf_matrix, 2) # round to 2 decimal places
    df_cm = pd.DataFrame(conf_matrix, index = labels, columns = labels)
    sn.heatmap(df_cm, annot=True, cmap = 'mako')

def confusion_matrix (dataset, model, classes, include_precision_recall = False, plot = True): # remember to pass you test dataset
    predictions = np.array([np.argmax(predlist) for predlist in model.predict(dataset)])
    # print(predictions)
    num_classes = len(classes)
    conf_matrix = np.asarray(tf.math.confusion_matrix(get_labels(dataset), predictions, num_classes = num_classes))
    precisions = np.asarray([conf_matrix[i][i] / np.sum(conf_matrix[:, i]) for i in range(len(conf_matrix))])
    recalls = np.asarray([conf_matrix[i][i] / np.sum(conf_matrix[i, :]) for i in range(len(conf_matrix))])

    # note: I'm adding 1e-7 to the sum of the row to avoid nan when dividing by 0 in an empty row
    conf_matrix_normalized = np.asarray([conf_matrix[idx] / (np.sum(row) + 1e-7) for idx, row in enumerate(conf_matrix)]) # normalize matrix
    if plot:
        plot_confusion_matrix(conf_matrix_normalized, classes)

    return (conf_matrix_normalized, precisions, recalls) if include_precision_recall else (conf_matrix_normalized,) # in case we need to access the matrix values

def train_and_evaluate(model, train_ds, test_ds, classes, epochs, valid_ds = None, lr_scheduler = None, silent = False, include_precision_recall = False):
    display(model.summary())

    history = model.fit(
        train_ds,
        validation_data = valid_ds, # if nothing passed, this is ignored
        epochs = epochs,
        verbose = 1,
        callbacks = [keras.callbacks.LearningRateScheduler(lr_scheduler)] if lr_scheduler else None
    )

    # print(history.history.keys())
    if not silent:
        plt.plot(history.history['acc'])
        if valid_ds is not None:
            plt.plot(history.history['val_acc'])
        plt.title('Accuracy')
        plt.xlabel('epoch')
        plt.xlabel('accuracy')
        plt.legend(['train','val'], loc='upper left')
        plt.grid()
        plt.show()

        plt.plot(history.history['loss'])
        if valid_ds is not None:
            plt.plot(history.history['val_loss'])
        plt.title('Loss')
        plt.xlabel('epoch')
        plt.xlabel('loss')
        plt.legend(['train','val'], loc='upper left')
        plt.grid()
        plt.show()

    print("Evaluation:")
    loss, accuracy = model.evaluate(test_ds)

    response = confusion_matrix(test_ds, model, classes, plot = not silent, include_precision_recall = include_precision_recall)

    return (accuracy, *response)

def build_model(classes, init_learning_rate, complexity_scale = None, dense_regularizer_intensity = None, conv_regularizer_intensity = None):
    custom_model = models.Sequential([
        layers.BatchNormalization(input_shape = (20, 50, 1)),
        layers.Conv2D(16, 3, activation = 'relu', padding = 'same', kernel_regularizer = regularizers.l2(conv_regularizer_intensity)),
        layers.Conv2D(16, 3, activation = 'relu', padding = 'same', kernel_regularizer = regularizers.l2(conv_regularizer_intensity)),
        layers.MaxPooling2D(2, strides=2),

        layers.BatchNormalization(),
        layers.Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_regularizer = regularizers.l2(conv_regularizer_intensity)),
        layers.Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_regularizer = regularizers.l2(conv_regularizer_intensity)),
        layers.MaxPooling2D(2, strides=2),

        layers.BatchNormalization(),
        layers.Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_regularizer = regularizers.l2(conv_regularizer_intensity)),
        layers.Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_regularizer = regularizers.l2(conv_regularizer_intensity)),
        layers.MaxPooling2D(2, strides=2),

        layers.Flatten(),

        layers.Dense(256, activation = 'relu', kernel_regularizer = regularizers.l2(dense_regularizer_intensity)),
        layers.Dropout(0.5),
        layers.Dense(256, activation = 'relu', kernel_regularizer = regularizers.l2(dense_regularizer_intensity)),
        layers.Dropout(0.5),
        layers.Dense(len(classes), activation = 'softmax')
    ])

    custom_model.compile(
        loss = keras.losses.SparseCategoricalCrossentropy(),
        metrics = ['acc'],
        optimizer = keras.optimizers.Adam(learning_rate = init_learning_rate),
    )

    return custom_model

def create_and_run(dataset_nums_full_train, preprocessing_version, trimmed, batch_size, epochs, complexity_scale, dense_regularizer_intensity, conv_regularizer_intensity,
                   dataset_nums_test = None, plotting = False, init_learning_rate = 0.001, lr_scheduler = None, cross_validation_num = None,
                #    hide_partial_train = True
                   ):

    # if hide_partial_train is False:
    #     raise Exception('hide_partial_train is broken. leave it be...')
    # # in case someone forgets to pass an explicit partial_train before setting "hide_partial_train"
    # if hide_partial_train and dataset_nums_test is None :
    #     raise Exception("Can't hide partial train if it's not provided!")

    data, classes = npy_files_to_numpy(dataset_nums_full_train, preprocessing_version, from_chopped = trimmed) # no background, chopped

    print('merged data shape:', data.shape)
    subcarSpectrogram(data[3, 21]) # see a sample of the input data

    # data_explicit_test is only set when we pass something to "dataset_nums_test"
    data_explicit_test, _ = npy_files_to_numpy(dataset_nums_test, preprocessing_version, from_chopped = trimmed) if dataset_nums_test is not None else (None, None)

    if cross_validation_num is not None:
        if dataset_nums_test is not None: # prevent using explicit test data with cross-validation
            raise Exception("Can't use cross validation with explicit test datasets!!")

        kfold = sklearn.model_selection.KFold(n_splits=cross_validation_num)
        data, labels = single_dataset_from_non_flat_numpy(data, classes, batch_size, give_numpy = True) # this does the shuffling

        # Initialize lists to store evaluation results
        accuracy_scores = []
        confusion_matrices = []
        all_precisions = []
        all_recalls = []

        for fold, (train_index, test_index) in enumerate(kfold.split(data)):
            print(f"Fold: {fold+1}")

            # print('train_index', train_index)
            # print('test_index', test_index)

            # Split the data into train and test sets
            train_data = data[train_index]
            train_labels = labels[train_index]
            test_data = data[test_index]
            test_labels = labels[test_index]

            # create tf datasets
            train_ds = single_dataset_from_flattened_numpy(train_data, train_labels, batch_size)
            test_ds = single_dataset_from_flattened_numpy(test_data, test_labels, batch_size)

            # if I put this outside it will remember its weights
            custom_model = build_model(classes, init_learning_rate, complexity_scale, dense_regularizer_intensity = dense_regularizer_intensity, conv_regularizer_intensity = conv_regularizer_intensity)
            # accuracy, conf_matrix = train_and_evaluate(custom_model, train_ds, test_ds, classes, epochs, lr_scheduler=lr_scheduler, silent = True)
            accuracy, conf_matrix, precisions, recalls = train_and_evaluate(custom_model, train_ds, test_ds, classes, epochs, lr_scheduler=lr_scheduler, silent = True, include_precision_recall = True)

            print('accuracy: ', accuracy)
            print('precisions: ', precisions)
            print('recalls: ', recalls)
            print(f'conf matrix of fold {fold+1}: \n', conf_matrix)

            accuracy_scores.append(accuracy)
            all_precisions.append(precisions)
            all_recalls.append(recalls)
            confusion_matrices.append(conf_matrix)

        average_acc = np.mean(np.asarray(accuracy_scores))
        average_precisions = np.mean(np.asarray(all_precisions), axis = 0)
        average_recalls = np.mean(np.asarray(all_recalls), axis = 0)
        average_conf_matrix = np.mean(np.asarray(confusion_matrices), axis = 0)

        print('COPY AFTER HERE -----------------------------')
        print('average acc: ', average_acc)
        print('average precisions: ', average_precisions)
        print('average recalls: ', average_recalls)
        print('average conf matrix \n', average_conf_matrix)
        plot_confusion_matrix(average_conf_matrix, classes)

        return average_acc

    else: # if no cross validation, use validation dataset, and plot the accuracy + loss afterwards
        train_ratio = 0.8
        valid_ratio = 0.1
        custom_model = build_model(classes, init_learning_rate, complexity_scale, dense_regularizer_intensity = dense_regularizer_intensity, conv_regularizer_intensity = conv_regularizer_intensity)

        if data_explicit_test is not None:
            print('using selected test datasets as test data...')
            train_ds = single_dataset_from_non_flat_numpy(data, classes, batch_size) # CONFIRMED
            test_ds = single_dataset_from_non_flat_numpy(data_explicit_test, classes, batch_size) # CONFIRMED
            # if hide_partial_train:
            # else:
            #     print('using subset of partial train as test data...')
            #     train_ds = single_dataset_from_non_flat_numpy(data, classes, batch_size)
            #     more_train_and_test_ds = single_dataset_from_non_flat_numpy(data_explicit_test, classes, batch_size)
            #     more_train_ds, test_ds = sklearn.model_selection.train_test_split(more_train_and_test_ds, 0.4)
            #     train_ds = train_ds.concatenate(more_train_ds) # add some of partial train to train
            #     # both train_ds and test_ds should be set

            accuracy, conf_matrix = train_and_evaluate(custom_model, train_ds, test_ds, classes, epochs, lr_scheduler = lr_scheduler) # problem
        else:
            train_ds, valid_ds, test_ds = create_datasets(data, classes, train_ratio, valid_ratio, batch_size) # CONFIRMED
            accuracy, conf_matrix, precisions, recalls = train_and_evaluate(custom_model, train_ds, test_ds, classes, epochs, valid_ds = valid_ds, lr_scheduler = lr_scheduler, include_precision_recall = True)

def create_step_decay(drop = 0.8, update_interval = 10):
    def step_decay(epoch_num, lr):
        return lr * drop if epoch_num % update_interval == 0 else lr

    return step_decay

def create_exponential_decay(drop = 0.99, update_interval = 10):
    def exponential_decay(epoch_num, lr):
        return drop ** epoch_num * lr if epoch_num % update_interval == 0 else lr

    return exponential_decay

In [None]:
# impact_factors = {
#     'overall': {
#         '1': tuple(range(41, 53)),
#     },
#     'phones': {
#         '1': (47, 48), # white
#         '2': (49, 50), # black
#     },
#     'people': {
#         '1': (51, 52), # p1
#         '2': (60, 61), # p2
#         '3': (62, 63), # p3
#         '4': (64, 65), # p4
#         'combined': (51, 52, 60, 61, 62, 63, 64, 65),
#     },
#     'dist-AP': {
#         "60''": (45, 46),
#         "20''": (51, 52),
#     },
#     'dist-phone': {
#         "60''": (47, 48),
#         "20''": (51, 52),
#     },
#     'data-points': { # white, setup 3, p1
#         '1': (68,),
#         '2': (68, 69),
#         '3': (68, 69, 70),
#         '4': (68, 69, 70, 71),
#     }
# }

create_and_run(
    dataset_nums_full_train = tuple(range(41, 53)),
    # dataset_nums_test = (65,), 
    # hide_partial_train = False, # BROKEN
    preprocessing_version = 4,
    trimmed = True, # if you choose false be sure to change your input shape
    batch_size = 64,
    epochs = 150,
    complexity_scale = 3,
    dense_regularizer_intensity = 0.1, 
    conv_regularizer_intensity = 0, 
    init_learning_rate = 0.0005,
    # lr_scheduler = create_step_decay(drop = 0.8, update_interval = 10),
    cross_validation_num = 5,
)