# ProtoNN in Tensorflow

This is a simple notebook that illustrates the usage of Tensorflow implementation of ProtoNN. We are using the USPS dataset. Please refer to `fetch_usps.py` for more details on downloading the dataset.

In [1]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

from __future__ import print_function
import sys
import os
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

#sys.path.insert(0, '../../')
# from edgeml.trainer.protoNNTrainer import ProtoNNTrainer
# from edgeml.graph.protoNN import ProtoNN
# import edgeml.utils as utils
# import helpermethods as helper
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV


sys.path.append(r"D:\programming\practice\research\protoNN\EdgeML\examples\tf\ProtoNN")
import helpermethods as helper

Instructions for updating:
non-resource variables are not supported in the long term


# Helper Methods

In [2]:
#helper methods
sys.path.insert(0, '../')
import argparse


def getModelSize(matrixList, sparcityList, expected=True, bytesPerVar=4):
    '''
    expected: Expected size according to the parameters set. The number of
        zeros could actually be more than that is required to satisfy the
        sparsity constraint.
    '''
    nnzList, sizeList, isSparseList = [], [], []
    hasSparse = False
    for i in range(len(matrixList)):
        A, s = matrixList[i], sparcityList[i]
        assert A.ndim == 2
        assert s >= 0
        assert s <= 1
        nnz, size, sparse = countnnZ(A, s, bytesPerVar=bytesPerVar)
        nnzList.append(nnz)
        sizeList.append(size)
        hasSparse = (hasSparse or sparse)

    totalnnZ = np.sum(nnzList)
    totalSize = np.sum(sizeList)
    if expected:
        return totalnnZ, totalSize, hasSparse
    numNonZero = 0
    totalSize = 0
    hasSparse = False
    for i in range(len(matrixList)):
        A, s = matrixList[i], sparcityList[i]
        numNonZero_ = np.count_nonzero(A)
        numNonZero += numNonZero_
        hasSparse = (hasSparse or (s < 0.5))
        if s <= 0.5:
            totalSize += numNonZero_ * 2 * bytesPerVar
        else:
            totalSize += A.size * bytesPerVar
    return numNonZero, totalSize, hasSparse


def getGamma(gammaInit, projectionDim, dataDim, numPrototypes, x_train):
    if gammaInit is None:
        print("Using median heuristic to estimate gamma.")
        gamma, W, B = medianHeuristic(x_train, projectionDim,
                                            numPrototypes)
        print("Gamma estimate is: %f" % gamma)
        return W, B, gamma
    return None, None, gammaInit


def preprocessData(dataDir,w):
    '''
    Loads data from the dataDir and does some initial preprocessing
    steps. Data is assumed to be contained in two files,
    train.npy and test.npy. Each containing a 2D numpy array of dimension
    [numberOfExamples, numberOfFeatures + 1]. The first column of each
    matrix is assumed to contain label information.

    For an N-Class problem, we assume the labels are integers from 0 through
    N-1.
    '''
    # Uncomment for usual training data
    # train = np.load(dataDir + '/train_'+str(w)+'.npy')
    # test = np.load(dataDir + '/test_'+str(w)+'.npy')
    # Uncomment for time domain training data
    train = np.load(dataDir + '/ttrain_'+str(w)+'.npy')
    test = np.load(dataDir + '/ttest_'+str(w)+'.npy')
    # Uncomment for 1 sensordrop training data
    # train = np.load(dataDir + '/train_'+str(w)+'.npy')
    # test = np.load(dataDir + '/test_'+str(w)+'.npy')

    dataDimension = int(train.shape[1]) - 1
    x_train = train[:, 1:dataDimension + 1]
    y_train_ = train[:, 0]
    x_test = test[:, 1:dataDimension + 1]
    y_test_ = test[:, 0]

    numClasses = max(y_train_) - min(y_train_) + 1
    numClasses = max(numClasses, max(y_test_) - min(y_test_) + 1)
    numClasses = int(numClasses)

    # mean-var
    mean = np.mean(x_train, 0)
    std = np.std(x_train, 0)
    std[std[:] < 0.000001] = 1
    x_train = (x_train - mean) / std
    x_test = (x_test - mean) / std

    # one hot y-train
    lab = y_train_.astype('uint8')
    lab = np.array(lab) - min(lab)
    lab_ = np.zeros((x_train.shape[0], numClasses))
    lab_[np.arange(x_train.shape[0]), lab] = 1
    y_train = lab_

    # one hot y-test
    lab = y_test_.astype('uint8')
    lab = np.array(lab) - min(lab)
    lab_ = np.zeros((x_test.shape[0], numClasses))
    lab_[np.arange(x_test.shape[0]), lab] = 1
    y_test = lab_

    return dataDimension, numClasses, x_train, y_train, x_test, y_test



def getProtoNNArgs():
    def checkIntPos(value):
        ivalue = int(value)
        if ivalue <= 0:
            raise argparse.ArgumentTypeError(
                "%s is an invalid positive int value" % value)
        return ivalue

    def checkIntNneg(value):
        ivalue = int(value)
        if ivalue < 0:
            raise argparse.ArgumentTypeError(
                "%s is an invalid non-neg int value" % value)
        return ivalue

    def checkFloatNneg(value):
        fvalue = float(value)
        if fvalue < 0:
            raise argparse.ArgumentTypeError(
                "%s is an invalid non-neg float value" % value)
        return fvalue

    def checkFloatPos(value):
        fvalue = float(value)
        if fvalue <= 0:
            raise argparse.ArgumentTypeError(
                "%s is an invalid positive float value" % value)
        return fvalue

    '''
    Parse protoNN commandline arguments
    '''
    parser = argparse.ArgumentParser(
        description='Hyperparameters for ProtoNN Algorithm')

    msg = 'Data directory containing train and test data. The '
    msg += 'data is assumed to be saved as 2-D numpy matrices with '
    msg += 'names `train.npy` and `test.npy`, of dimensions\n'
    msg += '\t[numberOfInstances, numberOfFeatures + 1].\n'
    msg += 'The first column of each file is assumed to contain label information.'
    msg += ' For a N-class problem, labels are assumed to be integers from 0 to'
    msg += ' N-1 (inclusive).'
    parser.add_argument('-d', '--data-dir', required=True, help=msg)
    parser.add_argument('-l', '--projection-dim', type=checkIntPos, default=10,
                        help='Projection Dimension.')
    parser.add_argument('-p', '--num-prototypes', type=checkIntPos, default=20,
                        help='Number of prototypes.')
    parser.add_argument('-g', '--gamma', type=checkFloatPos, default=None,
                        help='Gamma for Gaussian kernel. If not provided, ' +
                        'median heuristic will be used to estimate gamma.')

    parser.add_argument('-e', '--epochs', type=checkIntPos, default=100,
                        help='Total training epochs.')
    parser.add_argument('-b', '--batch-size', type=checkIntPos, default=32,
                        help='Batch size for each pass.')
    parser.add_argument('-r', '--learning-rate', type=checkFloatPos,
                        default=0.001,
                        help='Initial Learning rate for ADAM Optimizer.')

    parser.add_argument('-rW', type=float, default=0.000,
                        help='Coefficient for l2 regularizer for predictor' +
                        ' parameter W ' + '(default = 0.0).')
    parser.add_argument('-rB', type=float, default=0.00,
                        help='Coefficient for l2 regularizer for predictor' +
                        ' parameter B ' + '(default = 0.0).')
    parser.add_argument('-rZ', type=float, default=0.00,
                        help='Coefficient for l2 regularizer for predictor' +
                        'parameter Z ' +
                        '(default = 0.0).')

    parser.add_argument('-sW', type=float, default=1.000,
                        help='Sparsity constraint for predictor parameter W ' +
                        '(default = 1.0, i.e. dense matrix).')
    parser.add_argument('-sB', type=float, default=1.00,
                        help='Sparsity constraint for predictor parameter B ' +
                        '(default = 1.0, i.e. dense matrix).')
    parser.add_argument('-sZ', type=float, default=1.00,
                        help='Sparsity constraint for predictor parameter Z ' +
                        '(default = 1.0, i.e. dense matrix).')
    parser.add_argument('-pS', '--print-step', type=int, default=200,
                        help='The number of update steps between print ' +
                        'calls to console.')
    parser.add_argument('-vS', '--val-step', type=int, default=3,
                        help='The number of epochs between validation' +
                        'performance evaluation')
    return parser.parse_args()

# Utils 

In [3]:
#utils
import scipy.cluster
import scipy.spatial
import os


def medianHeuristic(data, projectionDimension, numPrototypes, W_init=None):
    '''
    This method can be used to estimate gamma for ProtoNN. An approximation to
    median heuristic is used here.
    1. First the data is collapsed into the projectionDimension by W_init. If
    W_init is not provided, it is initialized from a random normal(0, 1). Hence
    data normalization is essential.
    2. Prototype are computed by running a  k-means clustering on the projected
    data.
    3. The median distance is then estimated by calculating median distance
    between prototypes and projected data points.

    data needs to be [-1, numFeats]
    If using this method to initialize gamma, please use the W and B as well.

    TODO: Return estimate of Z (prototype labels) based on cluster centroids
    andand labels

    TODO: Clustering fails due to singularity error if projecting upwards

    W [dxd_cap]
    B [d_cap, m]
    returns gamma, W, B
    '''
    assert data.ndim == 2
    X = data
    featDim = data.shape[1]
    if projectionDimension > featDim:
        print("Warning: Projection dimension > feature dimension. Gamma")
        print("\t estimation due to median heuristic could fail.")
        print("\tTo retain the projection dataDimension, provide")
        print("\ta value for gamma.")

    if W_init is None:
        W_init = np.random.normal(size=[featDim, projectionDimension])
    W = W_init
    XW = np.matmul(X, W)
    assert XW.shape[1] == projectionDimension
    assert XW.shape[0] == len(X)
    # Requires [N x d_cap] data matrix of N observations of d_cap-dimension and
    # the number of centroids m. Returns, [n x d_cap] centroids and
    # elementwise center information.
    B, centers = scipy.cluster.vq.kmeans2(XW, numPrototypes)
    # Requires two matrices. Number of observations x dimension of observation
    # space. Distances[i,j] is the distance between XW[i] and B[j]
    distances = scipy.spatial.distance.cdist(XW, B, metric='euclidean')
    distances = np.reshape(distances, [-1])
    gamma = np.median(distances)
    gamma = 1 / (2.5 * gamma)
    return gamma.astype('float32'), W.astype('float32'), B.T.astype('float32')


def multiClassHingeLoss(logits, label, batch_th):
    '''
    MultiClassHingeLoss to match C++ Version - No TF internal version
    '''
    flatLogits = tf.reshape(logits, [-1, ])
    label_ = tf.argmax(label, 1)

    correctId = tf.range(0, batch_th) * label.shape[1] + label_
    correctLogit = tf.gather(flatLogits, correctId)

    maxLabel = tf.argmax(logits, 1)
    top2, _ = tf.nn.top_k(logits, k=2, sorted=True)

    wrongMaxLogit = tf.where(
        tf.equal(maxLabel, label_), top2[:, 1], top2[:, 0])

    return tf.reduce_mean(tf.nn.relu(1. + wrongMaxLogit - correctLogit))


def crossEntropyLoss(logits, label):
    '''
    Cross Entropy loss for MultiClass case in joint training for
    faster convergence
    '''
    return tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                   labels=tf.stop_gradient(label)))


def mean_absolute_error(logits, label):
    '''
    Function to compute the mean absolute error.
    '''
    return tf.reduce_mean(tf.abs(tf.subtract(logits, label)))


def hardThreshold(A, s):
    '''
    Hard thresholding function on Tensor A with sparsity s
    '''
    A_ = np.copy(A)
    A_ = A_.ravel()
    if len(A_) > 0:
        th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher')
        A_[np.abs(A_) < th] = 0.0
    A_ = A_.reshape(A.shape)
    return A_


def copySupport(src, dest):
    '''
    copy support of src tensor to dest tensor
    '''
    support = np.nonzero(src)
    dest_ = dest
    dest = np.zeros(dest_.shape)
    dest[support] = dest_[support]
    return dest


def countnnZ(A, s, bytesPerVar=4):
    '''
    Returns # of non-zeros and representative size of the tensor
    Uses dense for s >= 0.5 - 4 byte
    Else uses sparse - 8 byte
    '''
    params = 1
    hasSparse = False
    for i in range(0, len(A.shape)):
        params *= int(A.shape[i])
    if s < 0.5:
        nnZ = np.ceil(params * s)
        hasSparse = True
        return nnZ, nnZ * 2 * bytesPerVar, hasSparse
    else:
        nnZ = params
        return nnZ, nnZ * bytesPerVar, hasSparse


def getConfusionMatrix(predicted, target, numClasses):
    '''
    Returns a confusion matrix for a multiclass classification
    problem. `predicted` is a 1-D array of integers representing
    the predicted classes and `target` is the target classes.

    confusion[i][j]: Number of elements of class j
        predicted as class i
    Labels are assumed to be in range(0, numClasses)
    Use`printFormattedConfusionMatrix` to echo the confusion matrix
    in a user friendly form.
    '''
    assert(predicted.ndim == 1)
    assert(target.ndim == 1)
    arr = np.zeros([numClasses, numClasses])

    for i in range(len(predicted)):
        arr[predicted[i]][target[i]] += 1
    return arr


def printFormattedConfusionMatrix(matrix):
    '''
    Given a 2D confusion matrix, prints it in a human readable way.
    The confusion matrix is expected to be a 2D numpy array with
    square dimensions
    '''
    assert(matrix.ndim == 2)
    assert(matrix.shape[0] == matrix.shape[1])
    RECALL = 'Recall'
    PRECISION = 'PRECISION'
    print("|%s|" % ('True->'), end='')
    for i in range(matrix.shape[0]):
        print("%7d|" % i, end='')
    print("%s|" % 'Precision')

    print("|%s|" % ('-' * len(RECALL)), end='')
    for i in range(matrix.shape[0]):
        print("%s|" % ('-' * 7), end='')
    print("%s|" % ('-' * len(PRECISION)))

    precisionlist = np.sum(matrix, axis=1)
    recalllist = np.sum(matrix, axis=0)
    precisionlist = [matrix[i][i] / x if x !=
                     0 else -1 for i, x in enumerate(precisionlist)]
    recalllist = [matrix[i][i] / x if x !=
                  0 else -1 for i, x in enumerate(recalllist)]
    for i in range(matrix.shape[0]):
        # len recall = 6
        print("|%6d|" % (i), end='')
        for j in range(matrix.shape[0]):
            print("%7d|" % (matrix[i][j]), end='')
        print("%s" % (" " * (len(PRECISION) - 7)), end='')
        if precisionlist[i] != -1:
            print("%1.5f|" % precisionlist[i])
        else:
            print("%7s|" % "nan")

    print("|%s|" % ('-' * len(RECALL)), end='')
    for i in range(matrix.shape[0]):
        print("%s|" % ('-' * 7), end='')
    print("%s|" % ('-' * len(PRECISION)))
    print("|%s|" % ('Recall'), end='')

    for i in range(matrix.shape[0]):
        if recalllist[i] != -1:
            print("%1.5f|" % (recalllist[i]), end='')
        else:
            print("%7s|" % "nan", end='')

    print('%s|' % (' ' * len(PRECISION)))


def getPrecisionRecall(cmatrix, label=1):
    trueP = cmatrix[label][label]
    denom = np.sum(cmatrix, axis=0)[label]
    if denom == 0:
        denom = 1
    recall = trueP / denom
    denom = np.sum(cmatrix, axis=1)[label]
    if denom == 0:
        denom = 1
    precision = trueP / denom
    return precision, recall


def getMacroPrecisionRecall(cmatrix):
    # TP + FP
    precisionlist = np.sum(cmatrix, axis=1)
    # TP + FN
    recalllist = np.sum(cmatrix, axis=0)
    precisionlist__ = [cmatrix[i][i] / x if x !=
                       0 else 0 for i, x in enumerate(precisionlist)]
    recalllist__ = [cmatrix[i][i] / x if x !=
                    0 else 0 for i, x in enumerate(recalllist)]
    precision = np.sum(precisionlist__)
    precision /= len(precisionlist__)
    recall = np.sum(recalllist__)
    recall /= len(recalllist__)
    return precision, recall


def getMicroPrecisionRecall(cmatrix):
    # TP + FP
    precisionlist = np.sum(cmatrix, axis=1)
    # TP + FN
    recalllist = np.sum(cmatrix, axis=0)
    num = 0.0
    for i in range(len(cmatrix)):
        num += cmatrix[i][i]

    precision = num / np.sum(precisionlist)
    recall = num / np.sum(recalllist)
    return precision, recall


def getMacroMicroFScore(cmatrix):
    '''
    Returns macro and micro f-scores.
    Refer: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.104.8244&rep=rep1&type=pdf
    '''
    precisionlist = np.sum(cmatrix, axis=1)
    recalllist = np.sum(cmatrix, axis=0)
    precisionlist__ = [cmatrix[i][i] / x if x !=
                       0 else 0 for i, x in enumerate(precisionlist)]
    recalllist__ = [cmatrix[i][i] / x if x !=
                    0 else 0 for i, x in enumerate(recalllist)]
    macro = 0.0
    for i in range(len(precisionlist)):
        denom = precisionlist__[i] + recalllist__[i]
        numer = precisionlist__[i] * recalllist__[i] * 2
        if denom == 0:
            denom = 1
        macro += numer / denom
    macro /= len(precisionlist)

    num = 0.0
    for i in range(len(precisionlist)):
        num += cmatrix[i][i]

    denom1 = np.sum(precisionlist)
    denom2 = np.sum(recalllist)
    pi = num / denom1
    rho = num / denom2
    denom = pi + rho
    if denom == 0:
        denom = 1
    micro = 2 * pi * rho / denom
    return macro, micro


class GraphManager:
    '''
    Manages saving and restoring graphs. Designed to be used with EMI-RNN
    though is general enough to be useful otherwise as well.
    '''

    def __init__(self):
        pass

    def checkpointModel(self, saver, sess, modelPrefix,
                        globalStep=1000, redirFile=None):
        saver.save(sess, modelPrefix, global_step=globalStep)
        print('Model saved to %s, global_step %d' % (modelPrefix, globalStep),
              file=redirFile)

    def loadCheckpoint(self, sess, modelPrefix, globalStep,
                       redirFile=None):
        metaname = modelPrefix + '-%d.meta' % globalStep
        basename = os.path.basename(metaname)
        fileList = os.listdir(os.path.dirname(modelPrefix))
        fileList = [x for x in fileList if x.startswith(basename)]
        assert len(fileList) > 0, 'Checkpoint file not found'
        msg = 'Too many or too few checkpoint files for globalStep: %d' % globalStep
        assert len(fileList) is 1, msg
        chkpt = basename + '/' + fileList[0]
        saver = tf.train.import_meta_graph(metaname)
        metaname = metaname[:-5]
        saver.restore(sess, metaname)
        graph = tf.get_default_graph()
        return graph

# Model Trainer - ProtoNN

In [4]:
#Trainer
class ProtoNNTrainer:
    def __init__(self, protoNNObj, regW, regB, regZ,
                 sparcityW, sparcityB, sparcityZ,
                 learningRate, X, Y, lossType='l2'):
        '''
        A wrapper for the various techniques used for training ProtoNN. This
        subsumes both the responsibility of loss graph construction and
        performing training. The original training routine that is part of the
        C++ implementation of EdgeML used iterative hard thresholding (IHT),
        gamma estimation through median heuristic and other tricks for
        training ProtoNN. This module implements the same in Tensorflow
        and python.

        protoNNObj: An instance of ProtoNN class defining the forward
            computation graph. The loss functions and training routines will be
            attached to this instance.
        regW, regB, regZ: Regularization constants for W, B, and
            Z matrices of protoNN.
        sparcityW, sparcityB, sparcityZ: Sparsity constraints
            for W, B and Z matrices. A value between 0 (exclusive) and 1
            (inclusive) is expected. A value of 1 indicates dense training.
        learningRate: Initial learning rate for ADAM optimizer.
        X, Y : Placeholders for data and labels.
            X [-1, featureDimension]
            Y [-1, num Labels]
        lossType: ['l2', 'xentropy']
        '''
        self.protoNNObj = protoNNObj
        self.__regW = regW
        self.__regB = regB
        self.__regZ = regZ
        self.__sW = sparcityW
        self.__sB = sparcityB
        self.__sZ = sparcityZ
        self.__lR = learningRate
        self.X = X
        self.Y = Y
        self.sparseTraining = True
        if (sparcityW == 1.0) and (sparcityB == 1.0) and (sparcityZ == 1.0):
            self.sparseTraining = False
            print("Sparse training disabled.", file=sys.stderr)
        # Define placeholders for sparse training
        self.W_th = None
        self.B_th = None
        self.Z_th = None
        self.__lossType = lossType
        self.__validInit = False
        self.__validInit = self.__validateInit()
        self.__protoNNOut = protoNNObj(X, Y)
        self.loss = self.__lossGraph()
        self.trainStep = self.__trainGraph()
        self.__hthOp = self.__getHardThresholdOp()
        self.accuracy = protoNNObj.getAccuracyOp()

    def __validateInit(self):
        self.__validInit = False
        msg = "Sparsity value should be between"
        msg += " 0 and 1 (both inclusive)."
        assert self.__sW >= 0. and self.__sW <= 1., 'W:' + msg
        assert self.__sB >= 0. and self.__sB <= 1., 'B:' + msg
        assert self.__sZ >= 0. and self.__sZ <= 1., 'Z:' + msg
        d, dcap, m, L, _ = self.protoNNObj.getHyperParams()
        msg = 'Y should be of dimension [-1, num labels/classes]'
        msg += ' specified as part of ProtoNN object.'
        assert (len(self.Y.shape)) == 2, msg
        assert (self.Y.shape[1] == L), msg
        msg = 'X should be of dimension [-1, featureDimension]'
        msg += ' specified as part of ProtoNN object.'
        assert (len(self.X.shape) == 2), msg
        assert (self.X.shape[1] == d), msg
        self.__validInit = True
        msg = 'Values can be \'l2\', or \'xentropy\''
        if self.__lossType not in ['l2', 'xentropy']:
            raise ValueError(msg)
        return True

    def __lossGraph(self):
        pnnOut = self.__protoNNOut
        l1, l2, l3 = self.__regW, self.__regB, self.__regZ
        W, B, Z, _ = self.protoNNObj.getModelMatrices()
        if self.__lossType == 'l2':
            with tf.name_scope('protonn-l2-loss'):
                loss_0 = tf.nn.l2_loss(self.Y - pnnOut)
                reg = l1 * tf.nn.l2_loss(W) + l2 * tf.nn.l2_loss(B)
                reg += l3 * tf.nn.l2_loss(Z)
                loss = loss_0 + reg
        elif self.__lossType == 'xentropy':
            with tf.name_scope('protonn-xentropy-loss'):
                loss_0 = tf.nn.softmax_cross_entropy_with_logits_v2(logits=pnnOut,
                                                         labels=tf.stop_gradient(self.Y))
                loss_0 = tf.reduce_mean(loss_0)
                reg = l1 * tf.nn.l2_loss(W) + l2 * tf.nn.l2_loss(B)
                reg += l3 * tf.nn.l2_loss(Z)
                loss = loss_0 + reg
        return loss

    def __trainGraph(self):
        with tf.name_scope('protonn-gradient-adam'):
            trainStep = tf.train.AdamOptimizer(self.__lR)
            trainStep = trainStep.minimize(self.loss)
        return trainStep

    def __getHardThresholdOp(self):
        W, B, Z, _ = self.protoNNObj.getModelMatrices()
        self.W_th = tf.placeholder(tf.float32, name='W_th')
        self.B_th = tf.placeholder(tf.float32, name='B_th')
        self.Z_th = tf.placeholder(tf.float32, name='Z_th')
        with tf.name_scope('hard-threshold-assignments'):
            hard_thrsd_W = W.assign(self.W_th)
            hard_thrsd_B = B.assign(self.B_th)
            hard_thrsd_Z = Z.assign(self.Z_th)
            hard_thrsd_op = tf.group(hard_thrsd_W, hard_thrsd_B, hard_thrsd_Z)
        return hard_thrsd_op

    def train(self, batchSize, totalEpochs, sess,
              x_train, x_val, y_train, y_val, noInit=False,
              redirFile=None, printStep=10, valStep=3):
        '''
        Performs dense training of ProtoNN followed by iterative hard
        thresholding to enforce sparsity constraints.

        batchSize: Batch size per update
        totalEpochs: The number of epochs to run training for. One epoch is
            defined as one pass over the entire training data.
        sess: The Tensorflow session to use for running various graph
            operators.
        x_train, x_val, y_train, y_val: The numpy array containing train and
            validation data. x data is assumed to in of shape [-1,
            featureDimension] while y should have shape [-1, numberLabels].
        noInit: By default, all the tensors of the computation graph are
        initialized at the start of the training session. Set noInit=False to
        disable this behaviour.
        printStep: Number of batches between echoing of loss and train accuracy.
        valStep: Number of epochs between evolutions on validation set.
        '''
        d, d_cap, m, L, gamma = self.protoNNObj.getHyperParams()
        assert batchSize >= 1, 'Batch size should be positive integer'
        assert totalEpochs >= 1, 'Total epochs should be positive integer'
        assert x_train.ndim == 2, 'Expected training data to be of rank 2'
        assert x_train.shape[1] == d, 'Expected x_train to be [-1, %d]' % d
        assert x_val.ndim == 2, 'Expected validation data to be of rank 2'
        assert x_val.shape[1] == d, 'Expected x_val to be [-1, %d]' % d
        assert y_train.ndim == 2, 'Expected training labels to be of rank 2'
        assert y_train.shape[1] == L, 'Expected y_train to be [-1, %d]' % L
        assert y_val.ndim == 2, 'Expected validation labels to be of rank 2'
        assert y_val.shape[1] == L, 'Expected y_val to be [-1, %d]' % L

        # Numpy will throw asserts for arrays
        if sess is None:
            raise ValueError('sess must be valid Tensorflow session.')

        trainNumBatches = int(np.ceil(len(x_train) / batchSize))
        valNumBatches = int(np.ceil(len(x_val) / batchSize))
        x_train_batches = np.array_split(x_train, trainNumBatches)
        y_train_batches = np.array_split(y_train, trainNumBatches)
        x_val_batches = np.array_split(x_val, valNumBatches)
        y_val_batches = np.array_split(y_val, valNumBatches)
        if not noInit:
            sess.run(tf.global_variables_initializer())
        X, Y = self.X, self.Y
        W, B, Z, _ = self.protoNNObj.getModelMatrices()
        for epoch in range(totalEpochs):
            for i in range(len(x_train_batches)):
                batch_x = x_train_batches[i]
                batch_y = y_train_batches[i]
                feed_dict = {
                    X: batch_x,
                    Y: batch_y
                }
                sess.run(self.trainStep, feed_dict=feed_dict)
                if i % printStep == 0:
                    loss, acc = sess.run([self.loss, self.accuracy],
                                         feed_dict=feed_dict)
                    msg = "Epoch: %3d Batch: %3d" % (epoch, i)
                    msg += " Loss: %3.5f Accuracy: %2.5f" % (loss, acc)
                    print(msg, file=redirFile)

            # Perform Hard thresholding
            if self.sparseTraining:
                W_, B_, Z_ = sess.run([W, B, Z])
                fd_thrsd = {
                    self.W_th: hardThreshold(W_, self.__sW),
                    self.B_th: hardThreshold(B_, self.__sB),
                    self.Z_th: hardThreshold(Z_, self.__sZ)
                }
                sess.run(self.__hthOp, feed_dict=fd_thrsd)

            if (epoch + 1) % valStep  == 0:
                acc = 0.0
                loss = 0.0
                for j in range(len(x_val_batches)):
                    batch_x = x_val_batches[j]
                    batch_y = y_val_batches[j]
                    feed_dict = {
                        X: batch_x,
                        Y: batch_y
                    }
                    acc_, loss_ = sess.run([self.accuracy, self.loss],
                                           feed_dict=feed_dict)
                    acc += acc_
                    loss += loss_
                acc /= len(y_val_batches)
                loss /= len(y_val_batches)
                print("Test Loss: %2.5f Accuracy: %2.5f" % (loss, acc))


# Model Graph - ProtoNN

In [5]:

class ProtoNN:
    def __init__(self, inputDimension, projectionDimension, numPrototypes,
                 numOutputLabels, gamma,
                 W = None, B = None, Z = None):
        '''
        Forward computation graph for ProtoNN.

        inputDimension: Input data dimension or feature dimension.
        projectionDimension: hyperparameter
        numPrototypes: hyperparameter
        numOutputLabels: The number of output labels or classes
        W, B, Z: Numpy matrices that can be used to initialize
            projection matrix(W), prototype matrix (B) and prototype labels
            matrix (B).
            Expected Dimensions:
                W   inputDimension (d) x projectionDimension (d_cap)
                B   projectionDimension (d_cap) x numPrototypes (m)
                Z   numOutputLabels (L) x numPrototypes (m)
        '''
        with tf.name_scope('protoNN') as ns:
            self.__nscope = ns
        self.__d = inputDimension
        self.__d_cap = projectionDimension
        self.__m = numPrototypes
        self.__L = numOutputLabels

        self.__inW = W
        self.__inB = B
        self.__inZ = Z
        self.__inGamma = gamma
        self.W, self.B, self.Z = None, None, None
        self.gamma = None

        self.__validInit = False
        self.__initWBZ()
        self.__initGamma()
        self.__validateInit()
        self.protoNNOut = None
        self.predictions = None
        self.accuracy = None

    def __validateInit(self):
        self.__validInit = False
        errmsg = "Dimensions mismatch! Should be W[d, d_cap]"
        errmsg += ", B[d_cap, m] and Z[L, m]"
        d, d_cap, m, L, _ = self.getHyperParams()
        assert self.W.shape[0] == d, errmsg
        assert self.W.shape[1] == d_cap, errmsg
        assert self.B.shape[0] == d_cap, errmsg
        assert self.B.shape[1] == m, errmsg
        assert self.Z.shape[0] == L, errmsg
        assert self.Z.shape[1] == m, errmsg
        self.__validInit = True

    def __initWBZ(self):
        with tf.name_scope(self.__nscope):
            W = self.__inW
            if W is None:
                W = tf.random_normal_initializer()
                W = W([self.__d, self.__d_cap])
            self.W = tf.Variable(W, name='W', dtype=tf.float32)

            B = self.__inB
            if B is None:
                B = tf.random_uniform_initializer()
                B = B([self.__d_cap, self.__m])
            self.B = tf.Variable(B, name='B', dtype=tf.float32)

            Z = self.__inZ
            if Z is None:
                Z = tf.random_normal_initializer()
                Z = Z([self.__L, self.__m])
            Z = tf.Variable(Z, name='Z', dtype=tf.float32)
            self.Z = Z
        return self.W, self.B, self.Z

    def __initGamma(self):
        with tf.name_scope(self.__nscope):
            gamma = self.__inGamma
            self.gamma = tf.constant(gamma, name='gamma')

    def getHyperParams(self):
        '''
        Returns the model hyperparameters:
            [inputDimension, projectionDimension,
            numPrototypes, numOutputLabels, gamma]
        '''
        d = self.__d
        dcap = self.__d_cap
        m = self.__m
        L = self.__L
        return d, dcap, m, L, self.gamma

    def getModelMatrices(self):
        '''
        Returns Tensorflow tensors of the model matrices, which
        can then be evaluated to obtain corresponding numpy arrays.

        These can then be exported as part of other implementations of
        ProtonNN, for instance a C++ implementation or pure python
        implementation.
        Returns
            [ProjectionMatrix (W), prototypeMatrix (B),
             prototypeLabelsMatrix (Z), gamma]
        '''
        return self.W, self.B, self.Z, self.gamma

    def __call__(self, X, Y=None):
        '''
        This method is responsible for construction of the forward computation
        graph. The end point of the computation graph, or in other words the
        output operator for the forward computation is returned. Additionally,
        if the argument Y is provided, a classification accuracy operator with
        Y as target will also be created. For this, Y is assumed to in one-hot
        encoded format and the class with the maximum prediction score is
        compared to the encoded class in Y.  This accuracy operator is returned
        by getAccuracyOp() method. If a different accuracyOp is required, it
        can be defined by overriding the createAccOp(protoNNScoresOut, Y)
        method.

        X: Input tensor or placeholder of shape [-1, inputDimension]
        Y: Optional tensor or placeholder for targets (labels or classes).
            Expected shape is [-1, numOutputLabels].
        returns: The forward computation outputs, self.protoNNOut
        '''
        # This should never execute
        assert self.__validInit is True, "Initialization failed!"
        if self.protoNNOut is not None:
            return self.protoNNOut

        W, B, Z, gamma = self.W, self.B, self.Z, self.gamma
        with tf.name_scope(self.__nscope):
            WX = tf.matmul(X, W)
            # Convert WX to tensor so that broadcasting can work
            dim = [-1, WX.shape.as_list()[1], 1]
            WX = tf.reshape(WX, dim)
            dim = [1, B.shape.as_list()[0], -1]
            B = tf.reshape(B, dim)
            l2sim = B - WX
            l2sim = tf.pow(l2sim, 2)
            l2sim = tf.reduce_sum(l2sim, 1, keepdims=True)
            self.l2sim = l2sim
            gammal2sim = (-1 * gamma * gamma) * l2sim
            M = tf.exp(gammal2sim)
            dim = [1] + Z.shape.as_list()
            Z = tf.reshape(Z, dim)
            y = tf.multiply(Z, M)
            y = tf.reduce_sum(y, 2, name='protoNNScoreOut')
            self.protoNNOut = y
            self.predictions = tf.argmax(y, 1, name='protoNNPredictions')
            if Y is not None:
                self.createAccOp(self.protoNNOut, Y)
        return y

    def createAccOp(self, outputs, target):
        '''
        Define an accuracy operation on ProtoNN's output scores and targets.
        Here a simple classification accuracy operator is defined. More
        complicated operators (for multiple label problems and so forth) can be
        defined by overriding this method
        '''
        assert self.predictions is not None
        target = tf.argmax(target, 1)
        correctPrediction = tf.equal(self.predictions, target)
        acc = tf.reduce_mean(tf.cast(correctPrediction, tf.float32),
                             name='protoNNAccuracy')
        self.accuracy = acc

    def getPredictionsOp(self):
        '''
        The predictions operator is defined as argmax(protoNNScores) for each
        prediction.
        '''
        return self.predictions

    def getAccuracyOp(self):
        '''
        returns accuracyOp as defined by createAccOp. It defaults to
        multi-class classification accuracy.
        '''
        msg = "Accuracy operator not defined in graph. Did you provide Y as an"
        msg += " argument to _call_?"
        assert self.accuracy is not None, msg
        return self.accuracy

# Obtain Data

It is assumed that the Daphnet data has already been downloaded,preprocessed and set up in subdirectory.

In [6]:
DATA_DIR = r"./experiments"
windowLen = 'data'
out = preprocessData(DATA_DIR,windowLen)
dataDimension = out[0]
numClasses = out[1]
x_train, y_train = out[2], out[3]
x_test, y_test = out[4], out[5]
print("Feature Dimension: ", dataDimension)
print("Num classes: ", numClasses)

Feature Dimension:  423
Num classes:  2


In [7]:
DATA_DIR = r"./experiments"
train, test = np.load(DATA_DIR + '/ttrain_data.npy'), np.load(DATA_DIR + '/ttest_data.npy')
x_train, y_train = train[:, 1:], train[:, 0]
x_test, y_test = test[:, 1:], test[:, 0]

numClasses = max(y_train) - min(y_train) + 1
numClasses = max(numClasses, max(y_test) - min(y_test) + 1)
numClasses = int(numClasses)

y_train = helper.to_onehot(y_train, numClasses)
y_test = helper.to_onehot(y_test, numClasses)
dataDimension = x_train.shape[1]
numClasses = y_train.shape[1]

# Model Parameters

Note that ProtoNN is very sensitive to the value of the hyperparameter $\gamma$, here stored in valiable `GAMMA`. If `GAMMA` is set to `None`, median heuristic will be used to estimate a good value of $\gamma$ through the `helper.getGamma()` method. This method also returns the corresponding `W` and `B` matrices which should be used to initialize ProtoNN (as is done here).

In [8]:
PROJECTION_DIM = 5 #d^
NUM_PROTOTYPES = 40 #m
REG_W = 0.000005
REG_B = 0.0
REG_Z = 0.00005
SPAR_W = 1.0
SPAR_B = 0.8
SPAR_Z = 0.8
LEARNING_RATE = 0.001
NUM_EPOCHS = 600
BATCH_SIZE = 32
GAMMA = 0.007586

In [9]:
W, B, gamma = getGamma(GAMMA, PROJECTION_DIM, dataDimension,
                       NUM_PROTOTYPES, x_train)

In [10]:
X = tf.placeholder(tf.float32, [None, dataDimension], name='X')
Y = tf.placeholder(tf.float32, [None, numClasses], name='Y')

In [11]:
from sklearn.model_selection import cross_val_score
from sklearn.metrics import confusion_matrix,classification_report
from functools import partial

X = tf.placeholder(tf.float32, [None, dataDimension], name='X')
Y = tf.placeholder(tf.float32, [None, numClasses], name='Y')
def objective(trial,x_train, x_test, y_train, y_test):
    W, B, gamma = getGamma(GAMMA, PROJECTION_DIM, dataDimension,
                       NUM_PROTOTYPES, x_train)
    # Inside the optimization function, you use the 'trial' object to suggest hyperparameters
    REG_W = trial.suggest_float('REG_W', 2e-6, 5e-6)
    REG_B = trial.suggest_float('REG_B', 0.0, 0.01)
    REG_Z = trial.suggest_float('REG_Z', 2e-5, 5e-5)
    SPAR_W = trial.suggest_float('SPAR_W', 0.5, 1.0)
    SPAR_B = trial.suggest_float('SPAR_B', 0.5, 1.0)
    SPAR_Z = trial.suggest_float('SPAR_Z', 0.5, 1.0)
    LEARNING_RATE = trial.suggest_float('LEARNING_RATE', 1e-4, 1e-3)
    NUM_EPOCHS = trial.suggest_int('NUM_EPOCHS', 200, 600)

    # Set the suggested hyperparameters in the trainer
    protoNN = ProtoNN(dataDimension, PROJECTION_DIM,
                  NUM_PROTOTYPES, numClasses,
                  gamma, W=W, B=B)
    trainer = ProtoNNTrainer(protoNN, REG_W, REG_B, REG_Z,
                         SPAR_W, SPAR_B, SPAR_Z,
                         LEARNING_RATE, X, Y, lossType='xentropy')
    # Call your ProtoNN trainer function or use it as needed
    sess = tf.Session()

    trainer.train(BATCH_SIZE, NUM_EPOCHS, sess, x_train, x_test, y_train, y_test,printStep=600, valStep=10)
    acc = sess.run(protoNN.accuracy, feed_dict={X: x_test, Y: y_test})
    pred = sess.run(protoNN.predictions, feed_dict={X: x_test, Y: y_test})
    # W, B, Z are tensorflow graph nodes
    W, B, Z, _ = protoNN.getModelMatrices()
    matrixList = sess.run([W, B, Z])
    sparcityList = [SPAR_W, SPAR_B, SPAR_Z]                       
    nnz, size, sparse = getModelSize(matrixList, sparcityList)
    y_test = np.argmax(y_test,axis=1)
    sensitivity = confusion_matrix(y_test,pred)[1][1]/(confusion_matrix(y_test,pred)[1][1] + confusion_matrix(y_test,pred)[1][0])
    specificity = confusion_matrix(y_test,pred)[0][0]/(confusion_matrix(y_test,pred)[0][0] + confusion_matrix(y_test,pred)[0][1])
    return (sensitivity+specificity)/2



In [12]:
import optuna
study = optuna.create_study(direction='maximize')


[I 2023-11-17 09:40:02,802] A new study created in memory with name: no-name-58267e5c-7498-40f8-9415-5f6bbf07060c


In [13]:
op_fun = partial(objective,x_train=x_train, x_test=x_test, y_train=y_train, y_test=y_test)
study.optimize(op_fun,n_trials=20)

Epoch:   0 Batch:   0 Loss: 0.18710 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.50604 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 2.22825 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.40010 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.44677 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.46837 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.47680 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.47562 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.46665 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.45089 Accuracy: 0.00000
Test Loss: 1.26358 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.42958 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.40142 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.36982 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.33277 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.29247 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.24897 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.20095 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.15023 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.79672 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.79645 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.79620 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.79594 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.79570 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.79545 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.79523 Accuracy: 0.00000
Test Loss: 0.71058 Accuracy: 0.50060
Epoch: 150 Batch:   0 Loss: 0.79504 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.79483 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.79464 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.79444 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.79425 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.79409 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.79392 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.79375 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.79356 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.79340 Accuracy: 0.00000
Test Loss: 0.70650 Accuracy: 0.50060
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.77308 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.77249 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.77193 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.77138 Accuracy: 0.00000
Test Loss: 0.69292 Accuracy: 0.50050
Epoch: 290 Batch:   0 Loss: 0.77088 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.77039 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.76993 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.76949 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.76907 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.76867 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.76827 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.76791 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.76756 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.76722 Accuracy: 0.00000
Test Loss: 0.69396 Accuracy: 0.50050
Epoch: 300 Batch:   0 Loss: 0.76690 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.76661 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.76632 Accuracy: 0.00000
Epoch: 3

[I 2023-11-17 09:43:02,041] Trial 0 finished with value: 0.5003015681544029 and parameters: {'REG_W': 4.4063589336745566e-06, 'REG_B': 0.00550086588574879, 'REG_Z': 2.738256523855445e-05, 'SPAR_W': 0.7924326708349763, 'SPAR_B': 0.9252195803215768, 'SPAR_Z': 0.9738285286612246, 'LEARNING_RATE': 0.0005235324510491807, 'NUM_EPOCHS': 311}. Best is trial 0 with value: 0.5003015681544029.


Epoch:   0 Batch:   0 Loss: 0.23718 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 2.83710 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 2.79966 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.79775 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.79150 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.77877 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.76072 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.73840 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.71271 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.68440 Accuracy: 0.00000
Test Loss: 1.38927 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.65402 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.62198 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.58844 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.55356 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.51738 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.48008 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.44126 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.39951 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.81399 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.81382 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.81366 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.81348 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.81331 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.81320 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.81301 Accuracy: 0.00000
Test Loss: 0.76047 Accuracy: 0.50070
Epoch: 150 Batch:   0 Loss: 0.81281 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.81261 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.81242 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.81225 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.81209 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.81193 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.81178 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.81169 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.81153 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.81142 Accuracy: 0.00000
Test Loss: 0.75660 Accuracy: 0.50070
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.78019 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.78001 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.77984 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.77967 Accuracy: 0.00000
Test Loss: 0.74181 Accuracy: 0.50040
Epoch: 290 Batch:   0 Loss: 0.77950 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.77934 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.77918 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.77903 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.77888 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.77873 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.77859 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.77845 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.77832 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.77819 Accuracy: 0.00000
Test Loss: 0.74275 Accuracy: 0.50040
Epoch: 300 Batch:   0 Loss: 0.77807 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.77795 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.77784 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.77178 Accuracy: 0.00000
Test Loss: 0.75307 Accuracy: 0.50060
Epoch: 430 Batch:   0 Loss: 0.77175 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.77172 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.77170 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.77167 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.77165 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.77162 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.77160 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.77157 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.77155 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.77153 Accuracy: 0.00000
Test Loss: 0.75357 Accuracy: 0.50060
Epoch: 440 Batch:   0 Loss: 0.77151 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.77149 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.77147 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.77145 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.77143 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.77141 Accuracy: 0.00000
Epoch: 4

[I 2023-11-17 09:48:31,299] Trial 1 finished with value: 0.5004020908725372 and parameters: {'REG_W': 4.989016392035679e-06, 'REG_B': 0.006285881469037077, 'REG_Z': 4.620919544050023e-05, 'SPAR_W': 0.5580069993122618, 'SPAR_B': 0.9961478147253725, 'SPAR_Z': 0.8006718659290757, 'LEARNING_RATE': 0.0006325314401711697, 'NUM_EPOCHS': 495}. Best is trial 1 with value: 0.5004020908725372.


Epoch:   0 Batch:   0 Loss: 3.20639 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 3.24029 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 2.98119 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.89997 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.86077 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.82731 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.78818 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.74154 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.68951 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.63373 Accuracy: 0.00000
Test Loss: 1.46175 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.57327 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.50877 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.44131 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.37155 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.29998 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.22805 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.15471 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.08098 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.77240 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.77231 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.77220 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.77209 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.77198 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.77187 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.77177 Accuracy: 0.00000
Test Loss: 0.74330 Accuracy: 0.50100
Epoch: 150 Batch:   0 Loss: 0.77167 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.77160 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.77152 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.77142 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.77132 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.77123 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.77117 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.77109 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.77102 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.77096 Accuracy: 0.00000
Test Loss: 0.73930 Accuracy: 0.50100
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.74867 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.74858 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.74849 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.74839 Accuracy: 0.00000
Test Loss: 0.72228 Accuracy: 0.50080
Epoch: 290 Batch:   0 Loss: 0.74830 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.74821 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.74812 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.74803 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.74795 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.74786 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.74778 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.74770 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.74763 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.74755 Accuracy: 0.00000
Test Loss: 0.72458 Accuracy: 0.50090
Epoch: 300 Batch:   0 Loss: 0.74748 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.74741 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.74734 Accuracy: 0.00000
Epoch: 3

[I 2023-11-17 09:51:26,586] Trial 2 finished with value: 0.5008041817450744 and parameters: {'REG_W': 2.6045101624763913e-06, 'REG_B': 0.0008140312513964287, 'REG_Z': 2.905458883387219e-05, 'SPAR_W': 0.7803300030330671, 'SPAR_B': 0.5086272015109017, 'SPAR_Z': 0.6420072077306945, 'LEARNING_RATE': 0.000710503022728103, 'NUM_EPOCHS': 328}. Best is trial 2 with value: 0.5008041817450744.


Epoch:   0 Batch:   0 Loss: 0.31497 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.18979 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 0.91891 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.00794 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.25683 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.32840 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.36272 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.38367 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.39622 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.40205 Accuracy: 0.00000
Test Loss: 1.20858 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.40230 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.39724 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.38771 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.37483 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.35785 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.33712 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.31276 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.28374 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.79946 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.79907 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.79869 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.79831 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.79793 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.79757 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.79721 Accuracy: 0.00000
Test Loss: 0.71377 Accuracy: 0.50080
Epoch: 150 Batch:   0 Loss: 0.79685 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.79650 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.79617 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.79583 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.79551 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.79516 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.79484 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.79453 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.79422 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.79391 Accuracy: 0.00000
Test Loss: 0.70987 Accuracy: 0.50080
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.77899 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.77868 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.77837 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.77807 Accuracy: 0.00000
Test Loss: 0.69874 Accuracy: 0.50080
Epoch: 290 Batch:   0 Loss: 0.77779 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.77752 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.77726 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.77702 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.77680 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.77660 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.77642 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.77626 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.77614 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.77603 Accuracy: 0.00000
Test Loss: 0.69762 Accuracy: 0.50080
Epoch: 300 Batch:   0 Loss: 0.77594 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.77588 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.77583 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.75198 Accuracy: 0.00000
Test Loss: 0.71400 Accuracy: 0.50090
Epoch: 430 Batch:   0 Loss: 0.75192 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.75186 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.75180 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.75174 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.75169 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.75163 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.75158 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.75152 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.75147 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.75141 Accuracy: 0.00000
Test Loss: 0.71602 Accuracy: 0.50090
Epoch: 440 Batch:   0 Loss: 0.75136 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.75131 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.75126 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.75121 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.75116 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.75112 Accuracy: 0.00000
Epoch: 4

[I 2023-11-17 09:59:07,840] Trial 3 finished with value: 0.5007036590269401 and parameters: {'REG_W': 2.6161397357985874e-06, 'REG_B': 0.009477218367191434, 'REG_Z': 2.6137996213348635e-05, 'SPAR_W': 0.7924278818592528, 'SPAR_B': 0.9869714957831939, 'SPAR_Z': 0.8812742144423291, 'LEARNING_RATE': 0.00047823494417497005, 'NUM_EPOCHS': 505}. Best is trial 2 with value: 0.5008041817450744.


Epoch:   0 Batch:   0 Loss: 3.32370 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 4.83086 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 3.15451 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.50528 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.26500 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.18157 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.15504 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.14638 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.14021 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.13123 Accuracy: 0.00000
Test Loss: 1.05755 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.11931 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.10343 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.08405 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.06258 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.03798 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.01102 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.98309 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.95288 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.81594 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.81496 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.81399 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.81301 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.81206 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.81112 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.81021 Accuracy: 0.00000
Test Loss: 0.68947 Accuracy: 0.50070
Epoch: 150 Batch:   0 Loss: 0.80931 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.80843 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.80756 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.80670 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.80586 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.80505 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.80424 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.80347 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.80271 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.80198 Accuracy: 0.00000
Test Loss: 0.68555 Accuracy: 0.50070
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.76730 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.76720 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.76710 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.76700 Accuracy: 0.00000
Test Loss: 0.66639 Accuracy: 0.50060
Epoch: 290 Batch:   0 Loss: 0.76691 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.76682 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.76673 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.76663 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.76654 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.76645 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.76636 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.76626 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.76617 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.76607 Accuracy: 0.00000
Test Loss: 0.66647 Accuracy: 0.50070
Epoch: 300 Batch:   0 Loss: 0.76598 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.76589 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.76580 Accuracy: 0.00000
Epoch: 3

[I 2023-11-17 10:02:31,014] Trial 4 finished with value: 0.5004020908725372 and parameters: {'REG_W': 3.7655374124460386e-06, 'REG_B': 0.005034276869007759, 'REG_Z': 3.003075871217921e-05, 'SPAR_W': 0.8573520329247419, 'SPAR_B': 0.6245691349325497, 'SPAR_Z': 0.5970282447732955, 'LEARNING_RATE': 0.00038398908555143177, 'NUM_EPOCHS': 374}. Best is trial 2 with value: 0.5008041817450744.


Epoch:   0 Batch:   0 Loss: 3.35493 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 1.62445 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 1.43811 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.32024 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.24843 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.20529 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.17949 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.16401 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.15458 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.14864 Accuracy: 0.00000
Test Loss: 0.75496 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.14467 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.14178 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.13945 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.13741 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.13548 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.13360 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.13155 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.12951 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.82695 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.82556 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.82419 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.82281 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.82149 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.82016 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.81887 Accuracy: 0.03125
Test Loss: 0.71722 Accuracy: 0.56209
Epoch: 150 Batch:   0 Loss: 0.81761 Accuracy: 0.03125
Epoch: 151 Batch:   0 Loss: 0.81637 Accuracy: 0.03125
Epoch: 152 Batch:   0 Loss: 0.81509 Accuracy: 0.03125
Epoch: 153 Batch:   0 Loss: 0.81385 Accuracy: 0.03125
Epoch: 154 Batch:   0 Loss: 0.81263 Accuracy: 0.03125
Epoch: 155 Batch:   0 Loss: 0.81143 Accuracy: 0.03125
Epoch: 156 Batch:   0 Loss: 0.81023 Accuracy: 0.03125
Epoch: 157 Batch:   0 Loss: 0.80905 Accuracy: 0.03125
Epoch: 158 Batch:   0 Loss: 0.80787 Accuracy: 0.06250
Epoch: 159 Batch:   0 Loss: 0.80669 Accuracy: 0.06250
Test Loss: 0.71652 Accuracy: 0.56983
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.72429 Accuracy: 0.90625
Epoch: 287 Batch:   0 Loss: 0.72391 Accuracy: 0.90625
Epoch: 288 Batch:   0 Loss: 0.72354 Accuracy: 0.90625
Epoch: 289 Batch:   0 Loss: 0.72317 Accuracy: 0.90625
Test Loss: 0.71203 Accuracy: 0.64338
Epoch: 290 Batch:   0 Loss: 0.72280 Accuracy: 0.90625
Epoch: 291 Batch:   0 Loss: 0.72244 Accuracy: 0.90625
Epoch: 292 Batch:   0 Loss: 0.72208 Accuracy: 0.90625
Epoch: 293 Batch:   0 Loss: 0.72172 Accuracy: 0.90625
Epoch: 294 Batch:   0 Loss: 0.72136 Accuracy: 0.93750
Epoch: 295 Batch:   0 Loss: 0.72100 Accuracy: 0.93750
Epoch: 296 Batch:   0 Loss: 0.72064 Accuracy: 0.93750
Epoch: 297 Batch:   0 Loss: 0.72029 Accuracy: 0.96875
Epoch: 298 Batch:   0 Loss: 0.71994 Accuracy: 0.96875
Epoch: 299 Batch:   0 Loss: 0.71959 Accuracy: 1.00000
Test Loss: 0.71164 Accuracy: 0.64630
Epoch: 300 Batch:   0 Loss: 0.71924 Accuracy: 1.00000
Epoch: 301 Batch:   0 Loss: 0.71890 Accuracy: 1.00000
Epoch: 302 Batch:   0 Loss: 0.71856 Accuracy: 1.00000
Epoch: 3

[I 2023-11-17 10:06:12,118] Trial 5 finished with value: 0.6766184157619622 and parameters: {'REG_W': 3.979090600968805e-06, 'REG_B': 0.0007549210782041626, 'REG_Z': 2.1754592187918303e-05, 'SPAR_W': 0.70713793643758, 'SPAR_B': 0.6708918841540703, 'SPAR_Z': 0.5049920499452771, 'LEARNING_RATE': 0.00011383022856341377, 'NUM_EPOCHS': 425}. Best is trial 5 with value: 0.6766184157619622.


Epoch:   0 Batch:   0 Loss: 0.26583 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 5.43781 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 4.34331 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 3.87133 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 3.67692 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 3.60126 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 3.57393 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 3.56401 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 3.55835 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 3.55158 Accuracy: 0.00000
Test Loss: 1.68333 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 3.54145 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 3.52679 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 3.50764 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 3.48229 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 3.45207 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 3.41548 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 3.37315 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 3.32212 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.76429 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.76419 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.76414 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.76406 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.76398 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.76391 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.76382 Accuracy: 0.00000
Test Loss: 0.78558 Accuracy: 0.50111
Epoch: 150 Batch:   0 Loss: 0.76375 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.76367 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.76362 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.76356 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.76353 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.76354 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.76352 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.76351 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.76352 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.76350 Accuracy: 0.00000
Test Loss: 0.77752 Accuracy: 0.50111
Epoch: 1

[I 2023-11-17 10:08:27,692] Trial 6 finished with value: 0.5006031363088058 and parameters: {'REG_W': 2.0528350097010382e-06, 'REG_B': 0.006829258612051157, 'REG_Z': 4.787300524527533e-05, 'SPAR_W': 0.5443711183252911, 'SPAR_B': 0.6858431200088022, 'SPAR_Z': 0.5348947280603884, 'LEARNING_RATE': 0.0008885332323979621, 'NUM_EPOCHS': 257}. Best is trial 5 with value: 0.6766184157619622.


Epoch:   0 Batch:   0 Loss: 0.10138 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 2.15106 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 2.79645 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 3.11885 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 3.25503 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 3.29441 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 3.28624 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 3.25507 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 3.21394 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 3.16907 Accuracy: 0.00000
Test Loss: 1.61415 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 3.12352 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 3.07332 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 3.02088 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.96194 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.89383 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.81728 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.72952 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.62992 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.77236 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.77255 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.77280 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.77312 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.77351 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.77397 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.77447 Accuracy: 0.00000
Test Loss: 0.73116 Accuracy: 0.50121
Epoch: 150 Batch:   0 Loss: 0.77496 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.77538 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.77561 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.77562 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.77540 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.77499 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.77444 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.77378 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.77308 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.77237 Accuracy: 0.00000
Test Loss: 0.72754 Accuracy: 0.50121
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.74030 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.74027 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.74024 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.74020 Accuracy: 0.00000
Test Loss: 0.73449 Accuracy: 0.50121
Epoch: 290 Batch:   0 Loss: 0.74016 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.74013 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.74009 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.74005 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.74001 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.73997 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.73992 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.73989 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.73986 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.73981 Accuracy: 0.00000
Test Loss: 0.73737 Accuracy: 0.50121
Epoch: 300 Batch:   0 Loss: 0.73976 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.73972 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.73968 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.73716 Accuracy: 0.00000
Test Loss: 0.77556 Accuracy: 0.50121
Epoch: 430 Batch:   0 Loss: 0.73714 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.73713 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.73712 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.73711 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.73710 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.73709 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.73708 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.73707 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.73706 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.73706 Accuracy: 0.00000
Test Loss: 0.77826 Accuracy: 0.50121
Epoch: 440 Batch:   0 Loss: 0.73705 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.73705 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.73705 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.73704 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.73703 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.73703 Accuracy: 0.00000
Epoch: 4

[I 2023-11-17 10:12:33,292] Trial 7 finished with value: 0.501005227181343 and parameters: {'REG_W': 2.065814770435679e-06, 'REG_B': 0.0027971383710636054, 'REG_Z': 2.2493241430179776e-05, 'SPAR_W': 0.7398953282035616, 'SPAR_B': 0.9512823691406264, 'SPAR_Z': 0.5377375024841398, 'LEARNING_RATE': 0.0009050175708167293, 'NUM_EPOCHS': 468}. Best is trial 5 with value: 0.6766184157619622.


Epoch:   0 Batch:   0 Loss: 0.06252 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 2.26880 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 2.57547 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.72824 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.79466 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.81565 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.81334 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.79865 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.77735 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.75240 Accuracy: 0.00000
Test Loss: 1.38972 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.72536 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.69725 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.66853 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.63918 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.60954 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.58056 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.55132 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.52101 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.77979 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.78062 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.78141 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.78216 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.78290 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.78364 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.78436 Accuracy: 0.00000
Test Loss: 0.76561 Accuracy: 0.50070
Epoch: 150 Batch:   0 Loss: 0.78508 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.78579 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.78648 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.78713 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.78787 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.78856 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.78925 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.78992 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.79060 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.79125 Accuracy: 0.00000
Test Loss: 0.75479 Accuracy: 0.50070
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.78064 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.78047 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.78030 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.78012 Accuracy: 0.00000
Test Loss: 0.72768 Accuracy: 0.50060
Epoch: 290 Batch:   0 Loss: 0.77995 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.77977 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.77961 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.77945 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.77928 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.77912 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.77897 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.77883 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.77869 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.77855 Accuracy: 0.00000
Test Loss: 0.72884 Accuracy: 0.50060
Epoch: 300 Batch:   0 Loss: 0.77842 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.77829 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.77817 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.76194 Accuracy: 0.00000
Test Loss: 0.73069 Accuracy: 0.50040
Epoch: 430 Batch:   0 Loss: 0.76173 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.76151 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.76130 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.76110 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.76091 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.76073 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.76056 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.76039 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.76022 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.76007 Accuracy: 0.00000
Test Loss: 0.73294 Accuracy: 0.50040
Epoch: 440 Batch:   0 Loss: 0.75992 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.75978 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.75964 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.75951 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.75939 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.75927 Accuracy: 0.00000
Epoch: 4

[I 2023-11-17 10:16:44,997] Trial 8 finished with value: 0.5002010454362686 and parameters: {'REG_W': 3.817346363815336e-06, 'REG_B': 0.0010276432618562658, 'REG_Z': 3.062080140298306e-05, 'SPAR_W': 0.5598304338719655, 'SPAR_B': 0.7150516014943187, 'SPAR_Z': 0.5354972341993837, 'LEARNING_RATE': 0.0006150948015361265, 'NUM_EPOCHS': 478}. Best is trial 5 with value: 0.6766184157619622.


Epoch:   0 Batch:   0 Loss: 1.57460 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 3.76517 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 3.55688 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 3.44636 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 3.35841 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 3.27620 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 3.19699 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 3.11884 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 3.03908 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.95463 Accuracy: 0.00000
Test Loss: 1.69742 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.86491 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.76644 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.66131 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.54549 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.42377 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.29167 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.15594 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.01169 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.79596 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.79424 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.79264 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.79118 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.78981 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.78857 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.78739 Accuracy: 0.00000
Test Loss: 0.75415 Accuracy: 0.50080
Epoch: 150 Batch:   0 Loss: 0.78631 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.78528 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.78434 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.78350 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.78268 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.78193 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.78123 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.78058 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.77998 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.77941 Accuracy: 0.00000
Test Loss: 0.75594 Accuracy: 0.50080
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.76300 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.76298 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.76295 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.76292 Accuracy: 0.00000
Test Loss: 0.78517 Accuracy: 0.50060
Epoch: 290 Batch:   0 Loss: 0.76291 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.76289 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.76288 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.76286 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.76285 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.76284 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.76282 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.76281 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.76279 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.76278 Accuracy: 0.00000
Test Loss: 0.78717 Accuracy: 0.50060
Epoch: 300 Batch:   0 Loss: 0.76276 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.76276 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.76275 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.76187 Accuracy: 0.00000
Test Loss: 0.80943 Accuracy: 0.50060
Epoch: 430 Batch:   0 Loss: 0.76186 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.76185 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.76184 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.76183 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.76182 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.76182 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.76181 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.76180 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.76180 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.76179 Accuracy: 0.00000
Test Loss: 0.81085 Accuracy: 0.50060
Epoch: 440 Batch:   0 Loss: 0.76178 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.76177 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.76177 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.76176 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.76175 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.76175 Accuracy: 0.00000
Epoch: 4

[I 2023-11-17 10:22:49,477] Trial 9 finished with value: 0.5004020908725372 and parameters: {'REG_W': 4.868350288107119e-06, 'REG_B': 0.007451850465311245, 'REG_Z': 4.6738002974418115e-05, 'SPAR_W': 0.8302334538032026, 'SPAR_B': 0.9143803036852067, 'SPAR_Z': 0.8810790109315592, 'LEARNING_RATE': 0.0009517824797391942, 'NUM_EPOCHS': 568}. Best is trial 5 with value: 0.6766184157619622.


Epoch:   0 Batch:   0 Loss: 2.89099 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 1.90194 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 1.53739 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.35860 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.27392 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.23346 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.21331 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.20234 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.19542 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.19026 Accuracy: 0.00000
Test Loss: 0.77289 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.18588 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.18190 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.17812 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.17456 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.17114 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.16797 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.16497 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.16215 Accuracy:

Epoch: 143 Batch:   0 Loss: 1.05025 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 1.04682 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 1.04341 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 1.04004 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 1.03668 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 1.03335 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 1.03005 Accuracy: 0.00000
Test Loss: 0.68085 Accuracy: 0.50030
Epoch: 150 Batch:   0 Loss: 1.02677 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 1.02352 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 1.02028 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 1.01708 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 1.01390 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 1.01075 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 1.00763 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 1.00455 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 1.00150 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.99847 Accuracy: 0.00000
Test Loss: 0.67880 Accuracy: 0.50030
Epoch: 1

[I 2023-11-17 10:24:49,189] Trial 10 finished with value: 0.5003015681544029 and parameters: {'REG_W': 3.2622571574382e-06, 'REG_B': 0.003250951059039567, 'REG_Z': 2.1267256299300327e-05, 'SPAR_W': 0.9975001651887614, 'SPAR_B': 0.7955717878438994, 'SPAR_Z': 0.6933435937087871, 'LEARNING_RATE': 0.00011841834233424083, 'NUM_EPOCHS': 224}. Best is trial 5 with value: 0.6766184157619622.


Epoch:   0 Batch:   0 Loss: 0.09468 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 1.57375 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 1.47806 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.43629 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.42412 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.42447 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.42894 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.43386 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.43780 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.44045 Accuracy: 0.00000
Test Loss: 0.84859 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.44187 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.44230 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.44205 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.44128 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.44024 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.43898 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.43757 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.43608 Accuracy:

Epoch: 143 Batch:   0 Loss: 1.05844 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 1.05388 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 1.04932 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 1.04479 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 1.04025 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 1.03579 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 1.03136 Accuracy: 0.00000
Test Loss: 0.71039 Accuracy: 0.50030
Epoch: 150 Batch:   0 Loss: 1.02696 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 1.02260 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 1.01820 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 1.01380 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 1.00951 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 1.00526 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 1.00102 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.99697 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.99290 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.98889 Accuracy: 0.00000
Test Loss: 0.70462 Accuracy: 0.50030
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.78475 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.78426 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.78377 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.78330 Accuracy: 0.00000
Test Loss: 0.66586 Accuracy: 0.50040
Epoch: 290 Batch:   0 Loss: 0.78282 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.78235 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.78189 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.78143 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.78099 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.78055 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.78012 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.77969 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.77927 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.77886 Accuracy: 0.00000
Test Loss: 0.66402 Accuracy: 0.50040
Epoch: 300 Batch:   0 Loss: 0.77845 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.77804 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.77765 Accuracy: 0.00000
Epoch: 3

[I 2023-11-17 10:28:38,692] Trial 11 finished with value: 0.5002010454362686 and parameters: {'REG_W': 3.4072378154173003e-06, 'REG_B': 0.002602785572417885, 'REG_Z': 2.0170246749393997e-05, 'SPAR_W': 0.6724704939309168, 'SPAR_B': 0.8145582260082351, 'SPAR_Z': 0.517161356568251, 'LEARNING_RATE': 0.0001906247918707398, 'NUM_EPOCHS': 426}. Best is trial 5 with value: 0.6766184157619622.


Epoch:   0 Batch:   0 Loss: 11.54125 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 7.95788 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 5.49431 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 3.69179 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.68513 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.17771 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.93412 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.82319 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.77736 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.76217 Accuracy: 0.00000
Test Loss: 0.92040 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.76013 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.76242 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.76482 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.76549 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.76386 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.75996 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.75410 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.74664 Accuracy

Epoch: 143 Batch:   0 Loss: 0.84355 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.84329 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.84306 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.84285 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.84266 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.84250 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.84234 Accuracy: 0.00000
Test Loss: 0.71003 Accuracy: 0.50050
Epoch: 150 Batch:   0 Loss: 0.84221 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.84209 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.84198 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.84187 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.84176 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.84164 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.84150 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.84141 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.84132 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.84122 Accuracy: 0.00000
Test Loss: 0.70391 Accuracy: 0.50050
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.78806 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.78779 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.78751 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.78725 Accuracy: 0.00000
Test Loss: 0.66690 Accuracy: 0.50050
Epoch: 290 Batch:   0 Loss: 0.78698 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.78673 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.78647 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.78621 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.78596 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.78571 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.78547 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.78523 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.78499 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.78476 Accuracy: 0.00000
Test Loss: 0.66610 Accuracy: 0.50050
Epoch: 300 Batch:   0 Loss: 0.78454 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.78431 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.78409 Accuracy: 0.00000
Epoch: 3

[I 2023-11-17 10:32:24,736] Trial 12 finished with value: 0.5003015681544029 and parameters: {'REG_W': 4.11158743995577e-06, 'REG_B': 0.002997184575350683, 'REG_Z': 2.3415161857228485e-05, 'SPAR_W': 0.6842777308500527, 'SPAR_B': 0.8517741688183632, 'SPAR_Z': 0.5059483083458255, 'LEARNING_RATE': 0.00029404840238769303, 'NUM_EPOCHS': 422}. Best is trial 5 with value: 0.6766184157619622.


Epoch:   0 Batch:   0 Loss: 0.03627 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 1.91066 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 2.83649 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 3.20125 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 3.32216 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 3.34748 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 3.33565 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 3.30698 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 3.26917 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 3.22545 Accuracy: 0.00000
Test Loss: 1.49759 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 3.17393 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 3.11581 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 3.05002 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.97829 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.89900 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.81792 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.73287 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.64386 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.73935 Accuracy: 0.75000
Epoch: 144 Batch:   0 Loss: 0.73759 Accuracy: 0.75000
Epoch: 145 Batch:   0 Loss: 0.73582 Accuracy: 0.75000
Epoch: 146 Batch:   0 Loss: 0.73409 Accuracy: 0.75000
Epoch: 147 Batch:   0 Loss: 0.73235 Accuracy: 0.75000
Epoch: 148 Batch:   0 Loss: 0.73062 Accuracy: 0.75000
Epoch: 149 Batch:   0 Loss: 0.72891 Accuracy: 0.75000
Test Loss: 1.14164 Accuracy: 0.58239
Epoch: 150 Batch:   0 Loss: 0.72727 Accuracy: 0.75000
Epoch: 151 Batch:   0 Loss: 0.72560 Accuracy: 0.75000
Epoch: 152 Batch:   0 Loss: 0.72397 Accuracy: 0.75000
Epoch: 153 Batch:   0 Loss: 0.72239 Accuracy: 0.75000
Epoch: 154 Batch:   0 Loss: 0.72076 Accuracy: 0.75000
Epoch: 155 Batch:   0 Loss: 0.71915 Accuracy: 0.75000
Epoch: 156 Batch:   0 Loss: 0.71756 Accuracy: 0.75000
Epoch: 157 Batch:   0 Loss: 0.71598 Accuracy: 0.75000
Epoch: 158 Batch:   0 Loss: 0.71450 Accuracy: 0.75000
Epoch: 159 Batch:   0 Loss: 0.71302 Accuracy: 0.75000
Test Loss: 1.12222 Accuracy: 0.59143
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.55282 Accuracy: 0.96875
Epoch: 287 Batch:   0 Loss: 0.55181 Accuracy: 0.96875
Epoch: 288 Batch:   0 Loss: 0.55079 Accuracy: 0.96875
Epoch: 289 Batch:   0 Loss: 0.54976 Accuracy: 0.96875
Test Loss: 0.98141 Accuracy: 0.67061
Epoch: 290 Batch:   0 Loss: 0.54879 Accuracy: 0.96875
Epoch: 291 Batch:   0 Loss: 0.54780 Accuracy: 0.96875
Epoch: 292 Batch:   0 Loss: 0.54681 Accuracy: 0.96875
Epoch: 293 Batch:   0 Loss: 0.54582 Accuracy: 0.96875
Epoch: 294 Batch:   0 Loss: 0.54482 Accuracy: 0.96875
Epoch: 295 Batch:   0 Loss: 0.54380 Accuracy: 0.96875
Epoch: 296 Batch:   0 Loss: 0.54279 Accuracy: 0.96875
Epoch: 297 Batch:   0 Loss: 0.54175 Accuracy: 0.96875
Epoch: 298 Batch:   0 Loss: 0.54075 Accuracy: 0.96875
Epoch: 299 Batch:   0 Loss: 0.53976 Accuracy: 0.96875
Test Loss: 0.97507 Accuracy: 0.67352
Epoch: 300 Batch:   0 Loss: 0.53876 Accuracy: 0.96875
Epoch: 301 Batch:   0 Loss: 0.53777 Accuracy: 0.96875
Epoch: 302 Batch:   0 Loss: 0.53678 Accuracy: 0.96875
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.43240 Accuracy: 1.00000
Test Loss: 0.92155 Accuracy: 0.70135
Epoch: 430 Batch:   0 Loss: 0.43171 Accuracy: 1.00000
Epoch: 431 Batch:   0 Loss: 0.43102 Accuracy: 1.00000
Epoch: 432 Batch:   0 Loss: 0.43033 Accuracy: 1.00000
Epoch: 433 Batch:   0 Loss: 0.42965 Accuracy: 1.00000
Epoch: 434 Batch:   0 Loss: 0.42897 Accuracy: 1.00000
Epoch: 435 Batch:   0 Loss: 0.42829 Accuracy: 1.00000
Epoch: 436 Batch:   0 Loss: 0.42762 Accuracy: 1.00000
Epoch: 437 Batch:   0 Loss: 0.42695 Accuracy: 1.00000
Epoch: 438 Batch:   0 Loss: 0.42631 Accuracy: 1.00000
Epoch: 439 Batch:   0 Loss: 0.42567 Accuracy: 1.00000
Test Loss: 0.91860 Accuracy: 0.70336
Epoch: 440 Batch:   0 Loss: 0.42500 Accuracy: 1.00000
Epoch: 441 Batch:   0 Loss: 0.42433 Accuracy: 1.00000
Epoch: 442 Batch:   0 Loss: 0.42368 Accuracy: 1.00000
Epoch: 443 Batch:   0 Loss: 0.42303 Accuracy: 1.00000
Epoch: 444 Batch:   0 Loss: 0.42239 Accuracy: 1.00000
Epoch: 445 Batch:   0 Loss: 0.42174 Accuracy: 1.00000
Epoch: 4

[I 2023-11-17 10:37:24,632] Trial 13 finished with value: 0.7160233212706072 and parameters: {'REG_W': 3.11323756630625e-06, 'REG_B': 0.00018466641837387807, 'REG_Z': 3.370459835067725e-05, 'SPAR_W': 0.6947857066796415, 'SPAR_B': 0.7339627145414906, 'SPAR_Z': 0.6210527176469138, 'LEARNING_RATE': 0.0007713727368692295, 'NUM_EPOCHS': 552}. Best is trial 13 with value: 0.7160233212706072.


Epoch:   0 Batch:   0 Loss: 0.02725 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.30922 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 0.95055 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.42574 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.65662 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.74970 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.77781 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.77682 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.76212 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.74181 Accuracy: 0.00000
Test Loss: 0.96706 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.71853 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.69396 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.66957 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.64530 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.62187 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.59948 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.57837 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.55864 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.61052 Accuracy: 1.00000
Epoch: 144 Batch:   0 Loss: 0.60908 Accuracy: 1.00000
Epoch: 145 Batch:   0 Loss: 0.60768 Accuracy: 1.00000
Epoch: 146 Batch:   0 Loss: 0.60631 Accuracy: 1.00000
Epoch: 147 Batch:   0 Loss: 0.60495 Accuracy: 1.00000
Epoch: 148 Batch:   0 Loss: 0.60362 Accuracy: 1.00000
Epoch: 149 Batch:   0 Loss: 0.60229 Accuracy: 1.00000
Test Loss: 0.77294 Accuracy: 0.66669
Epoch: 150 Batch:   0 Loss: 0.60098 Accuracy: 1.00000
Epoch: 151 Batch:   0 Loss: 0.59968 Accuracy: 1.00000
Epoch: 152 Batch:   0 Loss: 0.59840 Accuracy: 1.00000
Epoch: 153 Batch:   0 Loss: 0.59711 Accuracy: 1.00000
Epoch: 154 Batch:   0 Loss: 0.59584 Accuracy: 1.00000
Epoch: 155 Batch:   0 Loss: 0.59462 Accuracy: 1.00000
Epoch: 156 Batch:   0 Loss: 0.59336 Accuracy: 1.00000
Epoch: 157 Batch:   0 Loss: 0.59213 Accuracy: 1.00000
Epoch: 158 Batch:   0 Loss: 0.59092 Accuracy: 1.00000
Epoch: 159 Batch:   0 Loss: 0.58973 Accuracy: 1.00000
Test Loss: 0.76764 Accuracy: 0.67261
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.49539 Accuracy: 1.00000
Epoch: 287 Batch:   0 Loss: 0.49482 Accuracy: 1.00000
Epoch: 288 Batch:   0 Loss: 0.49425 Accuracy: 1.00000
Epoch: 289 Batch:   0 Loss: 0.49369 Accuracy: 1.00000
Test Loss: 0.71839 Accuracy: 0.72506
Epoch: 290 Batch:   0 Loss: 0.49313 Accuracy: 1.00000
Epoch: 291 Batch:   0 Loss: 0.49258 Accuracy: 1.00000
Epoch: 292 Batch:   0 Loss: 0.49203 Accuracy: 1.00000
Epoch: 293 Batch:   0 Loss: 0.49148 Accuracy: 1.00000
Epoch: 294 Batch:   0 Loss: 0.49092 Accuracy: 1.00000
Epoch: 295 Batch:   0 Loss: 0.49037 Accuracy: 1.00000
Epoch: 296 Batch:   0 Loss: 0.48981 Accuracy: 1.00000
Epoch: 297 Batch:   0 Loss: 0.48925 Accuracy: 1.00000
Epoch: 298 Batch:   0 Loss: 0.48869 Accuracy: 1.00000
Epoch: 299 Batch:   0 Loss: 0.48814 Accuracy: 1.00000
Test Loss: 0.71603 Accuracy: 0.72727
Epoch: 300 Batch:   0 Loss: 0.48758 Accuracy: 1.00000
Epoch: 301 Batch:   0 Loss: 0.48703 Accuracy: 1.00000
Epoch: 302 Batch:   0 Loss: 0.48651 Accuracy: 1.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.42757 Accuracy: 1.00000
Test Loss: 0.69559 Accuracy: 0.74445
Epoch: 430 Batch:   0 Loss: 0.42717 Accuracy: 1.00000
Epoch: 431 Batch:   0 Loss: 0.42678 Accuracy: 1.00000
Epoch: 432 Batch:   0 Loss: 0.42639 Accuracy: 1.00000
Epoch: 433 Batch:   0 Loss: 0.42599 Accuracy: 1.00000
Epoch: 434 Batch:   0 Loss: 0.42560 Accuracy: 1.00000
Epoch: 435 Batch:   0 Loss: 0.42521 Accuracy: 1.00000
Epoch: 436 Batch:   0 Loss: 0.42482 Accuracy: 1.00000
Epoch: 437 Batch:   0 Loss: 0.42443 Accuracy: 1.00000
Epoch: 438 Batch:   0 Loss: 0.42404 Accuracy: 1.00000
Epoch: 439 Batch:   0 Loss: 0.42365 Accuracy: 1.00000
Test Loss: 0.69465 Accuracy: 0.74575
Epoch: 440 Batch:   0 Loss: 0.42326 Accuracy: 1.00000
Epoch: 441 Batch:   0 Loss: 0.42286 Accuracy: 1.00000
Epoch: 442 Batch:   0 Loss: 0.42247 Accuracy: 1.00000
Epoch: 443 Batch:   0 Loss: 0.42207 Accuracy: 1.00000
Epoch: 444 Batch:   0 Loss: 0.42167 Accuracy: 1.00000
Epoch: 445 Batch:   0 Loss: 0.42128 Accuracy: 1.00000
Epoch: 4

Epoch: 571 Batch:   0 Loss: 0.37723 Accuracy: 1.00000
Epoch: 572 Batch:   0 Loss: 0.37692 Accuracy: 1.00000
Epoch: 573 Batch:   0 Loss: 0.37661 Accuracy: 1.00000
Epoch: 574 Batch:   0 Loss: 0.37630 Accuracy: 1.00000
Epoch: 575 Batch:   0 Loss: 0.37600 Accuracy: 1.00000
Epoch: 576 Batch:   0 Loss: 0.37569 Accuracy: 1.00000
Epoch: 577 Batch:   0 Loss: 0.37538 Accuracy: 1.00000
Epoch: 578 Batch:   0 Loss: 0.37508 Accuracy: 1.00000
Epoch: 579 Batch:   0 Loss: 0.37479 Accuracy: 1.00000
Test Loss: 0.68646 Accuracy: 0.75559
Epoch: 580 Batch:   0 Loss: 0.37449 Accuracy: 1.00000
Epoch: 581 Batch:   0 Loss: 0.37419 Accuracy: 1.00000
Epoch: 582 Batch:   0 Loss: 0.37389 Accuracy: 1.00000
Epoch: 583 Batch:   0 Loss: 0.37358 Accuracy: 1.00000
Epoch: 584 Batch:   0 Loss: 0.37329 Accuracy: 1.00000
Epoch: 585 Batch:   0 Loss: 0.37299 Accuracy: 1.00000
Epoch: 586 Batch:   0 Loss: 0.37269 Accuracy: 1.00000
Epoch: 587 Batch:   0 Loss: 0.37239 Accuracy: 1.00000
Epoch: 588 Batch:   0 Loss: 0.37209 Accuracy:

[I 2023-11-17 10:42:47,287] Trial 14 finished with value: 0.756031363088058 and parameters: {'REG_W': 3.1545272147130644e-06, 'REG_B': 0.00025738817748430687, 'REG_Z': 3.534809578054365e-05, 'SPAR_W': 0.6622823978261199, 'SPAR_B': 0.651918182305612, 'SPAR_Z': 0.6197401480486052, 'LEARNING_RATE': 0.00029791501581002825, 'NUM_EPOCHS': 595}. Best is trial 14 with value: 0.756031363088058.


Epoch:   0 Batch:   0 Loss: 0.01370 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.95344 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 2.26497 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.84428 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 3.06146 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 3.13722 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 3.15652 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 3.15042 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 3.13104 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 3.10229 Accuracy: 0.00000
Test Loss: 1.51141 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 3.06513 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 3.01950 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.96723 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.90512 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.83923 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.76650 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.68622 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.60254 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.70141 Accuracy: 0.78125
Epoch: 144 Batch:   0 Loss: 0.69991 Accuracy: 0.78125
Epoch: 145 Batch:   0 Loss: 0.69838 Accuracy: 0.78125
Epoch: 146 Batch:   0 Loss: 0.69681 Accuracy: 0.78125
Epoch: 147 Batch:   0 Loss: 0.69531 Accuracy: 0.78125
Epoch: 148 Batch:   0 Loss: 0.69382 Accuracy: 0.78125
Epoch: 149 Batch:   0 Loss: 0.69234 Accuracy: 0.81250
Test Loss: 1.11215 Accuracy: 0.58359
Epoch: 150 Batch:   0 Loss: 0.69087 Accuracy: 0.81250
Epoch: 151 Batch:   0 Loss: 0.68948 Accuracy: 0.81250
Epoch: 152 Batch:   0 Loss: 0.68815 Accuracy: 0.81250
Epoch: 153 Batch:   0 Loss: 0.68685 Accuracy: 0.81250
Epoch: 154 Batch:   0 Loss: 0.68551 Accuracy: 0.81250
Epoch: 155 Batch:   0 Loss: 0.68419 Accuracy: 0.81250
Epoch: 156 Batch:   0 Loss: 0.68286 Accuracy: 0.81250
Epoch: 157 Batch:   0 Loss: 0.68153 Accuracy: 0.81250
Epoch: 158 Batch:   0 Loss: 0.68019 Accuracy: 0.81250
Epoch: 159 Batch:   0 Loss: 0.67887 Accuracy: 0.81250
Test Loss: 1.09718 Accuracy: 0.59193
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.54751 Accuracy: 0.96875
Epoch: 287 Batch:   0 Loss: 0.54671 Accuracy: 0.96875
Epoch: 288 Batch:   0 Loss: 0.54593 Accuracy: 0.96875
Epoch: 289 Batch:   0 Loss: 0.54517 Accuracy: 0.96875
Test Loss: 0.98558 Accuracy: 0.66569
Epoch: 290 Batch:   0 Loss: 0.54437 Accuracy: 0.96875
Epoch: 291 Batch:   0 Loss: 0.54358 Accuracy: 0.96875
Epoch: 292 Batch:   0 Loss: 0.54280 Accuracy: 0.96875
Epoch: 293 Batch:   0 Loss: 0.54202 Accuracy: 0.96875
Epoch: 294 Batch:   0 Loss: 0.54124 Accuracy: 0.96875
Epoch: 295 Batch:   0 Loss: 0.54047 Accuracy: 0.96875
Epoch: 296 Batch:   0 Loss: 0.53969 Accuracy: 0.96875
Epoch: 297 Batch:   0 Loss: 0.53892 Accuracy: 0.96875
Epoch: 298 Batch:   0 Loss: 0.53815 Accuracy: 0.96875
Epoch: 299 Batch:   0 Loss: 0.53738 Accuracy: 0.96875
Test Loss: 0.98070 Accuracy: 0.66850
Epoch: 300 Batch:   0 Loss: 0.53662 Accuracy: 0.96875
Epoch: 301 Batch:   0 Loss: 0.53586 Accuracy: 0.96875
Epoch: 302 Batch:   0 Loss: 0.53510 Accuracy: 0.96875
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.45726 Accuracy: 1.00000
Test Loss: 0.93576 Accuracy: 0.69633
Epoch: 430 Batch:   0 Loss: 0.45676 Accuracy: 1.00000
Epoch: 431 Batch:   0 Loss: 0.45626 Accuracy: 1.00000
Epoch: 432 Batch:   0 Loss: 0.45576 Accuracy: 1.00000
Epoch: 433 Batch:   0 Loss: 0.45527 Accuracy: 1.00000
Epoch: 434 Batch:   0 Loss: 0.45478 Accuracy: 1.00000
Epoch: 435 Batch:   0 Loss: 0.45429 Accuracy: 1.00000
Epoch: 436 Batch:   0 Loss: 0.45381 Accuracy: 1.00000
Epoch: 437 Batch:   0 Loss: 0.45332 Accuracy: 1.00000
Epoch: 438 Batch:   0 Loss: 0.45284 Accuracy: 1.00000
Epoch: 439 Batch:   0 Loss: 0.45236 Accuracy: 1.00000
Test Loss: 0.93299 Accuracy: 0.69754
Epoch: 440 Batch:   0 Loss: 0.45188 Accuracy: 1.00000
Epoch: 441 Batch:   0 Loss: 0.45140 Accuracy: 1.00000
Epoch: 442 Batch:   0 Loss: 0.45092 Accuracy: 1.00000
Epoch: 443 Batch:   0 Loss: 0.45044 Accuracy: 1.00000
Epoch: 444 Batch:   0 Loss: 0.44997 Accuracy: 1.00000
Epoch: 445 Batch:   0 Loss: 0.44949 Accuracy: 1.00000
Epoch: 4

[I 2023-11-17 10:47:54,590] Trial 15 finished with value: 0.712605548854041 and parameters: {'REG_W': 3.096750166716164e-06, 'REG_B': 0.00027084498079148574, 'REG_Z': 3.662498911469182e-05, 'SPAR_W': 0.6223364353859038, 'SPAR_B': 0.7503984140113251, 'SPAR_Z': 0.6836677974438338, 'LEARNING_RATE': 0.0007492425529136179, 'NUM_EPOCHS': 565}. Best is trial 14 with value: 0.756031363088058.


Epoch:   0 Batch:   0 Loss: 0.05363 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.03570 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 0.02754 Accuracy: 1.00000
Epoch:   3 Batch:   0 Loss: 0.06600 Accuracy: 1.00000
Epoch:   4 Batch:   0 Loss: 0.45104 Accuracy: 1.00000
Epoch:   5 Batch:   0 Loss: 1.07294 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.49025 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.70909 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.81781 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.87126 Accuracy: 0.00000
Test Loss: 1.09371 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.89728 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.90888 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.91240 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.91107 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.90652 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.89988 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.89145 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.88173 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.79743 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.79743 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.79745 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.79746 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.79745 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.79745 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.79740 Accuracy: 0.00000
Test Loss: 0.74437 Accuracy: 0.50070
Epoch: 150 Batch:   0 Loss: 0.79736 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.79729 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.79720 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.79713 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.79707 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.79703 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.79698 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.79694 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.79688 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.79681 Accuracy: 0.00000
Test Loss: 0.73711 Accuracy: 0.50070
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.78800 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.78794 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.78788 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.78783 Accuracy: 0.00000
Test Loss: 0.68990 Accuracy: 0.50070
Epoch: 290 Batch:   0 Loss: 0.78777 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.78772 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.78766 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.78758 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.78751 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.78741 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.78732 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.78723 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.78713 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.78704 Accuracy: 0.00000
Test Loss: 0.68830 Accuracy: 0.50070
Epoch: 300 Batch:   0 Loss: 0.78695 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.78686 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.78677 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.77521 Accuracy: 0.00000
Test Loss: 0.67677 Accuracy: 0.50060
Epoch: 430 Batch:   0 Loss: 0.77503 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.77485 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.77468 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.77453 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.77438 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.77425 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.77413 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.77403 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.77394 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.77387 Accuracy: 0.00000
Test Loss: 0.67588 Accuracy: 0.50060
Epoch: 440 Batch:   0 Loss: 0.77381 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.77376 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.77373 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.77370 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.77368 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.77368 Accuracy: 0.00000
Epoch: 4

Epoch: 571 Batch:   0 Loss: 0.75216 Accuracy: 0.00000
Epoch: 572 Batch:   0 Loss: 0.75210 Accuracy: 0.00000
Epoch: 573 Batch:   0 Loss: 0.75204 Accuracy: 0.00000
Epoch: 574 Batch:   0 Loss: 0.75198 Accuracy: 0.00000
Epoch: 575 Batch:   0 Loss: 0.75193 Accuracy: 0.00000
Epoch: 576 Batch:   0 Loss: 0.75187 Accuracy: 0.00000
Epoch: 577 Batch:   0 Loss: 0.75181 Accuracy: 0.00000
Epoch: 578 Batch:   0 Loss: 0.75176 Accuracy: 0.00000
Epoch: 579 Batch:   0 Loss: 0.75171 Accuracy: 0.00000
Test Loss: 0.68335 Accuracy: 0.50080
Epoch: 580 Batch:   0 Loss: 0.75165 Accuracy: 0.00000
Epoch: 581 Batch:   0 Loss: 0.75160 Accuracy: 0.00000
Epoch: 582 Batch:   0 Loss: 0.75155 Accuracy: 0.00000
Epoch: 583 Batch:   0 Loss: 0.75150 Accuracy: 0.00000
Epoch: 584 Batch:   0 Loss: 0.75145 Accuracy: 0.00000
Epoch: 585 Batch:   0 Loss: 0.75141 Accuracy: 0.00000
Epoch: 586 Batch:   0 Loss: 0.75136 Accuracy: 0.00000
Epoch: 587 Batch:   0 Loss: 0.75131 Accuracy: 0.00000
Epoch: 588 Batch:   0 Loss: 0.75126 Accuracy:

[I 2023-11-17 10:53:20,668] Trial 16 finished with value: 0.5006031363088058 and parameters: {'REG_W': 2.9921967736687803e-06, 'REG_B': 0.0017399308847832016, 'REG_Z': 3.73477848815666e-05, 'SPAR_W': 0.6251354755252909, 'SPAR_B': 0.6131562427040257, 'SPAR_Z': 0.6032833747352059, 'LEARNING_RATE': 0.00042360329673525035, 'NUM_EPOCHS': 600}. Best is trial 14 with value: 0.756031363088058.


Epoch:   0 Batch:   0 Loss: 1.76903 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 1.85262 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 1.89864 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.91419 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.91487 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.90810 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.89684 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.88228 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.86532 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.84615 Accuracy: 0.00000
Test Loss: 0.98552 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.82486 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.80212 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.77745 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.75155 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.72473 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.69662 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.66896 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.64190 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.56690 Accuracy: 1.00000
Epoch: 144 Batch:   0 Loss: 0.56548 Accuracy: 1.00000
Epoch: 145 Batch:   0 Loss: 0.56408 Accuracy: 1.00000
Epoch: 146 Batch:   0 Loss: 0.56270 Accuracy: 1.00000
Epoch: 147 Batch:   0 Loss: 0.56132 Accuracy: 1.00000
Epoch: 148 Batch:   0 Loss: 0.55997 Accuracy: 1.00000
Epoch: 149 Batch:   0 Loss: 0.55863 Accuracy: 1.00000
Test Loss: 0.76969 Accuracy: 0.68024
Epoch: 150 Batch:   0 Loss: 0.55730 Accuracy: 1.00000
Epoch: 151 Batch:   0 Loss: 0.55599 Accuracy: 1.00000
Epoch: 152 Batch:   0 Loss: 0.55467 Accuracy: 1.00000
Epoch: 153 Batch:   0 Loss: 0.55334 Accuracy: 1.00000
Epoch: 154 Batch:   0 Loss: 0.55202 Accuracy: 1.00000
Epoch: 155 Batch:   0 Loss: 0.55071 Accuracy: 1.00000
Epoch: 156 Batch:   0 Loss: 0.54941 Accuracy: 1.00000
Epoch: 157 Batch:   0 Loss: 0.54808 Accuracy: 1.00000
Epoch: 158 Batch:   0 Loss: 0.54682 Accuracy: 1.00000
Epoch: 159 Batch:   0 Loss: 0.54557 Accuracy: 1.00000
Test Loss: 0.76477 Accuracy: 0.68657
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.43538 Accuracy: 1.00000
Epoch: 287 Batch:   0 Loss: 0.43474 Accuracy: 1.00000
Epoch: 288 Batch:   0 Loss: 0.43412 Accuracy: 1.00000
Epoch: 289 Batch:   0 Loss: 0.43348 Accuracy: 1.00000
Test Loss: 0.72597 Accuracy: 0.73138
Epoch: 290 Batch:   0 Loss: 0.43283 Accuracy: 1.00000
Epoch: 291 Batch:   0 Loss: 0.43220 Accuracy: 1.00000
Epoch: 292 Batch:   0 Loss: 0.43156 Accuracy: 1.00000
Epoch: 293 Batch:   0 Loss: 0.43092 Accuracy: 1.00000
Epoch: 294 Batch:   0 Loss: 0.43029 Accuracy: 1.00000
Epoch: 295 Batch:   0 Loss: 0.42966 Accuracy: 1.00000
Epoch: 296 Batch:   0 Loss: 0.42904 Accuracy: 1.00000
Epoch: 297 Batch:   0 Loss: 0.42841 Accuracy: 1.00000
Epoch: 298 Batch:   0 Loss: 0.42778 Accuracy: 1.00000
Epoch: 299 Batch:   0 Loss: 0.42717 Accuracy: 1.00000
Test Loss: 0.72438 Accuracy: 0.73329
Epoch: 300 Batch:   0 Loss: 0.42655 Accuracy: 1.00000
Epoch: 301 Batch:   0 Loss: 0.42593 Accuracy: 1.00000
Epoch: 302 Batch:   0 Loss: 0.42531 Accuracy: 1.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.36484 Accuracy: 1.00000
Test Loss: 0.70891 Accuracy: 0.74725
Epoch: 430 Batch:   0 Loss: 0.36445 Accuracy: 1.00000
Epoch: 431 Batch:   0 Loss: 0.36405 Accuracy: 1.00000
Epoch: 432 Batch:   0 Loss: 0.36366 Accuracy: 1.00000
Epoch: 433 Batch:   0 Loss: 0.36326 Accuracy: 1.00000
Epoch: 434 Batch:   0 Loss: 0.36287 Accuracy: 1.00000
Epoch: 435 Batch:   0 Loss: 0.36248 Accuracy: 1.00000
Epoch: 436 Batch:   0 Loss: 0.36209 Accuracy: 1.00000
Epoch: 437 Batch:   0 Loss: 0.36171 Accuracy: 1.00000
Epoch: 438 Batch:   0 Loss: 0.36132 Accuracy: 1.00000
Epoch: 439 Batch:   0 Loss: 0.36094 Accuracy: 1.00000
Test Loss: 0.70797 Accuracy: 0.74826
Epoch: 440 Batch:   0 Loss: 0.36056 Accuracy: 1.00000
Epoch: 441 Batch:   0 Loss: 0.36018 Accuracy: 1.00000
Epoch: 442 Batch:   0 Loss: 0.35980 Accuracy: 1.00000
Epoch: 443 Batch:   0 Loss: 0.35943 Accuracy: 1.00000
Epoch: 444 Batch:   0 Loss: 0.35904 Accuracy: 1.00000
Epoch: 445 Batch:   0 Loss: 0.35897 Accuracy: 1.00000
Epoch: 4

[I 2023-11-17 10:58:06,766] Trial 17 finished with value: 0.7543224768797748 and parameters: {'REG_W': 3.5406815975622515e-06, 'REG_B': 0.00016864863899221193, 'REG_Z': 3.381830624969882e-05, 'SPAR_W': 0.6256238933757005, 'SPAR_B': 0.7487258140054214, 'SPAR_Z': 0.7340778613007167, 'LEARNING_RATE': 0.00031638933153247143, 'NUM_EPOCHS': 541}. Best is trial 14 with value: 0.756031363088058.


Epoch:   0 Batch:   0 Loss: 0.12972 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.12719 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 0.67977 Accuracy: 1.00000
Epoch:   3 Batch:   0 Loss: 1.41454 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.72430 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.82880 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.86169 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.86994 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.86907 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.86433 Accuracy: 0.00000
Test Loss: 1.00016 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.85754 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.84953 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.84067 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.83124 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.82150 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.81160 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.80164 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.79168 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.78629 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.78659 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.78694 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.78732 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.78770 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.78811 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.78847 Accuracy: 0.00000
Test Loss: 0.78890 Accuracy: 0.50060
Epoch: 150 Batch:   0 Loss: 0.78888 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.78933 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.78975 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.79019 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.79064 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.79107 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.79150 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.79192 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.79235 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.79278 Accuracy: 0.00000
Test Loss: 0.77564 Accuracy: 0.50060
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.80002 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.79985 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.79967 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.79950 Accuracy: 0.00000
Test Loss: 0.70459 Accuracy: 0.50050
Epoch: 290 Batch:   0 Loss: 0.79933 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.79917 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.79902 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.79889 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.79876 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.79862 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.79849 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.79835 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.79823 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.79812 Accuracy: 0.00000
Test Loss: 0.70226 Accuracy: 0.50050
Epoch: 300 Batch:   0 Loss: 0.79798 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.79782 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.79769 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.79020 Accuracy: 0.00000
Test Loss: 0.68198 Accuracy: 0.50050
Epoch: 430 Batch:   0 Loss: 0.79023 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.79026 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.79028 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.79027 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.79027 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.79025 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.79024 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.79022 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.79018 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.79015 Accuracy: 0.00000
Test Loss: 0.68072 Accuracy: 0.50050
Epoch: 440 Batch:   0 Loss: 0.79010 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.79005 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.79000 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.78995 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.78989 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.78983 Accuracy: 0.00000
Epoch: 4

[I 2023-11-17 11:03:01,976] Trial 18 finished with value: 0.5002010454362686 and parameters: {'REG_W': 3.582384022396309e-06, 'REG_B': 0.004066284365854493, 'REG_Z': 4.028592909037885e-05, 'SPAR_W': 0.5073628275195186, 'SPAR_B': 0.7855526698024307, 'SPAR_Z': 0.7414618993232829, 'LEARNING_RATE': 0.00032117831832901276, 'NUM_EPOCHS': 537}. Best is trial 14 with value: 0.756031363088058.


Epoch:   0 Batch:   0 Loss: 0.07572 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.10939 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 0.67295 Accuracy: 1.00000
Epoch:   3 Batch:   0 Loss: 1.33738 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.61726 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.71323 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.74384 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.75130 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.74954 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.74347 Accuracy: 0.00000
Test Loss: 0.94540 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.73490 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.72466 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.71330 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.70128 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.68896 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.67651 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.66414 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.65205 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.78622 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.78607 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.78599 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.78597 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.78600 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.78605 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.78617 Accuracy: 0.00000
Test Loss: 0.80368 Accuracy: 0.50050
Epoch: 150 Batch:   0 Loss: 0.78633 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.78646 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.78658 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.78670 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.78686 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.78697 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.78706 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.78714 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.78722 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.78734 Accuracy: 0.00000
Test Loss: 0.78680 Accuracy: 0.50060
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.81015 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.81013 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.81008 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.81005 Accuracy: 0.00000
Test Loss: 0.69496 Accuracy: 0.50070
Epoch: 290 Batch:   0 Loss: 0.81002 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.80997 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.80992 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.80987 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.80982 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.80976 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.80969 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.80962 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.80955 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.80947 Accuracy: 0.00000
Test Loss: 0.69268 Accuracy: 0.50070
Epoch: 300 Batch:   0 Loss: 0.80939 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.80932 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.80923 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.79171 Accuracy: 0.00000
Test Loss: 0.67857 Accuracy: 0.50070
Epoch: 430 Batch:   0 Loss: 0.79158 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.79145 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.79132 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.79120 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.79107 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.79096 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.79084 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.79073 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.79061 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.79050 Accuracy: 0.00000
Test Loss: 0.67838 Accuracy: 0.50070
Epoch: 440 Batch:   0 Loss: 0.79039 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.79027 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.79017 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.79006 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.78996 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.78985 Accuracy: 0.00000
Epoch: 4

Epoch: 571 Batch:   0 Loss: 0.78310 Accuracy: 0.00000
Epoch: 572 Batch:   0 Loss: 0.78304 Accuracy: 0.00000
Epoch: 573 Batch:   0 Loss: 0.78299 Accuracy: 0.00000
Epoch: 574 Batch:   0 Loss: 0.78295 Accuracy: 0.00000
Epoch: 575 Batch:   0 Loss: 0.78291 Accuracy: 0.00000
Epoch: 576 Batch:   0 Loss: 0.78287 Accuracy: 0.00000
Epoch: 577 Batch:   0 Loss: 0.78284 Accuracy: 0.00000
Epoch: 578 Batch:   0 Loss: 0.78281 Accuracy: 0.00000
Epoch: 579 Batch:   0 Loss: 0.78278 Accuracy: 0.00000
Test Loss: 0.67207 Accuracy: 0.50050
Epoch: 580 Batch:   0 Loss: 0.78275 Accuracy: 0.00000
Epoch: 581 Batch:   0 Loss: 0.78273 Accuracy: 0.00000
Epoch: 582 Batch:   0 Loss: 0.78270 Accuracy: 0.00000
Epoch: 583 Batch:   0 Loss: 0.78268 Accuracy: 0.00000
Epoch: 584 Batch:   0 Loss: 0.78265 Accuracy: 0.00000
Epoch: 585 Batch:   0 Loss: 0.78261 Accuracy: 0.00000
Epoch: 586 Batch:   0 Loss: 0.78257 Accuracy: 0.00000


[I 2023-11-17 11:12:19,738] Trial 19 finished with value: 0.5003015681544029 and parameters: {'REG_W': 3.5816754063519616e-06, 'REG_B': 0.0019308848784722011, 'REG_Z': 3.2192243443148886e-05, 'SPAR_W': 0.6279903470101106, 'SPAR_B': 0.8571723793461878, 'SPAR_Z': 0.759147611268914, 'LEARNING_RATE': 0.0002724335761200872, 'NUM_EPOCHS': 587}. Best is trial 14 with value: 0.756031363088058.


In [14]:
study.best_trials

[FrozenTrial(number=14, state=TrialState.COMPLETE, values=[0.756031363088058], datetime_start=datetime.datetime(2023, 11, 17, 10, 37, 24, 632687), datetime_complete=datetime.datetime(2023, 11, 17, 10, 42, 47, 287661), params={'REG_W': 3.1545272147130644e-06, 'REG_B': 0.00025738817748430687, 'REG_Z': 3.534809578054365e-05, 'SPAR_W': 0.6622823978261199, 'SPAR_B': 0.651918182305612, 'SPAR_Z': 0.6197401480486052, 'LEARNING_RATE': 0.00029791501581002825, 'NUM_EPOCHS': 595}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'REG_W': FloatDistribution(high=5e-06, log=False, low=2e-06, step=None), 'REG_B': FloatDistribution(high=0.01, log=False, low=0.0, step=None), 'REG_Z': FloatDistribution(high=5e-05, log=False, low=2e-05, step=None), 'SPAR_W': FloatDistribution(high=1.0, log=False, low=0.5, step=None), 'SPAR_B': FloatDistribution(high=1.0, log=False, low=0.5, step=None), 'SPAR_Z': FloatDistribution(high=1.0, log=False, low=0.5, step=None), 'LEARNING_RATE': FloatDistrib

# Model Training

In [None]:

# Setup input and train protoNN


X = tf.placeholder(tf.float32, [None, dataDimension], name='X')
Y = tf.placeholder(tf.float32, [None, numClasses], name='Y')
protoNN = ProtoNN(dataDimension, PROJECTION_DIM,
                  NUM_PROTOTYPES, numClasses,
                  gamma, W=W, B=B)
trainer = ProtoNNTrainer(protoNN, REG_W, REG_B, REG_Z,
                         SPAR_W, SPAR_B, SPAR_Z,
                         LEARNING_RATE, X, Y, lossType='xentropy')
sess = tf.Session()

trainer.train(BATCH_SIZE, NUM_EPOCHS, sess, x_train, x_test, y_train, y_test,
              printStep=600, valStep=10)


# Model Evaluation

In [None]:
acc = sess.run(protoNN.accuracy, feed_dict={X: x_test, Y: y_test})
pred = sess.run(protoNN.predictions, feed_dict={X: x_test, Y: y_test})
# W, B, Z are tensorflow graph nodes
W, B, Z, _ = protoNN.getModelMatrices()
matrixList = sess.run([W, B, Z])
sparcityList = [SPAR_W, SPAR_B, SPAR_Z]                       
nnz, size, sparse = getModelSize(matrixList, sparcityList)
print("Final test accuracy", acc)
print("Model size constraint (Bytes): ", size)
print("Number of non-zeros: ", nnz)

In [None]:
from sklearn.metrics import confusion_matrix,classification_report
y_test = np.argmax(y_test,axis=1)
print (confusion_matrix(y_test,pred))
print (classification_report(y_test,pred,digits=5))

In [None]:
sensitivity = confusion_matrix(y_test,pred)[1][1]/(confusion_matrix(y_test,pred)[1][1] + confusion_matrix(y_test,pred)[1][0])
sensitivity

In [None]:
specificity = confusion_matrix(y_test,pred)[0][0]/(confusion_matrix(y_test,pred)[0][0] + confusion_matrix(y_test,pred)[0][1])
specificity

In [None]:
# accuracy = (TP + TN) / (TP + TN + FP + FN) 
# Precision = TP / (TP + FP)
# Recall = TP / (TP + FN)
# sensitivity = TP / (TP + FN) 
# specificity = TN / (TN + FP) 
# positive predictive value (PPV) = TP / (TP + FP)